"""ICD service — normalize, split, validate, save, and generate coding templates.""" import logging from datetime import datetime, timezone from io import BytesIO from typing import Optional from openpyxl import Workbook, load_workbook from sqlalchemy.orm import Session from app.config import get_settings from app.models.case import Case, CaseICDCode from app.utils.validators import normalize_icd_hauptgruppe, split_icd_codes, validate_icd settings = get_settings() logger = logging.getLogger(__name__) def normalize_and_validate_icd(raw: str) -> list[tuple[str, str]]: """Split, validate, and normalize ICD codes from a raw string. Returns list of (icd_code, hauptgruppe) tuples. Raises ValueError for any invalid code. """ codes = split_icd_codes(raw) result = [] for code in codes: validated = validate_icd(code) hauptgruppe = normalize_icd_hauptgruppe(validated) result.append((validated, hauptgruppe)) return result def save_icd_for_case( db: Session, case_id: int, icd_raw: str, user_id: int, ) -> Case: """Set ICD codes for a case. Replaces existing ICD codes.""" case = db.query(Case).filter(Case.id == case_id).first() if not case: from app.core.exceptions import CaseNotFoundError raise CaseNotFoundError() # Validate all codes first icd_pairs = normalize_and_validate_icd(icd_raw) # Delete existing ICD codes for this case db.query(CaseICDCode).filter(CaseICDCode.case_id == case_id).delete() # Store raw ICD string on case case.icd = ", ".join(code for code, _ in icd_pairs) case.icd_entered_by = user_id case.icd_entered_at = datetime.now(timezone.utc) # Create individual ICD code entries for code, hauptgruppe in icd_pairs: db.add( CaseICDCode( case_id=case_id, icd_code=code, icd_hauptgruppe=hauptgruppe, ) ) db.commit() db.refresh(case) return case def get_pending_icd_cases( db: Session, jahr: Optional[int] = None, fallgruppe: Optional[str] = None, page: int = 1, per_page: int = 50, ) -> tuple[list[Case], int]: """Get cases without ICD codes.""" query = db.query(Case).filter( Case.versicherung == settings.VERSICHERUNG_FILTER, Case.icd == None, # noqa: E711 ) if jahr: query = query.filter(Case.jahr == jahr) if fallgruppe: query = query.filter(Case.fallgruppe == fallgruppe) total = query.count() cases = ( query.order_by(Case.datum.desc()) .offset((page - 1) * per_page) .limit(per_page) .all() ) return cases, total def generate_coding_template( db: Session, jahr: Optional[int] = None, fallgruppe: Optional[str] = None, ) -> bytes: """Generate an Excel template for ICD coding. Returns .xlsx bytes with columns: Case_ID, Fall_ID, Fallgruppe, Datum, ICD (empty) Patient names are excluded for data privacy (DSGVO). """ cases, _ = get_pending_icd_cases( db, jahr=jahr, fallgruppe=fallgruppe, page=1, per_page=10000 ) wb = Workbook() ws = wb.active ws.title = "ICD Coding" # Header headers = ["Case_ID", "Fall_ID", "Fallgruppe", "Datum", "ICD"] for col, header in enumerate(headers, start=1): ws.cell(row=1, column=col, value=header) # Data for i, case in enumerate(cases, start=2): ws.cell(row=i, column=1, value=case.id) ws.cell(row=i, column=2, value=case.fall_id) ws.cell(row=i, column=3, value=case.fallgruppe) ws.cell(row=i, column=4, value=case.datum.isoformat() if case.datum else "") # Column 5 (ICD) left empty for admin to fill in # Auto-width for col in ws.columns: max_length = max(len(str(cell.value or "")) for cell in col) ws.column_dimensions[col[0].column_letter].width = min(max_length + 2, 30) buffer = BytesIO() wb.save(buffer) return buffer.getvalue() def import_icd_from_xlsx(db: Session, content: bytes, user_id: int) -> dict: """Import ICD codes from a filled-in coding template Excel file. Expects columns: Case_ID (col 1), ICD (last col — col 5 or col 7) Returns: {"updated": int, "errors": list[str]} """ wb = load_workbook(BytesIO(content), read_only=True) ws = wb.active updated = 0 errors: list[str] = [] for row in ws.iter_rows(min_row=2, values_only=False): case_id_cell = row[0].value if not case_id_cell: continue try: case_id = int(case_id_cell) except (ValueError, TypeError): continue # Find ICD column: last column (col 5 in new template, col 7 in legacy) icd_value = None last_idx = len(row) - 1 if last_idx >= 0 and row[last_idx].value: icd_value = str(row[last_idx].value).strip() if not icd_value: continue try: save_icd_for_case(db, case_id, icd_value, user_id) updated += 1 except Exception as e: errors.append(f"Case {case_id}: {e}") return {"updated": updated, "errors": errors}