diff --git a/backend/app/core/dependencies.py b/backend/app/core/dependencies.py new file mode 100644 index 0000000..900b5c0 --- /dev/null +++ b/backend/app/core/dependencies.py @@ -0,0 +1,55 @@ +"""FastAPI dependency functions for authentication and authorisation.""" + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from jose import JWTError +from sqlalchemy.orm import Session + +from app.core.security import decode_access_token +from app.database import get_db +from app.models.user import User + +security = HTTPBearer() + + +def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security), + db: Session = Depends(get_db), +) -> User: + """Extract and validate the JWT bearer token, then return the active user. + + Raises 401 if the token is invalid/expired or the user is inactive. + """ + try: + payload = decode_access_token(credentials.credentials) + user_id = int(payload["sub"]) + except (JWTError, KeyError, ValueError): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token", + ) + + user = ( + db.query(User) + .filter(User.id == user_id, User.is_active == True) # noqa: E712 + .first() + ) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found or inactive", + ) + return user + + +def require_admin(user: User = Depends(get_current_user)) -> User: + """Require the authenticated user to have the ``admin`` role. + + Raises 403 if the user is not an admin. + """ + if user.role != "admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin access required", + ) + return user diff --git a/backend/app/core/exceptions.py b/backend/app/core/exceptions.py new file mode 100644 index 0000000..62c8ae4 --- /dev/null +++ b/backend/app/core/exceptions.py @@ -0,0 +1,49 @@ +"""Domain-specific HTTP exceptions for the DAK Zweitmeinungs-Portal.""" + +from fastapi import HTTPException, status + + +class CaseNotFoundError(HTTPException): + """Raised when a requested case does not exist (404).""" + + def __init__(self, detail: str = "Case not found") -> None: + super().__init__(status_code=status.HTTP_404_NOT_FOUND, detail=detail) + + +class DuplicateCaseError(HTTPException): + """Raised when a case with the same identifier already exists (409).""" + + def __init__(self, detail: str = "Case already exists") -> None: + super().__init__(status_code=status.HTTP_409_CONFLICT, detail=detail) + + +class InvalidImportFileError(HTTPException): + """Raised when an uploaded import file is malformed or invalid (400).""" + + def __init__(self, detail: str = "Invalid import file") -> None: + super().__init__(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) + + +class ICDValidationError(HTTPException): + """Raised when ICD code validation fails (422).""" + + def __init__(self, detail: str = "ICD code validation failed") -> None: + super().__init__( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail + ) + + +class AccountLockedError(HTTPException): + """Raised when a user account is temporarily locked (423).""" + + def __init__(self, detail: str = "Account is locked") -> None: + super().__init__(status_code=status.HTTP_423_LOCKED, detail=detail) + + +class InvalidCredentialsError(HTTPException): + """Raised when login credentials are incorrect (401).""" + + def __init__(self, detail: str = "Invalid credentials") -> None: + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, detail=detail + ) diff --git a/backend/app/core/security.py b/backend/app/core/security.py new file mode 100644 index 0000000..f171918 --- /dev/null +++ b/backend/app/core/security.py @@ -0,0 +1,114 @@ +"""Core security utilities: JWT, password hashing, MFA (TOTP).""" + +import hashlib +import secrets +from datetime import datetime, timedelta, timezone + +import pyotp +from jose import jwt, JWTError # noqa: F401 – re-exported for convenience + +# --------------------------------------------------------------------------- +# Monkey-patch passlib for bcrypt >= 4.1 / 5.x compatibility. +# Newer bcrypt removed ``__about__`` and rejects >72-byte passwords in its +# internal wrap-bug detection. The patch is applied before CryptContext is +# instantiated so that passlib's backend initialisation succeeds. +# --------------------------------------------------------------------------- +import passlib.handlers.bcrypt as _bcrypt_mod # noqa: E402 + +_orig_finalize = _bcrypt_mod._BcryptBackend._finalize_backend_mixin.__func__ # type: ignore[attr-defined] + + +@classmethod # type: ignore[misc] +def _patched_finalize(cls, name: str, dryrun: bool = False): # type: ignore[no-untyped-def] + try: + return _orig_finalize(cls, name, dryrun) + except ValueError: + # bcrypt 4.1+ raises ValueError on >72-byte secrets during + # passlib's internal wrap-bug detection — safe to ignore. + return True + + +_bcrypt_mod._BcryptBackend._finalize_backend_mixin = _patched_finalize # type: ignore[assignment] +# --------------------------------------------------------------------------- + +from passlib.context import CryptContext # noqa: E402 + +from app.config import get_settings + +settings = get_settings() + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +# --------------------------------------------------------------------------- +# Password hashing +# --------------------------------------------------------------------------- + +def hash_password(password: str) -> str: + """Hash a plain-text password using bcrypt.""" + return pwd_context.hash(password) + + +def verify_password(plain: str, hashed: str) -> bool: + """Verify a plain-text password against a bcrypt hash.""" + return pwd_context.verify(plain, hashed) + + +# --------------------------------------------------------------------------- +# JWT tokens +# --------------------------------------------------------------------------- + +def create_access_token(user_id: int, role: str) -> str: + """Create a short-lived JWT access token.""" + expire = datetime.now(timezone.utc) + timedelta( + minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + ) + payload = { + "sub": str(user_id), + "role": role, + "exp": expire, + } + return jwt.encode( + payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM + ) + + +def create_refresh_token() -> str: + """Create a cryptographically secure refresh token (URL-safe).""" + return secrets.token_urlsafe(64) + + +def hash_token(token: str) -> str: + """Return a SHA-256 hex digest of *token* (used for DB storage).""" + return hashlib.sha256(token.encode()).hexdigest() + + +def decode_access_token(token: str) -> dict: + """Decode and validate a JWT access token. + + Raises ``jose.JWTError`` on invalid or expired tokens. + """ + return jwt.decode( + token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] + ) + + +# --------------------------------------------------------------------------- +# MFA / TOTP +# --------------------------------------------------------------------------- + +def generate_mfa_secret() -> str: + """Generate a new TOTP base-32 secret.""" + return pyotp.random_base32() + + +def verify_mfa_code(secret: str, code: str) -> bool: + """Verify a 6-digit TOTP code against *secret*.""" + totp = pyotp.TOTP(secret) + return totp.verify(code) + + +def get_mfa_uri(secret: str, email: str) -> str: + """Return an ``otpauth://`` provisioning URI for QR code generation.""" + totp = pyotp.TOTP(secret) + return totp.provisioning_uri(name=email, issuer_name=settings.APP_NAME) diff --git a/backend/tests/test_security.py b/backend/tests/test_security.py new file mode 100644 index 0000000..15ad0cc --- /dev/null +++ b/backend/tests/test_security.py @@ -0,0 +1,113 @@ +"""Tests for app.core.security — JWT, bcrypt, MFA/TOTP, token hashing.""" + +import hashlib +from datetime import datetime, timedelta, timezone + +import pyotp +import pytest +from jose import JWTError, jwt + +from app.config import get_settings +from app.core.security import ( + create_access_token, + create_refresh_token, + decode_access_token, + generate_mfa_secret, + get_mfa_uri, + hash_password, + hash_token, + verify_mfa_code, + verify_password, +) + +settings = get_settings() + + +# --------------------------------------------------------------------------- +# Password hashing +# --------------------------------------------------------------------------- + +def test_hash_and_verify_password(): + """Hashing then verifying the same password returns True; wrong one False.""" + hashed = hash_password("s3cret!") + assert verify_password("s3cret!", hashed) is True + assert verify_password("wrong-password", hashed) is False + + +# --------------------------------------------------------------------------- +# JWT access tokens +# --------------------------------------------------------------------------- + +def test_create_and_decode_access_token(): + """Round-trip: create a token and decode it, payload should match.""" + token = create_access_token(user_id=1, role="admin") + payload = decode_access_token(token) + assert payload["sub"] == "1" + assert payload["role"] == "admin" + assert "exp" in payload + + +def test_expired_token_raises(): + """A token with a negative expiry must fail to decode.""" + expire = datetime.now(timezone.utc) - timedelta(seconds=10) + token = jwt.encode( + {"sub": "1", "role": "admin", "exp": expire}, + settings.JWT_SECRET_KEY, + algorithm=settings.JWT_ALGORITHM, + ) + with pytest.raises(JWTError): + decode_access_token(token) + + +def test_invalid_token_raises(): + """Decoding a garbage string must raise JWTError.""" + with pytest.raises(JWTError): + decode_access_token("not.a.valid.jwt.token") + + +# --------------------------------------------------------------------------- +# Refresh / hash tokens +# --------------------------------------------------------------------------- + +def test_create_refresh_token_uniqueness(): + """Two refresh tokens should never collide.""" + t1 = create_refresh_token() + t2 = create_refresh_token() + assert t1 != t2 + assert len(t1) > 40 # url-safe base64 of 64 bytes + + +def test_hash_token(): + """hash_token must produce a consistent SHA-256 hex digest.""" + token = "test-token-value" + expected = hashlib.sha256(token.encode()).hexdigest() + assert hash_token(token) == expected + # Deterministic: calling twice yields the same result + assert hash_token(token) == hash_token(token) + + +# --------------------------------------------------------------------------- +# MFA / TOTP +# --------------------------------------------------------------------------- + +def test_mfa_secret_and_verify(): + """Generate a secret, produce the current OTP, verify it succeeds.""" + secret = generate_mfa_secret() + assert len(secret) > 0 + totp = pyotp.TOTP(secret) + current_code = totp.now() + assert verify_mfa_code(secret, current_code) is True + + +def test_mfa_wrong_code(): + """A clearly wrong code must be rejected.""" + secret = generate_mfa_secret() + assert verify_mfa_code(secret, "000000") is False or verify_mfa_code(secret, "999999") is False + + +def test_mfa_uri_format(): + """The provisioning URI must start with the otpauth scheme.""" + secret = generate_mfa_secret() + uri = get_mfa_uri(secret, "user@example.com") + assert uri.startswith("otpauth://totp/") + assert "user%40example.com" in uri or "user@example.com" in uri