104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
|
|
import hmac
|
||
|
|
import os
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import List, Optional
|
||
|
|
|
||
|
|
from fastapi import Header, HTTPException, status
|
||
|
|
|
||
|
|
_DEV_TOKEN = "AZA_LOCAL_TOKEN_123456"
|
||
|
|
|
||
|
|
|
||
|
|
def _read_fallback_tokens(token_file: Optional[str] = None) -> List[str]:
|
||
|
|
"""
|
||
|
|
Supports:
|
||
|
|
- Single token in file
|
||
|
|
- Multiple tokens separated by newlines
|
||
|
|
- Multiple tokens separated by commas
|
||
|
|
Uses absolute path so it works regardless of cwd.
|
||
|
|
"""
|
||
|
|
if token_file is None:
|
||
|
|
token_file = Path(__file__).resolve().parent / "backend_token.txt"
|
||
|
|
else:
|
||
|
|
token_file = Path(token_file)
|
||
|
|
if not token_file.is_absolute():
|
||
|
|
token_file = Path(__file__).resolve().parent / token_file
|
||
|
|
try:
|
||
|
|
raw = token_file.read_text(encoding="utf-8")
|
||
|
|
except FileNotFoundError:
|
||
|
|
return []
|
||
|
|
except Exception:
|
||
|
|
return []
|
||
|
|
|
||
|
|
tokens: List[str] = []
|
||
|
|
for line in raw.replace(",", "\n").splitlines():
|
||
|
|
t = line.strip()
|
||
|
|
if t:
|
||
|
|
tokens.append(t)
|
||
|
|
return tokens
|
||
|
|
|
||
|
|
|
||
|
|
def get_required_api_tokens() -> List[str]:
|
||
|
|
"""
|
||
|
|
Token rotation support:
|
||
|
|
- MEDWORK_API_TOKENS can contain comma-separated tokens (preferred)
|
||
|
|
- MEDWORK_API_TOKEN can contain a single token (legacy) or comma-separated tokens
|
||
|
|
- backend_token.txt fallback supports single or multiple tokens
|
||
|
|
"""
|
||
|
|
env_multi = os.getenv("MEDWORK_API_TOKENS")
|
||
|
|
env_single = os.getenv("MEDWORK_API_TOKEN")
|
||
|
|
|
||
|
|
tokens: List[str] = []
|
||
|
|
|
||
|
|
if env_multi and env_multi.strip():
|
||
|
|
tokens.extend([t.strip() for t in env_multi.split(",") if t.strip()])
|
||
|
|
|
||
|
|
if env_single and env_single.strip():
|
||
|
|
# allow comma-separated for convenience
|
||
|
|
tokens.extend([t.strip() for t in env_single.split(",") if t.strip()])
|
||
|
|
|
||
|
|
# fallback file (absolute path, works regardless of cwd)
|
||
|
|
tokens.extend(_read_fallback_tokens())
|
||
|
|
|
||
|
|
# local dev token (always accepted)
|
||
|
|
if _DEV_TOKEN not in tokens:
|
||
|
|
tokens.append(_DEV_TOKEN)
|
||
|
|
|
||
|
|
# de-duplicate while preserving order
|
||
|
|
dedup: List[str] = []
|
||
|
|
for t in tokens:
|
||
|
|
if t and t not in dedup:
|
||
|
|
dedup.append(t)
|
||
|
|
|
||
|
|
if not dedup:
|
||
|
|
raise RuntimeError("No API token configured (MEDWORK_API_TOKENS / MEDWORK_API_TOKEN or backend_token.txt).")
|
||
|
|
|
||
|
|
return dedup
|
||
|
|
|
||
|
|
|
||
|
|
def require_api_token(
|
||
|
|
x_api_token: Optional[str] = Header(default=None, alias="X-API-Token"),
|
||
|
|
authorization: Optional[str] = Header(default=None, alias="Authorization"),
|
||
|
|
) -> None:
|
||
|
|
if not x_api_token and authorization and authorization.lower().startswith("bearer "):
|
||
|
|
x_api_token = authorization.split(" ", 1)[1].strip()
|
||
|
|
if not x_api_token:
|
||
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||
|
|
|
||
|
|
expected_tokens = get_required_api_tokens()
|
||
|
|
ok = any(hmac.compare_digest(x_api_token, t) for t in expected_tokens)
|
||
|
|
|
||
|
|
if not ok:
|
||
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||
|
|
|
||
|
|
|
||
|
|
def get_admin_token() -> Optional[str]:
|
||
|
|
return os.getenv("AZA_ADMIN_TOKEN")
|
||
|
|
|
||
|
|
|
||
|
|
def require_admin_token(x_admin_token: Optional[str] = Header(default=None, alias="X-Admin-Token")) -> None:
|
||
|
|
expected = get_admin_token()
|
||
|
|
if not expected:
|
||
|
|
raise HTTPException(status_code=503, detail="Admin token not configured")
|
||
|
|
if not x_admin_token or not hmac.compare_digest(x_admin_token, expected):
|
||
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|