""" AZA / MedWork - License Server (FastAPI) Meilenstein A: - Login generiert echte Tokens (DB-backed) - Token an User gebunden - Token in DB gespeichert - /license/check validiert Token via DB (kein hardcoded if mehr) - Geräte-Limit (default 2) bleibt aktiv SQLite DB: aza_license.db (wird automatisch initialisiert/migriert) """ from __future__ import annotations import os import logging import sqlite3 import secrets import hashlib import hmac import uuid from datetime import datetime, timezone from typing import Optional # Auto-load .env if python-dotenv is available (no hard dependency) try: from dotenv import load_dotenv # type: ignore except Exception: load_dotenv = None if load_dotenv is not None: # Loads variables from .env into process environment (safe if file missing) load_dotenv() from fastapi import Depends, FastAPI, HTTPException, Query, Request from pydantic import BaseModel, Field from starlette.middleware.base import BaseHTTPMiddleware from admin_routes import build_admin_router from aza_security import require_admin_token, require_api_token from aza_license_logic import compute_license_decision from aza_device_enforcement import enforce_and_touch_device from aza_stripe_idempotency import try_claim_event import stripe LOG_LEVEL = os.getenv("AZA_LOG_LEVEL", "INFO").upper() logging.basicConfig(level=getattr(logging, LOG_LEVEL, logging.INFO)) logger = logging.getLogger("aza") class RequestIdMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): request_id = request.headers.get("x-request-id") or str(uuid.uuid4()) request.state.request_id = request_id response = await call_next(request) response.headers["X-Request-Id"] = request_id return response class SimpleProxyHeadersMiddleware(BaseHTTPMiddleware): """ Minimal proxy header support without relying on Starlette's ProxyHeadersMiddleware module. - Sets request.state.client_ip from X-Forwarded-For if present - Sets request.state.scheme from X-Forwarded-Proto if present """ async def dispatch(self, request: Request, call_next): xff = request.headers.get("x-forwarded-for") if xff: # Take the first IP in the list (client) request.state.client_ip = xff.split(",")[0].strip() else: request.state.client_ip = request.client.host if request.client else "unknown" xfp = request.headers.get("x-forwarded-proto") request.state.scheme = (xfp.split(",")[0].strip() if xfp else request.url.scheme) return await call_next(request) # ----------------------------- # Config # ----------------------------- DB_PATH = os.environ.get("AZA_LICENSE_DB", "aza_license.db") DEVICE_LIMIT_DEFAULT = int(os.environ.get("AZA_DEVICE_LIMIT_DEFAULT", "2")) PLAN_LIMITS_RAW = os.environ.get("AZA_PLAN_LIMITS", "").strip() ADMIN_KEY = os.environ.get("AZA_ADMIN_KEY", "").strip() MAX_ACTIVE_TOKENS_PER_USER = int(os.environ.get("AZA_MAX_ACTIVE_TOKENS_PER_USER", "3")) STRIPE_SECRET_KEY = os.environ.get("AZA_STRIPE_SECRET_KEY", "").strip() STRIPE_WEBHOOK_SECRET = os.environ.get("AZA_STRIPE_WEBHOOK_SECRET", "").strip() STRIPE_PRICE_BASIC = os.environ.get("AZA_STRIPE_PRICE_BASIC", "").strip() STRIPE_PRICE_TEAM = os.environ.get("AZA_STRIPE_PRICE_TEAM", "").strip() STRIPE_PRICE_BASIC_YEARLY = os.environ.get("AZA_STRIPE_PRICE_BASIC_YEARLY", "").strip() STRIPE_PRICE_TEAM_YEARLY = os.environ.get("AZA_STRIPE_PRICE_TEAM_YEARLY", "").strip() STRIPE_SUCCESS_URL = os.environ.get("AZA_STRIPE_SUCCESS_URL", "").strip() STRIPE_CANCEL_URL = os.environ.get("AZA_STRIPE_CANCEL_URL", "").strip() # Optional: Demo-Token (für Tests). Client nutzt Demo i.d.R. lokal. DEMO_TOKEN = os.environ.get("AZA_DEMO_TOKEN", "DEMO") app = FastAPI(title="AZA License Server", version="1.0.0") app.add_middleware(SimpleProxyHeadersMiddleware) app.add_middleware(RequestIdMiddleware) # ----------------------------- # Helpers # ----------------------------- def utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat() def get_db() -> sqlite3.Connection: conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row conn.execute("PRAGMA foreign_keys = ON;") return conn def sha256_hex(data: bytes) -> str: return hashlib.sha256(data).hexdigest() def hash_password(password: str, salt_hex: str) -> str: # Simple, deterministic: sha256(salt + password) # (Für später: argon2/bcrypt möglich, aber jetzt nicht nötig.) return sha256_hex(bytes.fromhex(salt_hex) + password.encode("utf-8")) def verify_password(password: str, salt_hex: str, pw_hash_hex: str) -> bool: candidate = hash_password(password, salt_hex) return hmac.compare_digest(candidate, pw_hash_hex) def parse_plan_limits(raw: str) -> dict[str, int]: """ Parses AZA_PLAN_LIMITS like: "basic:2,pro:4,business:10" Returns dict {plan: limit} Invalid entries are ignored. """ out: dict[str, int] = {} if not raw: return out parts = [p.strip() for p in raw.split(",") if p.strip()] for p in parts: if ":" not in p: continue k, v = p.split(":", 1) plan = k.strip().lower() try: limit = int(v.strip()) except Exception: continue if plan and limit >= 0: out[plan] = limit return out PLAN_LIMITS = parse_plan_limits(PLAN_LIMITS_RAW) def get_device_limit_for_plan(plan: str) -> int: """ Returns device limit for a plan. Business definition: basic -> 3 devices team -> 5 devices Fallback -> 3 devices """ p = (plan or "basic").strip().lower() if p == "team": return 5 # default plan return 3 def log_admin_action( conn: sqlite3.Connection, action: str, email: Optional[str] = None, token: Optional[str] = None, old_value: Optional[str] = None, new_value: Optional[str] = None, ) -> None: conn.execute( """ INSERT INTO admin_audit (action, email, token, old_value, new_value, created_at) VALUES (?, ?, ?, ?, ?, ?) """, (action, email, token, old_value, new_value, utc_now_iso()), ) def _plan_from_price_id(price_id: Optional[str]) -> str: """ Map Stripe price -> internal plan. Defaults to 'basic' if unknown. """ pid = (price_id or "").strip() if STRIPE_PRICE_TEAM and pid == STRIPE_PRICE_TEAM: return "team" if STRIPE_PRICE_TEAM_YEARLY and pid == STRIPE_PRICE_TEAM_YEARLY: return "team" if STRIPE_PRICE_BASIC and pid == STRIPE_PRICE_BASIC: return "basic" if STRIPE_PRICE_BASIC_YEARLY and pid == STRIPE_PRICE_BASIC_YEARLY: return "basic" return "basic" def _extract_price_id_from_subscription(subscription_obj: dict) -> Optional[str]: items = subscription_obj.get("items") if not isinstance(items, dict): return None data = items.get("data") if not isinstance(data, list) or not data: return None first = data[0] if not isinstance(first, dict): return None price = first.get("price") if not isinstance(price, dict): return None pid = price.get("id") return str(pid).strip() if pid else None def _ensure_licenses_customer_id_column(conn: sqlite3.Connection) -> None: cur = conn.execute("PRAGMA table_info(licenses);") cols = [r[1] for r in cur.fetchall()] # (cid, name, type, notnull, dflt_value, pk) if "customer_id" not in cols: conn.execute("ALTER TABLE licenses ADD COLUMN customer_id TEXT;") conn.commit() def init_db() -> None: conn = get_db() try: # Base tables conn.executescript( """ CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT NOT NULL UNIQUE, salt_hex TEXT NOT NULL, password_hash_hex TEXT NOT NULL, plan TEXT NOT NULL DEFAULT 'basic', status TEXT NOT NULL DEFAULT 'active', created_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS tokens ( token TEXT PRIMARY KEY, user_id INTEGER NOT NULL, created_at TEXT NOT NULL, last_used_at TEXT, revoked INTEGER NOT NULL DEFAULT 0, FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS devices ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, device_id TEXT NOT NULL, first_seen_at TEXT NOT NULL, last_seen_at TEXT NOT NULL, UNIQUE(user_id, device_id), FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE ); CREATE TABLE IF NOT EXISTS admin_audit ( id INTEGER PRIMARY KEY AUTOINCREMENT, action TEXT NOT NULL, email TEXT, token TEXT, old_value TEXT, new_value TEXT, created_at TEXT NOT NULL ); """ ) # Minimal migrations (falls alte DB existiert): # Prüfe auf Spalten und füge ggf. hinzu. def column_exists(table: str, col: str) -> bool: cur = conn.execute(f"PRAGMA table_info({table});") return any(r["name"] == col for r in cur.fetchall()) if not column_exists("users", "plan"): conn.execute("ALTER TABLE users ADD COLUMN plan TEXT NOT NULL DEFAULT 'basic';") if not column_exists("users", "status"): conn.execute("ALTER TABLE users ADD COLUMN status TEXT NOT NULL DEFAULT 'active';") # Stripe linkage columns (needed for Billing Portal) if not column_exists("users", "stripe_customer_id"): conn.execute("ALTER TABLE users ADD COLUMN stripe_customer_id TEXT;") if not column_exists("users", "stripe_subscription_id"): conn.execute("ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT;") if not column_exists("tokens", "revoked"): conn.execute("ALTER TABLE tokens ADD COLUMN revoked INTEGER NOT NULL DEFAULT 0;") conn.commit() finally: conn.close() @app.on_event("startup") def _startup() -> None: init_db() # Register admin routes (kept separate for maintainability) app.include_router( build_admin_router( get_db=get_db, admin_key=ADMIN_KEY, log_admin_action=log_admin_action, ) ) # ----------------------------- # API Models # ----------------------------- class RegisterRequest(BaseModel): email: str = Field(..., min_length=3, max_length=255) password: str = Field(..., min_length=6, max_length=1024) class RegisterResponse(BaseModel): ok: bool message: str class LoginRequest(BaseModel): email: str = Field(..., min_length=3, max_length=255) password: str = Field(..., min_length=6, max_length=1024) class LoginResponse(BaseModel): ok: bool token: str plan: str class LicenseCheckRequest(BaseModel): token: str = Field(..., min_length=1, max_length=4096) device_id: str = Field(..., min_length=1, max_length=512) class LicenseCheckResponse(BaseModel): ok: bool mode: str # 'active' | 'demo' plan: str device_limit: int devices_registered: int message: str # ----------------------------- # Endpoints # ----------------------------- @app.get("/health") def health(): return {"ok": True} @app.get("/version") def version(): """ Leak-free version endpoint for deployments. """ return { "name": "AZA", "build": os.getenv("AZA_BUILD", "dev"), } @app.get("/license/status") def license_status(request: Request, _: None = Depends(require_api_token)): """ Production license check (token-authenticated). Response format must remain: {"valid": bool, "valid_until": int|None} """ conn = sqlite3.connect("data/stripe_webhook.sqlite") try: cur = conn.execute( "SELECT status, current_period_end, customer_email FROM licenses ORDER BY updated_at DESC LIMIT 1;" ) row = cur.fetchone() if not row: # No license row -> demo return {"valid": False, "valid_until": None} status = row[0] current_period_end = row[1] customer_email = row[2] # Compute license validity (time-based, optional grace via ENV) decision = compute_license_decision( current_period_end=int(current_period_end) if current_period_end is not None else None, status=str(status) if status is not None else None, ) # If already invalid by time, return early if not decision.valid: # IMPORTANT: schema rule # valid=false -> valid_until MUST be null return {"valid": False, "valid_until": None} # Device + seats enforcement device_id = request.headers.get("X-Device-Id") user_key = "default" dd = enforce_and_touch_device( customer_email=str(customer_email) if customer_email is not None else "", user_key=user_key, device_id=device_id, ) if not dd.allowed: # IMPORTANT: schema rule # valid=false -> valid_until MUST be null return {"valid": False, "valid_until": None} return {"valid": True, "valid_until": decision.valid_until} finally: conn.close() @app.post("/register", response_model=RegisterResponse) def register(req: RegisterRequest): email = req.email.strip().lower() if not email or "@" not in email: raise HTTPException(status_code=400, detail="Invalid email") salt_hex = secrets.token_hex(16) pw_hash_hex = hash_password(req.password, salt_hex) conn = get_db() try: try: conn.execute( "INSERT INTO users (email, salt_hex, password_hash_hex, plan, status, created_at) VALUES (?, ?, ?, ?, ?, ?)", (email, salt_hex, pw_hash_hex, "basic", "active", utc_now_iso()), ) conn.commit() except sqlite3.IntegrityError: raise HTTPException(status_code=409, detail="User already exists") return RegisterResponse(ok=True, message="Registered") finally: conn.close() @app.post("/login", response_model=LoginResponse) def login(req: LoginRequest): email = req.email.strip().lower() conn = get_db() try: row = conn.execute( "SELECT id, email, salt_hex, password_hash_hex, plan, status FROM users WHERE email = ?", (email,), ).fetchone() if not row: raise HTTPException(status_code=401, detail="Invalid credentials") if not verify_password(req.password, row["salt_hex"], row["password_hash_hex"]): raise HTTPException(status_code=401, detail="Invalid credentials") # Account status gate (SaaS-ready) if str(row["status"] or "active").strip().lower() != "active": raise HTTPException(status_code=403, detail="Account not active") # Generate a new token per login token = secrets.token_urlsafe(32) conn.execute( "INSERT INTO tokens (token, user_id, created_at, last_used_at, revoked) VALUES (?, ?, ?, ?, 0)", (token, row["id"], utc_now_iso(), utc_now_iso()), ) # Enforce max active tokens per user (revoke older tokens) # Keep newest MAX_ACTIVE_TOKENS_PER_USER non-revoked tokens max_keep = max(1, int(MAX_ACTIVE_TOKENS_PER_USER)) active_tokens = conn.execute( """ SELECT token FROM tokens WHERE user_id = ? AND revoked = 0 ORDER BY created_at DESC """, (row["id"],), ).fetchall() if len(active_tokens) > max_keep: to_revoke = [r["token"] for r in active_tokens[max_keep:]] conn.executemany( "UPDATE tokens SET revoked = 1 WHERE token = ?", [(t,) for t in to_revoke], ) conn.commit() return LoginResponse(ok=True, token=token, plan=row["plan"]) finally: conn.close() @app.post("/license/check", response_model=LicenseCheckResponse) def license_check(req: LicenseCheckRequest): token = req.token.strip() device_id = req.device_id.strip() # Optional Demo Token for testing if token == DEMO_TOKEN: return LicenseCheckResponse( ok=True, mode="demo", plan="demo", device_limit=0, devices_registered=0, message="Demo token accepted", ) conn = get_db() try: tok = conn.execute( """ SELECT t.token, t.user_id, t.revoked, u.plan, u.status FROM tokens t JOIN users u ON u.id = t.user_id WHERE t.token = ? """, (token,), ).fetchone() if not tok or int(tok["revoked"]) == 1: raise HTTPException(status_code=401, detail="Invalid or revoked token") # Account status gate (SaaS-ready) if str(tok["status"] or "active").strip().lower() != "active": raise HTTPException(status_code=403, detail="Account not active") user_id = int(tok["user_id"]) plan = str(tok["plan"]) if tok["plan"] else "basic" # device limit: plan-basiert (Basic=DEVICE_LIMIT_DEFAULT, Pro=4, oder via AZA_PLAN_LIMITS) device_limit = get_device_limit_for_plan(plan) # count current devices existing = conn.execute( "SELECT id FROM devices WHERE user_id = ? AND device_id = ?", (user_id, device_id), ).fetchone() total_devices = conn.execute( "SELECT COUNT(*) AS c FROM devices WHERE user_id = ?", (user_id,), ).fetchone()["c"] if existing: # touch last_seen conn.execute( "UPDATE devices SET last_seen_at = ? WHERE user_id = ? AND device_id = ?", (utc_now_iso(), user_id, device_id), ) # touch token last_used conn.execute( "UPDATE tokens SET last_used_at = ? WHERE token = ?", (utc_now_iso(), token), ) conn.commit() return LicenseCheckResponse( ok=True, mode="active", plan=plan, device_limit=device_limit, devices_registered=int(total_devices), message="License OK (existing device)", ) # new device if int(total_devices) >= int(device_limit): return LicenseCheckResponse( ok=False, mode="active", plan=plan, device_limit=device_limit, devices_registered=int(total_devices), message=f"Device limit reached ({device_limit}).", ) conn.execute( "INSERT INTO devices (user_id, device_id, first_seen_at, last_seen_at) VALUES (?, ?, ?, ?)", (user_id, device_id, utc_now_iso(), utc_now_iso()), ) conn.execute( "UPDATE tokens SET last_used_at = ? WHERE token = ?", (utc_now_iso(), token), ) conn.commit() total_devices_after = conn.execute( "SELECT COUNT(*) AS c FROM devices WHERE user_id = ?", (user_id,), ).fetchone()["c"] return LicenseCheckResponse( ok=True, mode="active", plan=plan, device_limit=device_limit, devices_registered=int(total_devices_after), message="License OK (new device registered)", ) finally: conn.close() @app.post("/stripe/webhook") async def stripe_webhook(request: Request): """ Stripe Webhook Skeleton (nur Empfang + Signaturprüfung + Audit-Log). Noch keine Business-Logik (Plan/Status-Updates) hier drin. """ payload = await request.body() sig_header = request.headers.get("stripe-signature") webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET") if not webhook_secret: raise HTTPException(status_code=503, detail="Stripe webhook not configured") if not sig_header: raise HTTPException(status_code=400, detail="Missing Stripe-Signature") try: event = stripe.Webhook.construct_event(payload=payload, sig_header=sig_header, secret=webhook_secret) except Exception: # invalid payload or invalid signature raise HTTPException(status_code=400, detail="Invalid webhook signature") stripe.api_key = STRIPE_SECRET_KEY # --- Idempotency: prevent duplicate event processing --- stripe_conn = sqlite3.connect("data/stripe_webhook.sqlite") try: claimed = try_claim_event(stripe_conn, event["id"]) _ensure_licenses_customer_id_column(stripe_conn) if not claimed: # Duplicate delivery -> already processed return {"status": "duplicate"} obj = event.get("data", {}).get("object", {}) or {} customer_id = obj.get("customer") customer_id = str(customer_id).strip() if customer_id else "" finally: stripe_conn.close() event_type = getattr(event, "type", None) or event.get("type", "unknown") event_id = getattr(event, "id", None) or event.get("id", "") data_container = getattr(event, "data", None) or event.get("data", {}) data_obj = data_container.get("object", {}) if isinstance(data_container, dict) else {} # Minimal business sync (status/plan), plus audit logs conn = get_db() try: email: Optional[str] = None new_status: Optional[str] = None new_plan: Optional[str] = None new_sub_id: Optional[str] = None # Try to resolve user by Stripe customer id if isinstance(data_obj, dict): customer_id = customer_id or str(data_obj.get("customer") or "").strip() if not customer_id: customer_details = data_obj.get("customer_details") if isinstance(customer_details, dict): customer_email = (customer_details.get("email") or "").strip().lower() if customer_email: email = customer_email if not email: email = (data_obj.get("customer_email") or "").strip().lower() or None if customer_id: user_row = conn.execute( "SELECT email FROM users WHERE stripe_customer_id = ?", (customer_id,), ).fetchone() if user_row: email = str(user_row["email"]).strip().lower() if event_type in ("customer.subscription.created", "customer.subscription.updated"): # Sync plan + status from subscription object if isinstance(data_obj, dict): new_sub_id = (data_obj.get("id") or "").strip() or None price_id = _extract_price_id_from_subscription(data_obj) new_plan = _plan_from_price_id(price_id) sub_status = (data_obj.get("status") or "").strip().lower() cancel_at_period_end = bool(data_obj.get("cancel_at_period_end", False)) current_period_end = data_obj.get("current_period_end") # If user cancels but "at period end", keep active until the period really ends. if cancel_at_period_end: # keep access until end of paid period (monthly or yearly) if sub_status in ("active", "trialing"): new_status = "active" # Log the scheduled cancellation for support/audit log_admin_action( conn, action="stripe:cancel_at_period_end", email=email, token=event_id, old_value="active", new_value=f"scheduled_end:{current_period_end}" if current_period_end else "scheduled_end", ) else: if sub_status in ("active", "trialing"): new_status = "active" elif sub_status in ("past_due", "unpaid", "paused"): new_status = "suspended" elif event_type == "customer.subscription.deleted": # This happens when the subscription is actually ended. # For "cancel at period end", Stripe will only send deleted at the real end. new_status = "cancelled" if isinstance(data_obj, dict): new_sub_id = (data_obj.get("id") or "").strip() or None # Persist updates if we can map event to a known user if email and (new_status or new_plan or new_sub_id): row = conn.execute( "SELECT plan, status, stripe_subscription_id FROM users WHERE email = ?", (email,), ).fetchone() if row: plan_to_set = new_plan or str(row["plan"] or "basic") status_to_set = new_status or str(row["status"] or "active") sub_to_set = new_sub_id or row["stripe_subscription_id"] conn.execute( "UPDATE users SET plan = ?, status = ?, stripe_subscription_id = ? WHERE email = ?", (plan_to_set, status_to_set, sub_to_set, email), ) if email and customer_id: conn.execute( "UPDATE users SET stripe_customer_id = ? WHERE email = ?", (customer_id, email), ) # Best effort: persist customer_id in licenses table used by Stripe DB logic. try: stripe_conn = sqlite3.connect("data/stripe_webhook.sqlite") try: _ensure_licenses_customer_id_column(stripe_conn) cur = stripe_conn.execute( "UPDATE licenses SET customer_id = ?, updated_at = strftime('%s','now') WHERE customer_email = ?;", (customer_id, email), ) stripe_conn.commit() if cur.rowcount == 0: # No existing license row for this email -> insert minimal row so billing portal works stripe_conn.execute( """ INSERT INTO licenses (customer_email, customer_id, updated_at) VALUES (?, ?, strftime('%s','now')); """, (email, customer_id), ) stripe_conn.commit() finally: stripe_conn.close() except Exception: pass log_admin_action( conn, action=f"stripe:{event_type}", email=email, token=event_id, old_value="received", new_value="ok", ) conn.commit() finally: conn.close() return {"ok": True} # -------------------------------- # Stripe Checkout + Billing Portal # -------------------------------- class CreateCheckoutRequest(BaseModel): email: str = Field(..., min_length=3, max_length=255) plan: str = Field(..., min_length=1, max_length=64) # basic | team billing_cycle: str = Field(..., min_length=1, max_length=16) # monthly | yearly class CreateCheckoutResponse(BaseModel): ok: bool url: str @app.post("/stripe/create_checkout_session", response_model=CreateCheckoutResponse) def create_checkout_session(req: CreateCheckoutRequest): """ Erstellt eine Stripe Checkout Session für ein Abo. Unterstützt Karten, Apple Pay, Google Pay und TWINT (wenn im Stripe-Dashboard aktiviert). """ if stripe is None: raise HTTPException(status_code=503, detail="Stripe SDK not installed") if not STRIPE_SECRET_KEY: raise HTTPException(status_code=503, detail="Stripe not configured") stripe.api_key = STRIPE_SECRET_KEY email = req.email.strip().lower() plan = req.plan.strip().lower() billing = req.billing_cycle.strip().lower() if plan == "team": if billing == "yearly": price_id = STRIPE_PRICE_TEAM_YEARLY else: price_id = STRIPE_PRICE_TEAM else: if billing == "yearly": price_id = STRIPE_PRICE_BASIC_YEARLY else: price_id = STRIPE_PRICE_BASIC if not price_id: raise HTTPException(status_code=400, detail="Stripe price not configured") try: session = stripe.checkout.Session.create( # type: ignore[attr-defined] mode="subscription", payment_method_types=["card", "twint"], customer_email=email, line_items=[{"price": price_id, "quantity": 1}], success_url=STRIPE_SUCCESS_URL or "https://example.com/success", cancel_url=STRIPE_CANCEL_URL or "https://example.com/cancel", ) return CreateCheckoutResponse(ok=True, url=session.url) except Exception: raise HTTPException(status_code=400, detail="Stripe checkout creation failed") class CreatePortalRequest(BaseModel): email: str = Field(..., min_length=3, max_length=255) class CreatePortalResponse(BaseModel): ok: bool url: str @app.get("/stripe/billing_portal_url") def stripe_billing_portal_url(request: Request, _: None = Depends(require_api_token)): """ Creates a Stripe Billing Portal session for the current customer and returns the URL. Assumes single-tenant local desktop usage: uses the most recently updated license row. """ # Stripe secret key must be configured server-side secret_key = os.getenv("STRIPE_SECRET_KEY") or os.getenv("STRIPE_API_KEY") if not secret_key: raise HTTPException(status_code=503, detail="Stripe API key not configured") stripe.api_key = secret_key # Return URL (where Stripe sends the user back) return_url = os.getenv("AZA_PORTAL_RETURN_URL") if not return_url: # best effort: send back to API root (safe default) base = str(request.base_url).rstrip("/") return_url = base + "/" conn = sqlite3.connect("data/stripe_webhook.sqlite") try: _ensure_licenses_customer_id_column(conn) cur = conn.execute( "SELECT customer_id FROM licenses WHERE customer_id IS NOT NULL ORDER BY updated_at DESC LIMIT 1;" ) row = cur.fetchone() if not row or not row[0]: raise HTTPException(status_code=404, detail="No Stripe customer found") customer_id = str(row[0]) finally: conn.close() try: session = stripe.billing_portal.Session.create( customer=customer_id, return_url=return_url, ) except Exception: raise HTTPException(status_code=502, detail="Failed to create billing portal session") return {"url": session.url} class DeviceResetRequest(BaseModel): customer_email: str user_key: str = "default" @app.post("/stripe/create_billing_portal", response_model=CreatePortalResponse) def create_billing_portal(req: CreatePortalRequest): """ Erstellt einen Stripe Billing Portal Link. Kunde kann selbst kündigen, Zahlungsmittel ändern, upgraden. """ if stripe is None: raise HTTPException(status_code=503, detail="Stripe SDK not installed") if not STRIPE_SECRET_KEY: raise HTTPException(status_code=503, detail="Stripe not configured") stripe.api_key = STRIPE_SECRET_KEY email = req.email.strip().lower() conn = get_db() try: row = conn.execute( "SELECT stripe_customer_id FROM users WHERE email = ?", (email,), ).fetchone() if not row or not row["stripe_customer_id"]: raise HTTPException(status_code=404, detail="Stripe customer not found") session = stripe.billing_portal.Session.create( # type: ignore[attr-defined] customer=row["stripe_customer_id"], return_url=STRIPE_SUCCESS_URL or "https://example.com", ) return CreatePortalResponse(ok=True, url=session.url) finally: conn.close() @app.post("/admin/devices/reset") def admin_devices_reset(payload: DeviceResetRequest, _: None = Depends(require_admin_token)): """ Support endpoint: remove all device bindings for a given (customer_email, user_key). Requires X-Admin-Token. """ conn = sqlite3.connect("data/stripe_webhook.sqlite") try: conn.execute( "DELETE FROM device_bindings WHERE customer_email = ? AND user_key = ?;", (payload.customer_email, payload.user_key), ) conn.commit() cur = conn.execute( "SELECT COUNT(*) FROM device_bindings WHERE customer_email = ? AND user_key = ?;", (payload.customer_email, payload.user_key), ) remaining = int(cur.fetchone()[0]) finally: conn.close() return {"ok": True, "remaining": remaining} @app.get("/admin/usage") def admin_usage( customer_email: str = Query(...), _: None = Depends(require_admin_token), ): """ Admin-only: Return usage counters for a given customer_email. No device identifiers are returned. """ conn = sqlite3.connect("data/stripe_webhook.sqlite") try: # seats used cur = conn.execute( "SELECT COUNT(DISTINCT user_key) FROM device_bindings WHERE customer_email = ?;", (customer_email,), ) seats_used = int(cur.fetchone()[0]) # devices per seat cur2 = conn.execute( """ SELECT user_key, COUNT(*) as devices FROM device_bindings WHERE customer_email = ? GROUP BY user_key ORDER BY user_key ASC; """, (customer_email,), ) per_user = [{"user_key": r[0], "devices": int(r[1])} for r in cur2.fetchall()] # license limits (best effort) cur3 = conn.execute( "SELECT allowed_users, devices_per_user FROM licenses WHERE customer_email = ? LIMIT 1;", (customer_email,), ) row = cur3.fetchone() if row: allowed_users = int(row[0]) if row[0] is not None else 1 devices_per_user = int(row[1]) if row[1] is not None else 1 else: allowed_users = 0 devices_per_user = 0 finally: conn.close() return { "customer_email": customer_email, "allowed_users": allowed_users, "devices_per_user": devices_per_user, "seats_used": seats_used, "per_user": per_user, }