mirror of
https://github.com/complexcaresolutions/dak.c2s.git
synced 2026-03-17 18:23:42 +00:00
feat: JWT auth, bcrypt, MFA, dependency injection, security tests
Add core security layer: - security.py: password hashing (bcrypt), JWT access/refresh tokens, SHA-256 token hashing, TOTP MFA (generate, verify, provisioning URI), plus passlib/bcrypt 5.x compatibility patch - dependencies.py: FastAPI deps for get_current_user (Bearer JWT) and require_admin (role check) - exceptions.py: domain-specific HTTP exceptions (CaseNotFound, DuplicateCase, InvalidImportFile, ICDValidation, AccountLocked, InvalidCredentials) - test_security.py: 9 tests covering all security functions Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e0ca8c31c3
commit
178d40d036
4 changed files with 331 additions and 0 deletions
55
backend/app/core/dependencies.py
Normal file
55
backend/app/core/dependencies.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""FastAPI dependency functions for authentication and authorisation."""
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from jose import JWTError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.security import decode_access_token
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Extract and validate the JWT bearer token, then return the active user.
|
||||
|
||||
Raises 401 if the token is invalid/expired or the user is inactive.
|
||||
"""
|
||||
try:
|
||||
payload = decode_access_token(credentials.credentials)
|
||||
user_id = int(payload["sub"])
|
||||
except (JWTError, KeyError, ValueError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.id == user_id, User.is_active == True) # noqa: E712
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found or inactive",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(user: User = Depends(get_current_user)) -> User:
|
||||
"""Require the authenticated user to have the ``admin`` role.
|
||||
|
||||
Raises 403 if the user is not an admin.
|
||||
"""
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin access required",
|
||||
)
|
||||
return user
|
||||
49
backend/app/core/exceptions.py
Normal file
49
backend/app/core/exceptions.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
"""Domain-specific HTTP exceptions for the DAK Zweitmeinungs-Portal."""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
class CaseNotFoundError(HTTPException):
|
||||
"""Raised when a requested case does not exist (404)."""
|
||||
|
||||
def __init__(self, detail: str = "Case not found") -> None:
|
||||
super().__init__(status_code=status.HTTP_404_NOT_FOUND, detail=detail)
|
||||
|
||||
|
||||
class DuplicateCaseError(HTTPException):
|
||||
"""Raised when a case with the same identifier already exists (409)."""
|
||||
|
||||
def __init__(self, detail: str = "Case already exists") -> None:
|
||||
super().__init__(status_code=status.HTTP_409_CONFLICT, detail=detail)
|
||||
|
||||
|
||||
class InvalidImportFileError(HTTPException):
|
||||
"""Raised when an uploaded import file is malformed or invalid (400)."""
|
||||
|
||||
def __init__(self, detail: str = "Invalid import file") -> None:
|
||||
super().__init__(status_code=status.HTTP_400_BAD_REQUEST, detail=detail)
|
||||
|
||||
|
||||
class ICDValidationError(HTTPException):
|
||||
"""Raised when ICD code validation fails (422)."""
|
||||
|
||||
def __init__(self, detail: str = "ICD code validation failed") -> None:
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail
|
||||
)
|
||||
|
||||
|
||||
class AccountLockedError(HTTPException):
|
||||
"""Raised when a user account is temporarily locked (423)."""
|
||||
|
||||
def __init__(self, detail: str = "Account is locked") -> None:
|
||||
super().__init__(status_code=status.HTTP_423_LOCKED, detail=detail)
|
||||
|
||||
|
||||
class InvalidCredentialsError(HTTPException):
|
||||
"""Raised when login credentials are incorrect (401)."""
|
||||
|
||||
def __init__(self, detail: str = "Invalid credentials") -> None:
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=detail
|
||||
)
|
||||
114
backend/app/core/security.py
Normal file
114
backend/app/core/security.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""Core security utilities: JWT, password hashing, MFA (TOTP)."""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pyotp
|
||||
from jose import jwt, JWTError # noqa: F401 – re-exported for convenience
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Monkey-patch passlib for bcrypt >= 4.1 / 5.x compatibility.
|
||||
# Newer bcrypt removed ``__about__`` and rejects >72-byte passwords in its
|
||||
# internal wrap-bug detection. The patch is applied before CryptContext is
|
||||
# instantiated so that passlib's backend initialisation succeeds.
|
||||
# ---------------------------------------------------------------------------
|
||||
import passlib.handlers.bcrypt as _bcrypt_mod # noqa: E402
|
||||
|
||||
_orig_finalize = _bcrypt_mod._BcryptBackend._finalize_backend_mixin.__func__ # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@classmethod # type: ignore[misc]
|
||||
def _patched_finalize(cls, name: str, dryrun: bool = False): # type: ignore[no-untyped-def]
|
||||
try:
|
||||
return _orig_finalize(cls, name, dryrun)
|
||||
except ValueError:
|
||||
# bcrypt 4.1+ raises ValueError on >72-byte secrets during
|
||||
# passlib's internal wrap-bug detection — safe to ignore.
|
||||
return True
|
||||
|
||||
|
||||
_bcrypt_mod._BcryptBackend._finalize_backend_mixin = _patched_finalize # type: ignore[assignment]
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from passlib.context import CryptContext # noqa: E402
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Password hashing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a plain-text password using bcrypt."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
"""Verify a plain-text password against a bcrypt hash."""
|
||||
return pwd_context.verify(plain, hashed)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JWT tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_access_token(user_id: int, role: str) -> str:
|
||||
"""Create a short-lived JWT access token."""
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"role": role,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(
|
||||
payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token() -> str:
|
||||
"""Create a cryptographically secure refresh token (URL-safe)."""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Return a SHA-256 hex digest of *token* (used for DB storage)."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict:
|
||||
"""Decode and validate a JWT access token.
|
||||
|
||||
Raises ``jose.JWTError`` on invalid or expired tokens.
|
||||
"""
|
||||
return jwt.decode(
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MFA / TOTP
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_mfa_secret() -> str:
|
||||
"""Generate a new TOTP base-32 secret."""
|
||||
return pyotp.random_base32()
|
||||
|
||||
|
||||
def verify_mfa_code(secret: str, code: str) -> bool:
|
||||
"""Verify a 6-digit TOTP code against *secret*."""
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.verify(code)
|
||||
|
||||
|
||||
def get_mfa_uri(secret: str, email: str) -> str:
|
||||
"""Return an ``otpauth://`` provisioning URI for QR code generation."""
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.provisioning_uri(name=email, issuer_name=settings.APP_NAME)
|
||||
113
backend/tests/test_security.py
Normal file
113
backend/tests/test_security.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""Tests for app.core.security — JWT, bcrypt, MFA/TOTP, token hashing."""
|
||||
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pyotp
|
||||
import pytest
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.config import get_settings
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
generate_mfa_secret,
|
||||
get_mfa_uri,
|
||||
hash_password,
|
||||
hash_token,
|
||||
verify_mfa_code,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Password hashing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_hash_and_verify_password():
|
||||
"""Hashing then verifying the same password returns True; wrong one False."""
|
||||
hashed = hash_password("s3cret!")
|
||||
assert verify_password("s3cret!", hashed) is True
|
||||
assert verify_password("wrong-password", hashed) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JWT access tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_create_and_decode_access_token():
|
||||
"""Round-trip: create a token and decode it, payload should match."""
|
||||
token = create_access_token(user_id=1, role="admin")
|
||||
payload = decode_access_token(token)
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["role"] == "admin"
|
||||
assert "exp" in payload
|
||||
|
||||
|
||||
def test_expired_token_raises():
|
||||
"""A token with a negative expiry must fail to decode."""
|
||||
expire = datetime.now(timezone.utc) - timedelta(seconds=10)
|
||||
token = jwt.encode(
|
||||
{"sub": "1", "role": "admin", "exp": expire},
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
)
|
||||
with pytest.raises(JWTError):
|
||||
decode_access_token(token)
|
||||
|
||||
|
||||
def test_invalid_token_raises():
|
||||
"""Decoding a garbage string must raise JWTError."""
|
||||
with pytest.raises(JWTError):
|
||||
decode_access_token("not.a.valid.jwt.token")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Refresh / hash tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_create_refresh_token_uniqueness():
|
||||
"""Two refresh tokens should never collide."""
|
||||
t1 = create_refresh_token()
|
||||
t2 = create_refresh_token()
|
||||
assert t1 != t2
|
||||
assert len(t1) > 40 # url-safe base64 of 64 bytes
|
||||
|
||||
|
||||
def test_hash_token():
|
||||
"""hash_token must produce a consistent SHA-256 hex digest."""
|
||||
token = "test-token-value"
|
||||
expected = hashlib.sha256(token.encode()).hexdigest()
|
||||
assert hash_token(token) == expected
|
||||
# Deterministic: calling twice yields the same result
|
||||
assert hash_token(token) == hash_token(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MFA / TOTP
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_mfa_secret_and_verify():
|
||||
"""Generate a secret, produce the current OTP, verify it succeeds."""
|
||||
secret = generate_mfa_secret()
|
||||
assert len(secret) > 0
|
||||
totp = pyotp.TOTP(secret)
|
||||
current_code = totp.now()
|
||||
assert verify_mfa_code(secret, current_code) is True
|
||||
|
||||
|
||||
def test_mfa_wrong_code():
|
||||
"""A clearly wrong code must be rejected."""
|
||||
secret = generate_mfa_secret()
|
||||
assert verify_mfa_code(secret, "000000") is False or verify_mfa_code(secret, "999999") is False
|
||||
|
||||
|
||||
def test_mfa_uri_format():
|
||||
"""The provisioning URI must start with the otpauth scheme."""
|
||||
secret = generate_mfa_secret()
|
||||
uri = get_mfa_uri(secret, "user@example.com")
|
||||
assert uri.startswith("otpauth://totp/")
|
||||
assert "user%40example.com" in uri or "user@example.com" in uri
|
||||
Loading…
Reference in a new issue