# -*- coding: utf-8 -*- """Repository für Absence- und BalanceAccount-CRUD.""" import datetime from typing import Optional from sqlalchemy.orm import Session from sqlalchemy import and_, func from ..core.models import Absence, BalanceAccount, Employee from ..core.enums import AbsenceStatus, AbsenceCategory, ABSENCE_META from ..core.schemas import AbsenceCreate, AbsenceUpdate class AbsenceRepository: def __init__(self, db: Session): self.db = db def get_by_id(self, absence_id: str) -> Optional[Absence]: return self.db.query(Absence).filter(Absence.id == absence_id).first() def list_for_employee( self, employee_id: str, year: Optional[int] = None ) -> list[Absence]: q = self.db.query(Absence).filter(Absence.employee_id == employee_id) if year: jan1 = datetime.date(year, 1, 1) dec31 = datetime.date(year, 12, 31) q = q.filter(Absence.end_date >= jan1, Absence.start_date <= dec31) return q.order_by(Absence.start_date).all() def list_for_period( self, start: datetime.date, end: datetime.date, status: Optional[AbsenceStatus] = None ) -> list[Absence]: q = self.db.query(Absence).filter( Absence.end_date >= start, Absence.start_date <= end, ) if status: q = q.filter(Absence.status == status) return q.order_by(Absence.start_date).all() def list_all(self, year: Optional[int] = None) -> list[Absence]: q = self.db.query(Absence) if year: jan1 = datetime.date(year, 1, 1) dec31 = datetime.date(year, 12, 31) q = q.filter(Absence.end_date >= jan1, Absence.start_date <= dec31) return q.order_by(Absence.start_date).all() def find_overlapping( self, employee_id: str, start: datetime.date, end: datetime.date, exclude_id: Optional[str] = None, ) -> list[Absence]: q = self.db.query(Absence).filter( Absence.employee_id == employee_id, Absence.end_date >= start, Absence.start_date <= end, Absence.status != AbsenceStatus.CANCELLED, ) if exclude_id: q = q.filter(Absence.id != exclude_id) return q.all() def count_absent_on_date(self, date: datetime.date) -> int: return self.db.query(Absence).filter( Absence.start_date <= date, Absence.end_date >= date, Absence.status.in_([AbsenceStatus.APPROVED, AbsenceStatus.PENDING]), ).count() def create(self, data: AbsenceCreate, business_days: float) -> Absence: absence = Absence( **data.model_dump(), business_days=business_days, ) self.db.add(absence) self.db.flush() return absence def update(self, absence_id: str, data: AbsenceUpdate) -> Optional[Absence]: absence = self.get_by_id(absence_id) if not absence: return None for field, value in data.model_dump(exclude_unset=True).items(): setattr(absence, field, value) self.db.flush() return absence def delete(self, absence_id: str) -> bool: absence = self.get_by_id(absence_id) if not absence: return False self.db.delete(absence) self.db.flush() return True def used_days(self, employee_id: str, year: int) -> float: """Verbrauchte Ferientage (nur Kategorien mit deducts_balance).""" deducting = [ cat for cat, meta in ABSENCE_META.items() if meta["deducts_balance"] ] rows = ( self.db.query(Absence) .filter( Absence.employee_id == employee_id, Absence.category.in_(deducting), Absence.status != AbsenceStatus.CANCELLED, Absence.end_date >= datetime.date(year, 1, 1), Absence.start_date <= datetime.date(year, 12, 31), ) .all() ) return sum(r.business_days for r in rows) class BalanceRepository: def __init__(self, db: Session): self.db = db def get_or_create(self, employee_id: str, year: int) -> BalanceAccount: ba = ( self.db.query(BalanceAccount) .filter(BalanceAccount.employee_id == employee_id, BalanceAccount.year == year) .first() ) if not ba: emp = self.db.query(Employee).filter(Employee.id == employee_id).first() quota = emp.vacation_days_per_year if emp else 25 ba = BalanceAccount(employee_id=employee_id, year=year, yearly_quota=quota) self.db.add(ba) self.db.flush() return ba def update(self, employee_id: str, year: int, **kwargs) -> BalanceAccount: ba = self.get_or_create(employee_id, year) for k, v in kwargs.items(): if hasattr(ba, k) and v is not None: setattr(ba, k, v) self.db.flush() return ba