Files
2026-03-30 07:59:11 +02:00

1023 lines
34 KiB
Python

"""
AZA / MedWork - License Server (FastAPI)
Meilenstein A:
- Login generiert echte Tokens (DB-backed)
- Token an User gebunden
- Token in DB gespeichert
- /license/check validiert Token via DB (kein hardcoded if mehr)
- Geräte-Limit (default 2) bleibt aktiv
SQLite DB: aza_license.db (wird automatisch initialisiert/migriert)
"""
from __future__ import annotations
import os
import logging
import sqlite3
import secrets
import hashlib
import hmac
import uuid
from datetime import datetime, timezone
from typing import Optional
# Auto-load .env if python-dotenv is available (no hard dependency)
try:
from dotenv import load_dotenv # type: ignore
except Exception:
load_dotenv = None
if load_dotenv is not None:
# Loads variables from .env into process environment (safe if file missing)
load_dotenv()
from fastapi import Depends, FastAPI, HTTPException, Query, Request
from pydantic import BaseModel, Field
from starlette.middleware.base import BaseHTTPMiddleware
from admin_routes import build_admin_router
from aza_security import require_admin_token, require_api_token
from aza_license_logic import compute_license_decision
from aza_device_enforcement import enforce_and_touch_device
from aza_stripe_idempotency import try_claim_event
import stripe
LOG_LEVEL = os.getenv("AZA_LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO))
logger = logging.getLogger("aza")
class RequestIdMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_id = request.headers.get("x-request-id") or str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-Id"] = request_id
return response
class SimpleProxyHeadersMiddleware(BaseHTTPMiddleware):
"""
Minimal proxy header support without relying on Starlette's ProxyHeadersMiddleware module.
- Sets request.state.client_ip from X-Forwarded-For if present
- Sets request.state.scheme from X-Forwarded-Proto if present
"""
async def dispatch(self, request: Request, call_next):
xff = request.headers.get("x-forwarded-for")
if xff:
# Take the first IP in the list (client)
request.state.client_ip = xff.split(",")[0].strip()
else:
request.state.client_ip = request.client.host if request.client else "unknown"
xfp = request.headers.get("x-forwarded-proto")
request.state.scheme = (xfp.split(",")[0].strip() if xfp else request.url.scheme)
return await call_next(request)
# -----------------------------
# Config
# -----------------------------
DB_PATH = os.environ.get("AZA_LICENSE_DB", "aza_license.db")
DEVICE_LIMIT_DEFAULT = int(os.environ.get("AZA_DEVICE_LIMIT_DEFAULT", "2"))
PLAN_LIMITS_RAW = os.environ.get("AZA_PLAN_LIMITS", "").strip()
ADMIN_KEY = os.environ.get("AZA_ADMIN_KEY", "").strip()
MAX_ACTIVE_TOKENS_PER_USER = int(os.environ.get("AZA_MAX_ACTIVE_TOKENS_PER_USER", "3"))
STRIPE_SECRET_KEY = os.environ.get("AZA_STRIPE_SECRET_KEY", "").strip()
STRIPE_WEBHOOK_SECRET = os.environ.get("AZA_STRIPE_WEBHOOK_SECRET", "").strip()
STRIPE_PRICE_BASIC = os.environ.get("AZA_STRIPE_PRICE_BASIC", "").strip()
STRIPE_PRICE_TEAM = os.environ.get("AZA_STRIPE_PRICE_TEAM", "").strip()
STRIPE_PRICE_BASIC_YEARLY = os.environ.get("AZA_STRIPE_PRICE_BASIC_YEARLY", "").strip()
STRIPE_PRICE_TEAM_YEARLY = os.environ.get("AZA_STRIPE_PRICE_TEAM_YEARLY", "").strip()
STRIPE_SUCCESS_URL = os.environ.get("AZA_STRIPE_SUCCESS_URL", "").strip()
STRIPE_CANCEL_URL = os.environ.get("AZA_STRIPE_CANCEL_URL", "").strip()
# Optional: Demo-Token (für Tests). Client nutzt Demo i.d.R. lokal.
DEMO_TOKEN = os.environ.get("AZA_DEMO_TOKEN", "DEMO")
app = FastAPI(title="AZA License Server", version="1.0.0")
app.add_middleware(SimpleProxyHeadersMiddleware)
app.add_middleware(RequestIdMiddleware)
# -----------------------------
# Helpers
# -----------------------------
def utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def get_db() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON;")
return conn
def sha256_hex(data: bytes) -> str:
return hashlib.sha256(data).hexdigest()
def hash_password(password: str, salt_hex: str) -> str:
# Simple, deterministic: sha256(salt + password)
# (Für später: argon2/bcrypt möglich, aber jetzt nicht nötig.)
return sha256_hex(bytes.fromhex(salt_hex) + password.encode("utf-8"))
def verify_password(password: str, salt_hex: str, pw_hash_hex: str) -> bool:
candidate = hash_password(password, salt_hex)
return hmac.compare_digest(candidate, pw_hash_hex)
def parse_plan_limits(raw: str) -> dict[str, int]:
"""
Parses AZA_PLAN_LIMITS like:
"basic:2,pro:4,business:10"
Returns dict {plan: limit}
Invalid entries are ignored.
"""
out: dict[str, int] = {}
if not raw:
return out
parts = [p.strip() for p in raw.split(",") if p.strip()]
for p in parts:
if ":" not in p:
continue
k, v = p.split(":", 1)
plan = k.strip().lower()
try:
limit = int(v.strip())
except Exception:
continue
if plan and limit >= 0:
out[plan] = limit
return out
PLAN_LIMITS = parse_plan_limits(PLAN_LIMITS_RAW)
def get_device_limit_for_plan(plan: str) -> int:
"""
Returns device limit for a plan.
Business definition:
basic -> 3 devices
team -> 5 devices
Fallback -> 3 devices
"""
p = (plan or "basic").strip().lower()
if p == "team":
return 5
# default plan
return 3
def log_admin_action(
conn: sqlite3.Connection,
action: str,
email: Optional[str] = None,
token: Optional[str] = None,
old_value: Optional[str] = None,
new_value: Optional[str] = None,
) -> None:
conn.execute(
"""
INSERT INTO admin_audit (action, email, token, old_value, new_value, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(action, email, token, old_value, new_value, utc_now_iso()),
)
def _plan_from_price_id(price_id: Optional[str]) -> str:
"""
Map Stripe price -> internal plan.
Defaults to 'basic' if unknown.
"""
pid = (price_id or "").strip()
if STRIPE_PRICE_TEAM and pid == STRIPE_PRICE_TEAM:
return "team"
if STRIPE_PRICE_TEAM_YEARLY and pid == STRIPE_PRICE_TEAM_YEARLY:
return "team"
if STRIPE_PRICE_BASIC and pid == STRIPE_PRICE_BASIC:
return "basic"
if STRIPE_PRICE_BASIC_YEARLY and pid == STRIPE_PRICE_BASIC_YEARLY:
return "basic"
return "basic"
def _extract_price_id_from_subscription(subscription_obj: dict) -> Optional[str]:
items = subscription_obj.get("items")
if not isinstance(items, dict):
return None
data = items.get("data")
if not isinstance(data, list) or not data:
return None
first = data[0]
if not isinstance(first, dict):
return None
price = first.get("price")
if not isinstance(price, dict):
return None
pid = price.get("id")
return str(pid).strip() if pid else None
def _ensure_licenses_customer_id_column(conn: sqlite3.Connection) -> None:
cur = conn.execute("PRAGMA table_info(licenses);")
cols = [r[1] for r in cur.fetchall()] # (cid, name, type, notnull, dflt_value, pk)
if "customer_id" not in cols:
conn.execute("ALTER TABLE licenses ADD COLUMN customer_id TEXT;")
conn.commit()
def init_db() -> None:
conn = get_db()
try:
# Base tables
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT NOT NULL UNIQUE,
salt_hex TEXT NOT NULL,
password_hash_hex TEXT NOT NULL,
plan TEXT NOT NULL DEFAULT 'basic',
status TEXT NOT NULL DEFAULT 'active',
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS tokens (
token TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
created_at TEXT NOT NULL,
last_used_at TEXT,
revoked INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS devices (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
device_id TEXT NOT NULL,
first_seen_at TEXT NOT NULL,
last_seen_at TEXT NOT NULL,
UNIQUE(user_id, device_id),
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS admin_audit (
id INTEGER PRIMARY KEY AUTOINCREMENT,
action TEXT NOT NULL,
email TEXT,
token TEXT,
old_value TEXT,
new_value TEXT,
created_at TEXT NOT NULL
);
"""
)
# Minimal migrations (falls alte DB existiert):
# Prüfe auf Spalten und füge ggf. hinzu.
def column_exists(table: str, col: str) -> bool:
cur = conn.execute(f"PRAGMA table_info({table});")
return any(r["name"] == col for r in cur.fetchall())
if not column_exists("users", "plan"):
conn.execute("ALTER TABLE users ADD COLUMN plan TEXT NOT NULL DEFAULT 'basic';")
if not column_exists("users", "status"):
conn.execute("ALTER TABLE users ADD COLUMN status TEXT NOT NULL DEFAULT 'active';")
# Stripe linkage columns (needed for Billing Portal)
if not column_exists("users", "stripe_customer_id"):
conn.execute("ALTER TABLE users ADD COLUMN stripe_customer_id TEXT;")
if not column_exists("users", "stripe_subscription_id"):
conn.execute("ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT;")
if not column_exists("tokens", "revoked"):
conn.execute("ALTER TABLE tokens ADD COLUMN revoked INTEGER NOT NULL DEFAULT 0;")
conn.commit()
finally:
conn.close()
@app.on_event("startup")
def _startup() -> None:
init_db()
# Register admin routes (kept separate for maintainability)
app.include_router(
build_admin_router(
get_db=get_db,
admin_key=ADMIN_KEY,
log_admin_action=log_admin_action,
)
)
# -----------------------------
# API Models
# -----------------------------
class RegisterRequest(BaseModel):
email: str = Field(..., min_length=3, max_length=255)
password: str = Field(..., min_length=6, max_length=1024)
class RegisterResponse(BaseModel):
ok: bool
message: str
class LoginRequest(BaseModel):
email: str = Field(..., min_length=3, max_length=255)
password: str = Field(..., min_length=6, max_length=1024)
class LoginResponse(BaseModel):
ok: bool
token: str
plan: str
class LicenseCheckRequest(BaseModel):
token: str = Field(..., min_length=1, max_length=4096)
device_id: str = Field(..., min_length=1, max_length=512)
class LicenseCheckResponse(BaseModel):
ok: bool
mode: str # 'active' | 'demo'
plan: str
device_limit: int
devices_registered: int
message: str
# -----------------------------
# Endpoints
# -----------------------------
@app.get("/health")
def health():
return {"ok": True}
@app.get("/version")
def version():
"""
Leak-free version endpoint for deployments.
"""
return {
"name": "AZA",
"build": os.getenv("AZA_BUILD", "dev"),
}
@app.get("/license/status")
def license_status(request: Request, _: None = Depends(require_api_token)):
"""
Production license check (token-authenticated).
Response format must remain: {"valid": bool, "valid_until": int|None}
"""
conn = sqlite3.connect("data/stripe_webhook.sqlite")
try:
cur = conn.execute(
"SELECT status, current_period_end, customer_email FROM licenses ORDER BY updated_at DESC LIMIT 1;"
)
row = cur.fetchone()
if not row:
# No license row -> demo
return {"valid": False, "valid_until": None}
status = row[0]
current_period_end = row[1]
customer_email = row[2]
# Compute license validity (time-based, optional grace via ENV)
decision = compute_license_decision(
current_period_end=int(current_period_end) if current_period_end is not None else None,
status=str(status) if status is not None else None,
)
# If already invalid by time, return early
if not decision.valid:
# IMPORTANT: schema rule
# valid=false -> valid_until MUST be null
return {"valid": False, "valid_until": None}
# Device + seats enforcement
device_id = request.headers.get("X-Device-Id")
user_key = "default"
dd = enforce_and_touch_device(
customer_email=str(customer_email) if customer_email is not None else "",
user_key=user_key,
device_id=device_id,
)
if not dd.allowed:
# IMPORTANT: schema rule
# valid=false -> valid_until MUST be null
return {"valid": False, "valid_until": None}
return {"valid": True, "valid_until": decision.valid_until}
finally:
conn.close()
@app.post("/register", response_model=RegisterResponse)
def register(req: RegisterRequest):
email = req.email.strip().lower()
if not email or "@" not in email:
raise HTTPException(status_code=400, detail="Invalid email")
salt_hex = secrets.token_hex(16)
pw_hash_hex = hash_password(req.password, salt_hex)
conn = get_db()
try:
try:
conn.execute(
"INSERT INTO users (email, salt_hex, password_hash_hex, plan, status, created_at) VALUES (?, ?, ?, ?, ?, ?)",
(email, salt_hex, pw_hash_hex, "basic", "active", utc_now_iso()),
)
conn.commit()
except sqlite3.IntegrityError:
raise HTTPException(status_code=409, detail="User already exists")
return RegisterResponse(ok=True, message="Registered")
finally:
conn.close()
@app.post("/login", response_model=LoginResponse)
def login(req: LoginRequest):
email = req.email.strip().lower()
conn = get_db()
try:
row = conn.execute(
"SELECT id, email, salt_hex, password_hash_hex, plan, status FROM users WHERE email = ?",
(email,),
).fetchone()
if not row:
raise HTTPException(status_code=401, detail="Invalid credentials")
if not verify_password(req.password, row["salt_hex"], row["password_hash_hex"]):
raise HTTPException(status_code=401, detail="Invalid credentials")
# Account status gate (SaaS-ready)
if str(row["status"] or "active").strip().lower() != "active":
raise HTTPException(status_code=403, detail="Account not active")
# Generate a new token per login
token = secrets.token_urlsafe(32)
conn.execute(
"INSERT INTO tokens (token, user_id, created_at, last_used_at, revoked) VALUES (?, ?, ?, ?, 0)",
(token, row["id"], utc_now_iso(), utc_now_iso()),
)
# Enforce max active tokens per user (revoke older tokens)
# Keep newest MAX_ACTIVE_TOKENS_PER_USER non-revoked tokens
max_keep = max(1, int(MAX_ACTIVE_TOKENS_PER_USER))
active_tokens = conn.execute(
"""
SELECT token
FROM tokens
WHERE user_id = ? AND revoked = 0
ORDER BY created_at DESC
""",
(row["id"],),
).fetchall()
if len(active_tokens) > max_keep:
to_revoke = [r["token"] for r in active_tokens[max_keep:]]
conn.executemany(
"UPDATE tokens SET revoked = 1 WHERE token = ?",
[(t,) for t in to_revoke],
)
conn.commit()
return LoginResponse(ok=True, token=token, plan=row["plan"])
finally:
conn.close()
@app.post("/license/check", response_model=LicenseCheckResponse)
def license_check(req: LicenseCheckRequest):
token = req.token.strip()
device_id = req.device_id.strip()
# Optional Demo Token for testing
if token == DEMO_TOKEN:
return LicenseCheckResponse(
ok=True,
mode="demo",
plan="demo",
device_limit=0,
devices_registered=0,
message="Demo token accepted",
)
conn = get_db()
try:
tok = conn.execute(
"""
SELECT t.token, t.user_id, t.revoked, u.plan, u.status
FROM tokens t
JOIN users u ON u.id = t.user_id
WHERE t.token = ?
""",
(token,),
).fetchone()
if not tok or int(tok["revoked"]) == 1:
raise HTTPException(status_code=401, detail="Invalid or revoked token")
# Account status gate (SaaS-ready)
if str(tok["status"] or "active").strip().lower() != "active":
raise HTTPException(status_code=403, detail="Account not active")
user_id = int(tok["user_id"])
plan = str(tok["plan"]) if tok["plan"] else "basic"
# device limit: plan-basiert (Basic=DEVICE_LIMIT_DEFAULT, Pro=4, oder via AZA_PLAN_LIMITS)
device_limit = get_device_limit_for_plan(plan)
# count current devices
existing = conn.execute(
"SELECT id FROM devices WHERE user_id = ? AND device_id = ?",
(user_id, device_id),
).fetchone()
total_devices = conn.execute(
"SELECT COUNT(*) AS c FROM devices WHERE user_id = ?",
(user_id,),
).fetchone()["c"]
if existing:
# touch last_seen
conn.execute(
"UPDATE devices SET last_seen_at = ? WHERE user_id = ? AND device_id = ?",
(utc_now_iso(), user_id, device_id),
)
# touch token last_used
conn.execute(
"UPDATE tokens SET last_used_at = ? WHERE token = ?",
(utc_now_iso(), token),
)
conn.commit()
return LicenseCheckResponse(
ok=True,
mode="active",
plan=plan,
device_limit=device_limit,
devices_registered=int(total_devices),
message="License OK (existing device)",
)
# new device
if int(total_devices) >= int(device_limit):
return LicenseCheckResponse(
ok=False,
mode="active",
plan=plan,
device_limit=device_limit,
devices_registered=int(total_devices),
message=f"Device limit reached ({device_limit}).",
)
conn.execute(
"INSERT INTO devices (user_id, device_id, first_seen_at, last_seen_at) VALUES (?, ?, ?, ?)",
(user_id, device_id, utc_now_iso(), utc_now_iso()),
)
conn.execute(
"UPDATE tokens SET last_used_at = ? WHERE token = ?",
(utc_now_iso(), token),
)
conn.commit()
total_devices_after = conn.execute(
"SELECT COUNT(*) AS c FROM devices WHERE user_id = ?",
(user_id,),
).fetchone()["c"]
return LicenseCheckResponse(
ok=True,
mode="active",
plan=plan,
device_limit=device_limit,
devices_registered=int(total_devices_after),
message="License OK (new device registered)",
)
finally:
conn.close()
@app.post("/stripe/webhook")
async def stripe_webhook(request: Request):
"""
Stripe Webhook Skeleton (nur Empfang + Signaturprüfung + Audit-Log).
Noch keine Business-Logik (Plan/Status-Updates) hier drin.
"""
payload = await request.body()
sig_header = request.headers.get("stripe-signature")
webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET")
if not webhook_secret:
raise HTTPException(status_code=503, detail="Stripe webhook not configured")
if not sig_header:
raise HTTPException(status_code=400, detail="Missing Stripe-Signature")
try:
event = stripe.Webhook.construct_event(payload=payload, sig_header=sig_header, secret=webhook_secret)
except Exception:
# invalid payload or invalid signature
raise HTTPException(status_code=400, detail="Invalid webhook signature")
stripe.api_key = STRIPE_SECRET_KEY
# --- Idempotency: prevent duplicate event processing ---
stripe_conn = sqlite3.connect("data/stripe_webhook.sqlite")
try:
claimed = try_claim_event(stripe_conn, event["id"])
_ensure_licenses_customer_id_column(stripe_conn)
if not claimed:
# Duplicate delivery -> already processed
return {"status": "duplicate"}
obj = event.get("data", {}).get("object", {}) or {}
customer_id = obj.get("customer")
customer_id = str(customer_id).strip() if customer_id else ""
finally:
stripe_conn.close()
event_type = getattr(event, "type", None) or event.get("type", "unknown")
event_id = getattr(event, "id", None) or event.get("id", "")
data_container = getattr(event, "data", None) or event.get("data", {})
data_obj = data_container.get("object", {}) if isinstance(data_container, dict) else {}
# Minimal business sync (status/plan), plus audit logs
conn = get_db()
try:
email: Optional[str] = None
new_status: Optional[str] = None
new_plan: Optional[str] = None
new_sub_id: Optional[str] = None
# Try to resolve user by Stripe customer id
if isinstance(data_obj, dict):
customer_id = customer_id or str(data_obj.get("customer") or "").strip()
if not customer_id:
customer_details = data_obj.get("customer_details")
if isinstance(customer_details, dict):
customer_email = (customer_details.get("email") or "").strip().lower()
if customer_email:
email = customer_email
if not email:
email = (data_obj.get("customer_email") or "").strip().lower() or None
if customer_id:
user_row = conn.execute(
"SELECT email FROM users WHERE stripe_customer_id = ?",
(customer_id,),
).fetchone()
if user_row:
email = str(user_row["email"]).strip().lower()
if event_type in ("customer.subscription.created", "customer.subscription.updated"):
# Sync plan + status from subscription object
if isinstance(data_obj, dict):
new_sub_id = (data_obj.get("id") or "").strip() or None
price_id = _extract_price_id_from_subscription(data_obj)
new_plan = _plan_from_price_id(price_id)
sub_status = (data_obj.get("status") or "").strip().lower()
cancel_at_period_end = bool(data_obj.get("cancel_at_period_end", False))
current_period_end = data_obj.get("current_period_end")
# If user cancels but "at period end", keep active until the period really ends.
if cancel_at_period_end:
# keep access until end of paid period (monthly or yearly)
if sub_status in ("active", "trialing"):
new_status = "active"
# Log the scheduled cancellation for support/audit
log_admin_action(
conn,
action="stripe:cancel_at_period_end",
email=email,
token=event_id,
old_value="active",
new_value=f"scheduled_end:{current_period_end}" if current_period_end else "scheduled_end",
)
else:
if sub_status in ("active", "trialing"):
new_status = "active"
elif sub_status in ("past_due", "unpaid", "paused"):
new_status = "suspended"
elif event_type == "customer.subscription.deleted":
# This happens when the subscription is actually ended.
# For "cancel at period end", Stripe will only send deleted at the real end.
new_status = "cancelled"
if isinstance(data_obj, dict):
new_sub_id = (data_obj.get("id") or "").strip() or None
# Persist updates if we can map event to a known user
if email and (new_status or new_plan or new_sub_id):
row = conn.execute(
"SELECT plan, status, stripe_subscription_id FROM users WHERE email = ?",
(email,),
).fetchone()
if row:
plan_to_set = new_plan or str(row["plan"] or "basic")
status_to_set = new_status or str(row["status"] or "active")
sub_to_set = new_sub_id or row["stripe_subscription_id"]
conn.execute(
"UPDATE users SET plan = ?, status = ?, stripe_subscription_id = ? WHERE email = ?",
(plan_to_set, status_to_set, sub_to_set, email),
)
if email and customer_id:
conn.execute(
"UPDATE users SET stripe_customer_id = ? WHERE email = ?",
(customer_id, email),
)
# Best effort: persist customer_id in licenses table used by Stripe DB logic.
try:
stripe_conn = sqlite3.connect("data/stripe_webhook.sqlite")
try:
_ensure_licenses_customer_id_column(stripe_conn)
cur = stripe_conn.execute(
"UPDATE licenses SET customer_id = ?, updated_at = strftime('%s','now') WHERE customer_email = ?;",
(customer_id, email),
)
stripe_conn.commit()
if cur.rowcount == 0:
# No existing license row for this email -> insert minimal row so billing portal works
stripe_conn.execute(
"""
INSERT INTO licenses (customer_email, customer_id, updated_at)
VALUES (?, ?, strftime('%s','now'));
""",
(email, customer_id),
)
stripe_conn.commit()
finally:
stripe_conn.close()
except Exception:
pass
log_admin_action(
conn,
action=f"stripe:{event_type}",
email=email,
token=event_id,
old_value="received",
new_value="ok",
)
conn.commit()
finally:
conn.close()
return {"ok": True}
# --------------------------------
# Stripe Checkout + Billing Portal
# --------------------------------
class CreateCheckoutRequest(BaseModel):
email: str = Field(..., min_length=3, max_length=255)
plan: str = Field(..., min_length=1, max_length=64) # basic | team
billing_cycle: str = Field(..., min_length=1, max_length=16) # monthly | yearly
class CreateCheckoutResponse(BaseModel):
ok: bool
url: str
@app.post("/stripe/create_checkout_session", response_model=CreateCheckoutResponse)
def create_checkout_session(req: CreateCheckoutRequest):
"""
Erstellt eine Stripe Checkout Session für ein Abo.
Unterstützt Karten, Apple Pay, Google Pay und TWINT (wenn im Stripe-Dashboard aktiviert).
"""
if stripe is None:
raise HTTPException(status_code=503, detail="Stripe SDK not installed")
if not STRIPE_SECRET_KEY:
raise HTTPException(status_code=503, detail="Stripe not configured")
stripe.api_key = STRIPE_SECRET_KEY
email = req.email.strip().lower()
plan = req.plan.strip().lower()
billing = req.billing_cycle.strip().lower()
if plan == "team":
if billing == "yearly":
price_id = STRIPE_PRICE_TEAM_YEARLY
else:
price_id = STRIPE_PRICE_TEAM
else:
if billing == "yearly":
price_id = STRIPE_PRICE_BASIC_YEARLY
else:
price_id = STRIPE_PRICE_BASIC
if not price_id:
raise HTTPException(status_code=400, detail="Stripe price not configured")
try:
session = stripe.checkout.Session.create( # type: ignore[attr-defined]
mode="subscription",
payment_method_types=["card", "twint"],
customer_email=email,
line_items=[{"price": price_id, "quantity": 1}],
success_url=STRIPE_SUCCESS_URL or "https://example.com/success",
cancel_url=STRIPE_CANCEL_URL or "https://example.com/cancel",
)
return CreateCheckoutResponse(ok=True, url=session.url)
except Exception:
raise HTTPException(status_code=400, detail="Stripe checkout creation failed")
class CreatePortalRequest(BaseModel):
email: str = Field(..., min_length=3, max_length=255)
class CreatePortalResponse(BaseModel):
ok: bool
url: str
@app.get("/stripe/billing_portal_url")
def stripe_billing_portal_url(request: Request, _: None = Depends(require_api_token)):
"""
Creates a Stripe Billing Portal session for the current customer and returns the URL.
Assumes single-tenant local desktop usage: uses the most recently updated license row.
"""
# Stripe secret key must be configured server-side
secret_key = os.getenv("STRIPE_SECRET_KEY") or os.getenv("STRIPE_API_KEY")
if not secret_key:
raise HTTPException(status_code=503, detail="Stripe API key not configured")
stripe.api_key = secret_key
# Return URL (where Stripe sends the user back)
return_url = os.getenv("AZA_PORTAL_RETURN_URL")
if not return_url:
# best effort: send back to API root (safe default)
base = str(request.base_url).rstrip("/")
return_url = base + "/"
conn = sqlite3.connect("data/stripe_webhook.sqlite")
try:
_ensure_licenses_customer_id_column(conn)
cur = conn.execute(
"SELECT customer_id FROM licenses WHERE customer_id IS NOT NULL ORDER BY updated_at DESC LIMIT 1;"
)
row = cur.fetchone()
if not row or not row[0]:
raise HTTPException(status_code=404, detail="No Stripe customer found")
customer_id = str(row[0])
finally:
conn.close()
try:
session = stripe.billing_portal.Session.create(
customer=customer_id,
return_url=return_url,
)
except Exception:
raise HTTPException(status_code=502, detail="Failed to create billing portal session")
return {"url": session.url}
class DeviceResetRequest(BaseModel):
customer_email: str
user_key: str = "default"
@app.post("/stripe/create_billing_portal", response_model=CreatePortalResponse)
def create_billing_portal(req: CreatePortalRequest):
"""
Erstellt einen Stripe Billing Portal Link.
Kunde kann selbst kündigen, Zahlungsmittel ändern, upgraden.
"""
if stripe is None:
raise HTTPException(status_code=503, detail="Stripe SDK not installed")
if not STRIPE_SECRET_KEY:
raise HTTPException(status_code=503, detail="Stripe not configured")
stripe.api_key = STRIPE_SECRET_KEY
email = req.email.strip().lower()
conn = get_db()
try:
row = conn.execute(
"SELECT stripe_customer_id FROM users WHERE email = ?",
(email,),
).fetchone()
if not row or not row["stripe_customer_id"]:
raise HTTPException(status_code=404, detail="Stripe customer not found")
session = stripe.billing_portal.Session.create( # type: ignore[attr-defined]
customer=row["stripe_customer_id"],
return_url=STRIPE_SUCCESS_URL or "https://example.com",
)
return CreatePortalResponse(ok=True, url=session.url)
finally:
conn.close()
@app.post("/admin/devices/reset")
def admin_devices_reset(payload: DeviceResetRequest, _: None = Depends(require_admin_token)):
"""
Support endpoint: remove all device bindings for a given (customer_email, user_key).
Requires X-Admin-Token.
"""
conn = sqlite3.connect("data/stripe_webhook.sqlite")
try:
conn.execute(
"DELETE FROM device_bindings WHERE customer_email = ? AND user_key = ?;",
(payload.customer_email, payload.user_key),
)
conn.commit()
cur = conn.execute(
"SELECT COUNT(*) FROM device_bindings WHERE customer_email = ? AND user_key = ?;",
(payload.customer_email, payload.user_key),
)
remaining = int(cur.fetchone()[0])
finally:
conn.close()
return {"ok": True, "remaining": remaining}
@app.get("/admin/usage")
def admin_usage(
customer_email: str = Query(...),
_: None = Depends(require_admin_token),
):
"""
Admin-only: Return usage counters for a given customer_email.
No device identifiers are returned.
"""
conn = sqlite3.connect("data/stripe_webhook.sqlite")
try:
# seats used
cur = conn.execute(
"SELECT COUNT(DISTINCT user_key) FROM device_bindings WHERE customer_email = ?;",
(customer_email,),
)
seats_used = int(cur.fetchone()[0])
# devices per seat
cur2 = conn.execute(
"""
SELECT user_key, COUNT(*) as devices
FROM device_bindings
WHERE customer_email = ?
GROUP BY user_key
ORDER BY user_key ASC;
""",
(customer_email,),
)
per_user = [{"user_key": r[0], "devices": int(r[1])} for r in cur2.fetchall()]
# license limits (best effort)
cur3 = conn.execute(
"SELECT allowed_users, devices_per_user FROM licenses WHERE customer_email = ? LIMIT 1;",
(customer_email,),
)
row = cur3.fetchone()
if row:
allowed_users = int(row[0]) if row[0] is not None else 1
devices_per_user = int(row[1]) if row[1] is not None else 1
else:
allowed_users = 0
devices_per_user = 0
finally:
conn.close()
return {
"customer_email": customer_email,
"allowed_users": allowed_users,
"devices_per_user": devices_per_user,
"seats_used": seats_used,
"per_user": per_user,
}