"""Authentication business logic: login, register, tokens, MFA.""" from datetime import datetime, timedelta, timezone from fastapi import HTTPException, status from sqlalchemy.orm import Session from app.config import get_settings from app.core.exceptions import AccountLockedError, InvalidCredentialsError from app.core.security import ( create_access_token, create_refresh_token, generate_mfa_secret, get_mfa_uri, hash_password, hash_token, verify_mfa_code, verify_password, ) from app.models.user import AllowedDomain, InvitationLink, RefreshToken, User settings = get_settings() # --------------------------------------------------------------------------- # Login # --------------------------------------------------------------------------- def authenticate_user( db: Session, email: str, password: str, mfa_code: str | None = None, ) -> User: """Authenticate a user by email/password (+ optional MFA). Handles: * Account-lock detection (locked_until). * Failed-attempt counting with auto-lock after 5 failures. * TOTP verification when MFA is enabled. """ user = db.query(User).filter(User.email == email).first() if not user or not user.is_active: raise InvalidCredentialsError() # Check whether the account is currently locked. if user.locked_until and user.locked_until > datetime.now(timezone.utc): raise AccountLockedError( detail=f"Account locked until {user.locked_until.isoformat()}" ) # Verify password. if not verify_password(password, user.password_hash): user.failed_login_attempts += 1 if user.failed_login_attempts >= 5: user.locked_until = datetime.now(timezone.utc) + timedelta(minutes=30) db.commit() raise InvalidCredentialsError() # MFA check (if enabled on this account). if user.mfa_enabled: if not mfa_code: raise InvalidCredentialsError(detail="MFA code required") if not verify_mfa_code(user.mfa_secret, mfa_code): raise InvalidCredentialsError(detail="Invalid MFA code") # Success -- reset counters and record last login. user.failed_login_attempts = 0 user.locked_until = None user.last_login = datetime.now(timezone.utc) db.commit() return user # --------------------------------------------------------------------------- # Registration # --------------------------------------------------------------------------- def register_user( db: Session, username: str, email: str, password: str, invitation_token: str | None = None, ) -> User: """Register a new user account. Either a valid (unexpired, unused) invitation token OR an email whose domain is in the ``allowed_domains`` whitelist is required. """ # Uniqueness check. existing = ( db.query(User) .filter((User.email == email) | (User.username == username)) .first() ) if existing: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="User already exists", ) role = "dak_mitarbeiter" if invitation_token: invite = ( db.query(InvitationLink) .filter( InvitationLink.token == invitation_token, InvitationLink.is_active == True, # noqa: E712 InvitationLink.used_at == None, # noqa: E711 InvitationLink.expires_at > datetime.now(timezone.utc), ) .first() ) if not invite: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired invitation", ) if invite.email and invite.email.lower() != email.lower(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email does not match invitation", ) role = invite.role else: # Domain whitelist. domain = email.split("@")[1].lower() allowed = ( db.query(AllowedDomain) .filter( AllowedDomain.domain == domain, AllowedDomain.is_active == True, # noqa: E712 ) .first() ) if not allowed: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Email domain not allowed for registration", ) role = allowed.role user = User( username=username, email=email, password_hash=hash_password(password), role=role, ) db.add(user) db.commit() db.refresh(user) # Mark invitation as consumed. if invitation_token: invite.used_at = datetime.now(timezone.utc) invite.used_by = user.id invite.is_active = False db.commit() return user # --------------------------------------------------------------------------- # Token management # --------------------------------------------------------------------------- def create_tokens(db: Session, user: User) -> tuple[str, str]: """Create an access/refresh token pair and persist the refresh hash.""" access = create_access_token(user.id, user.role) refresh = create_refresh_token() rt = RefreshToken( user_id=user.id, token_hash=hash_token(refresh), expires_at=datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS), ) db.add(rt) db.commit() return access, refresh def refresh_access_token(db: Session, refresh_token: str) -> tuple[str, User]: """Validate a refresh token and return a fresh access token + user.""" token_hash = hash_token(refresh_token) rt = ( db.query(RefreshToken) .filter( RefreshToken.token_hash == token_hash, RefreshToken.revoked == False, # noqa: E712 RefreshToken.expires_at > datetime.now(timezone.utc), ) .first() ) if not rt: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired refresh token", ) user = ( db.query(User) .filter(User.id == rt.user_id, User.is_active == True) # noqa: E712 .first() ) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive", ) access = create_access_token(user.id, user.role) return access, user def revoke_refresh_token(db: Session, refresh_token: str) -> None: """Revoke a refresh token (used on logout).""" token_hash = hash_token(refresh_token) rt = ( db.query(RefreshToken) .filter(RefreshToken.token_hash == token_hash) .first() ) if rt: rt.revoked = True db.commit() # --------------------------------------------------------------------------- # MFA helpers # --------------------------------------------------------------------------- def setup_mfa(user: User) -> tuple[str, str]: """Generate a fresh MFA secret and provisioning URI for *user*. Does NOT persist the secret -- the caller must save it after the user confirms activation via ``activate_mfa``. """ secret = generate_mfa_secret() uri = get_mfa_uri(secret, user.email) return secret, uri def activate_mfa(db: Session, user: User, secret: str, code: str) -> None: """Verify the TOTP *code* against *secret* and enable MFA on *user*.""" if not verify_mfa_code(secret, code): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid MFA code", ) user.mfa_secret = secret user.mfa_enabled = True db.commit() # --------------------------------------------------------------------------- # Password change # --------------------------------------------------------------------------- def change_password( db: Session, user: User, old_password: str, new_password: str, ) -> None: """Change the authenticated user's password after verifying the old one.""" if not verify_password(old_password, user.password_hash): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Current password is incorrect", ) user.password_hash = hash_password(new_password) user.must_change_password = False db.commit() # --------------------------------------------------------------------------- # Profile update # --------------------------------------------------------------------------- def update_profile( db: Session, user: User, username: str | None = None, email: str | None = None, first_name: str | None = None, last_name: str | None = None, display_name: str | None = None, ) -> User: """Update the authenticated user's own profile fields.""" if username and username != user.username: existing = db.query(User).filter(User.username == username, User.id != user.id).first() if existing: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Username already taken", ) user.username = username if email and email != user.email: existing = db.query(User).filter(User.email == email, User.id != user.id).first() if existing: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Email already taken", ) user.email = email if first_name is not None: user.first_name = first_name if last_name is not None: user.last_name = last_name if display_name is not None: user.display_name = display_name db.commit() db.refresh(user) return user # --------------------------------------------------------------------------- # MFA disable # --------------------------------------------------------------------------- def disable_mfa(db: Session, user: User, password: str) -> None: """Disable MFA on the user's account after verifying their password.""" if not verify_password(password, user.password_hash): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Incorrect password", ) user.mfa_secret = None user.mfa_enabled = False db.commit()