Files
aza/APP/backup 24.2.26/aza_security.py

104 lines
3.3 KiB
Python
Raw Normal View History

2026-03-25 14:14:07 +01:00
import hmac
import os
from pathlib import Path
from typing import List, Optional
from fastapi import Header, HTTPException, status
_DEV_TOKEN = "AZA_LOCAL_TOKEN_123456"
def _read_fallback_tokens(token_file: Optional[str] = None) -> List[str]:
"""
Supports:
- Single token in file
- Multiple tokens separated by newlines
- Multiple tokens separated by commas
Uses absolute path so it works regardless of cwd.
"""
if token_file is None:
token_file = Path(__file__).resolve().parent / "backend_token.txt"
else:
token_file = Path(token_file)
if not token_file.is_absolute():
token_file = Path(__file__).resolve().parent / token_file
try:
raw = token_file.read_text(encoding="utf-8")
except FileNotFoundError:
return []
except Exception:
return []
tokens: List[str] = []
for line in raw.replace(",", "\n").splitlines():
t = line.strip()
if t:
tokens.append(t)
return tokens
def get_required_api_tokens() -> List[str]:
"""
Token rotation support:
- MEDWORK_API_TOKENS can contain comma-separated tokens (preferred)
- MEDWORK_API_TOKEN can contain a single token (legacy) or comma-separated tokens
- backend_token.txt fallback supports single or multiple tokens
"""
env_multi = os.getenv("MEDWORK_API_TOKENS")
env_single = os.getenv("MEDWORK_API_TOKEN")
tokens: List[str] = []
if env_multi and env_multi.strip():
tokens.extend([t.strip() for t in env_multi.split(",") if t.strip()])
if env_single and env_single.strip():
# allow comma-separated for convenience
tokens.extend([t.strip() for t in env_single.split(",") if t.strip()])
# fallback file (absolute path, works regardless of cwd)
tokens.extend(_read_fallback_tokens())
# local dev token (always accepted)
if _DEV_TOKEN not in tokens:
tokens.append(_DEV_TOKEN)
# de-duplicate while preserving order
dedup: List[str] = []
for t in tokens:
if t and t not in dedup:
dedup.append(t)
if not dedup:
raise RuntimeError("No API token configured (MEDWORK_API_TOKENS / MEDWORK_API_TOKEN or backend_token.txt).")
return dedup
def require_api_token(
x_api_token: Optional[str] = Header(default=None, alias="X-API-Token"),
authorization: Optional[str] = Header(default=None, alias="Authorization"),
) -> None:
if not x_api_token and authorization and authorization.lower().startswith("bearer "):
x_api_token = authorization.split(" ", 1)[1].strip()
if not x_api_token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
expected_tokens = get_required_api_tokens()
ok = any(hmac.compare_digest(x_api_token, t) for t in expected_tokens)
if not ok:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
def get_admin_token() -> Optional[str]:
return os.getenv("AZA_ADMIN_TOKEN")
def require_admin_token(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> None:
expected = get_admin_token()
if not expected:
raise HTTPException(status_code=503, detail="Admin token not configured")
if not x_admin_token or not hmac.compare_digest(x_admin_token, expected):
raise HTTPException(status_code=401, detail="Unauthorized")