# -*- coding: utf-8 -*- """Lokale Unit-Tests KI-Budget Phase 1 (kein Netzwerk).""" from __future__ import annotations import os import sqlite3 import tempfile import unittest from pathlib import Path from aza_ai_budget import ( LicenseBudgetRow, check_allows_openai_call, compute_budget_snapshot, ensure_ai_budget_schema, estimate_openai_cost_usd, insert_usage_event, resolve_license_for_device, sum_usage_usd_for_period, ) def _mk_db(path: Path) -> None: con = sqlite3.connect(str(path)) con.execute( """ CREATE TABLE device_bindings ( id INTEGER PRIMARY KEY AUTOINCREMENT, customer_email TEXT NOT NULL, user_key TEXT NOT NULL, device_hash TEXT NOT NULL, first_seen_at INTEGER NOT NULL, last_seen_at INTEGER NOT NULL, is_active INTEGER DEFAULT 1, UNIQUE(customer_email, user_key, device_hash) ) """ ) con.execute( """ CREATE TABLE licenses ( subscription_id TEXT PRIMARY KEY, customer_id TEXT, status TEXT, lookup_key TEXT, allowed_users INTEGER, devices_per_user INTEGER, customer_email TEXT, client_reference_id TEXT, current_period_end INTEGER, current_period_start INTEGER, updated_at INTEGER NOT NULL, license_key TEXT, practice_id TEXT ) """ ) now = 1_700_000_000 con.execute( """ INSERT INTO licenses(subscription_id, customer_id, status, lookup_key, allowed_users, devices_per_user, customer_email, client_reference_id, current_period_end, current_period_start, updated_at, license_key, practice_id) VALUES ('sub_t1', 'cus_x', 'active', 'aza_basic_monthly', 1, 2, 'cli@example.test', NULL, ?, ?, ?, 'KEY', 'prac_test') """, (now + 86400 * 30, now, now), ) import hashlib dh = hashlib.sha256(b"device-unit-test").hexdigest() con.execute( """ INSERT INTO device_bindings(customer_email, user_key, device_hash, first_seen_at, last_seen_at) VALUES ('cli@example.test', 'prac_test', ?, ?, ?) """, (dh, now, now), ) con.commit() con.close() class TestAiBudgetPhase1(unittest.TestCase): def setUp(self): self.tmp = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) self.tmp.close() self.db_path = Path(self.tmp.name) _mk_db(self.db_path) def tearDown(self): try: os.unlink(self.db_path) except OSError: pass def test_schema_idempotent(self): with sqlite3.connect(str(self.db_path)) as con: ensure_ai_budget_schema(con) ensure_ai_budget_schema(con) n = con.execute( "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='ai_usage_events'" ).fetchone()[0] self.assertEqual(n, 1) def test_resolve_and_percent(self): with sqlite3.connect(str(self.db_path)) as con: ensure_ai_budget_schema(con) lic = resolve_license_for_device(con, "device-unit-test") self.assertIsNotNone(lic) snap = compute_budget_snapshot(con, lic) self.assertEqual(snap["available_percent"], 100) ok, _ = check_allows_openai_call(con, lic) self.assertTrue(ok) def test_usage_reduces_percent(self): with sqlite3.connect(str(self.db_path)) as con: ensure_ai_budget_schema(con) lic = resolve_license_for_device(con, "device-unit-test") ps, pe = lic.period_start or 0, lic.period_end or 0 insert_usage_event( con, lic=lic, device_id="device-unit-test", period_start=ps, period_end=pe, operation_type="chat", model="gpt-4o-mini", input_tokens=1_000_000, output_tokens=0, total_tokens=1_000_000, audio_seconds=0.0, estimated_cost_usd=estimate_openai_cost_usd( model="gpt-4o-mini", input_tokens=1_000_000, output_tokens=0 ), request_id="t1", status="success", ) used = sum_usage_usd_for_period(con, lic.subscription_id, ps, pe) self.assertGreater(used, 0) snap2 = compute_budget_snapshot(con, lic) self.assertLess(snap2["available_percent"], 100) def test_block_at_budget(self): with sqlite3.connect(str(self.db_path)) as con: ensure_ai_budget_schema(con) lic = resolve_license_for_device(con, "device-unit-test") ps, pe = int(lic.period_start or 0), int(lic.period_end or 0) big = 50.0 insert_usage_event( con, lic=lic, device_id="x", period_start=ps, period_end=pe, operation_type="chat", model="gpt-4o-mini", input_tokens=0, output_tokens=0, total_tokens=0, audio_seconds=0.0, estimated_cost_usd=big, request_id="t2", status="success", ) ok, info = check_allows_openai_call(con, lic) self.assertFalse(ok) self.assertEqual(info.get("error_code"), "AI_BUDGET_EXCEEDED") if __name__ == "__main__": unittest.main()