Files
aza/AzA march 2026/tests/test_ai_budget_phase1b.py
2026-05-20 00:09:28 +02:00

309 lines
10 KiB
Python

# -*- coding: utf-8 -*-
"""Phase 1b: Empfang-Gate, Practice-Mapping, Audio-Schätzung, Admin-CSV-Spalten (ohne Secrets)."""
from __future__ import annotations
import csv
import hashlib
import io
import os
import sqlite3
import struct
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
from aza_ai_budget import (
budget_gate_blocked_payload_or_none,
ensure_ai_budget_schema,
estimate_audio_seconds_for_transcription,
resolve_license_for_empfang,
resolve_license_for_practice_id,
)
def _mk_minimal_wav(path: Path, duration_sec: float = 0.5, sample_rate: int = 8000) -> None:
"""16-bit mono PCM WAV für Dauer-Test."""
nframes = int(duration_sec * sample_rate)
buf = io.BytesIO()
buf.write(b"RIFF")
buf.write(struct.pack("<I", 36 + nframes * 2))
buf.write(b"WAVEfmt ")
buf.write(
struct.pack(
"<IHHIIHH",
16,
1,
1,
sample_rate,
sample_rate * 2,
2,
16,
)
)
buf.write(b"data")
buf.write(struct.pack("<I", nframes * 2))
buf.write(b"\x00\x00" * nframes)
path.write_bytes(buf.getvalue())
def _mk_db_device_and_practice(path: Path) -> None:
now = 1_700_000_000
con = sqlite3.connect(str(path))
con.execute(
"""
CREATE TABLE device_bindings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_email TEXT NOT NULL,
user_key TEXT NOT NULL,
device_hash TEXT NOT NULL,
first_seen_at INTEGER NOT NULL,
last_seen_at INTEGER NOT NULL,
is_active INTEGER DEFAULT 1,
UNIQUE(customer_email, user_key, device_hash)
)
"""
)
con.execute(
"""
CREATE TABLE licenses (
subscription_id TEXT PRIMARY KEY,
customer_id TEXT,
status TEXT,
lookup_key TEXT,
allowed_users INTEGER,
devices_per_user INTEGER,
customer_email TEXT,
client_reference_id TEXT,
current_period_end INTEGER,
current_period_start INTEGER,
updated_at INTEGER NOT NULL,
license_key TEXT,
practice_id TEXT
)
"""
)
con.execute(
"""
INSERT INTO licenses(subscription_id, customer_id, status, lookup_key, allowed_users, devices_per_user,
customer_email, client_reference_id, current_period_end, current_period_start, updated_at, license_key, practice_id)
VALUES ('sub_ef', 'cus_x', 'active', 'aza_basic_monthly', 1, 2,
'ef@example.test', NULL, ?, ?, ?, 'KEY', 'prac_ef_test')
""",
(now + 86400 * 30, now, now),
)
dh = hashlib.sha256(b"dev-ef-test").hexdigest()
con.execute(
"""
INSERT INTO device_bindings(customer_email, user_key, device_hash, first_seen_at, last_seen_at)
VALUES ('ef@example.test', 'uk', ?, ?, ?)
""",
(dh, now, now),
)
con.commit()
con.close()
class TestPhase1bResolve(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False)
self.tmp.close()
self.db_path = Path(self.tmp.name)
_mk_db_device_and_practice(self.db_path)
def tearDown(self):
try:
os.unlink(self.db_path)
except OSError:
pass
def test_resolve_by_practice(self):
with sqlite3.connect(str(self.db_path)) as con:
lic = resolve_license_for_practice_id(con, "prac_ef_test")
self.assertIsNotNone(lic)
self.assertEqual(lic.subscription_id, "sub_ef")
def test_empfang_prefers_device_then_practice(self):
with sqlite3.connect(str(self.db_path)) as con:
lic = resolve_license_for_empfang(
con, x_device_id="dev-ef-test", session_practice_id="prac_ef_test"
)
self.assertIsNotNone(lic)
self.assertEqual(lic.customer_email, "ef@example.test")
def test_empfang_device_practice_conflict_drops_device(self):
with sqlite3.connect(str(self.db_path)) as con:
lic = resolve_license_for_empfang(
con, x_device_id="dev-ef-test", session_practice_id="prac_other"
)
self.assertIsNone(lic)
class TestPhase1bGateAndAudio(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False)
self.tmp.close()
self.db_path = Path(self.tmp.name)
_mk_db_device_and_practice(self.db_path)
def tearDown(self):
try:
os.unlink(self.db_path)
except OSError:
pass
def test_gate_blocked_payload(self):
from aza_ai_budget import LicenseBudgetRow, insert_usage_event, sum_usage_usd_for_period
with sqlite3.connect(str(self.db_path)) as con:
ensure_ai_budget_schema(con)
lic = resolve_license_for_practice_id(con, "prac_ef_test")
assert lic is not None
ps, pe = int(lic.period_start or 0), int(lic.period_end or 0)
insert_usage_event(
con,
lic=lic,
device_id=None,
period_start=ps,
period_end=pe,
operation_type="chat",
model="gpt-4o-mini",
input_tokens=0,
output_tokens=0,
total_tokens=0,
audio_seconds=0.0,
estimated_cost_usd=50.0,
request_id="x",
status="success",
)
self.assertGreater(sum_usage_usd_for_period(con, lic.subscription_id, ps, pe), 9.0)
blocked = budget_gate_blocked_payload_or_none(
con,
lic,
device_id=None,
request_id="r1",
operation_type="transcription",
model="whisper-1",
gate_meta={"route": "test"},
)
self.assertIsNotNone(blocked)
self.assertEqual(blocked.get("error_code"), "AI_BUDGET_EXCEEDED")
def test_wav_duration_not_heuristic_only(self):
wf = Path(self.tmp.name + ".wav")
try:
_mk_minimal_wav(wf, duration_sec=2.0, sample_rate=8000)
sec = estimate_audio_seconds_for_transcription(
byte_size=wf.stat().st_size,
file_path=str(wf),
suffix=".wav",
)
self.assertGreaterEqual(sec, 2.0)
self.assertLessEqual(sec, 3.01)
finally:
try:
wf.unlink()
except OSError:
pass
def test_webm_uses_byte_heuristic(self):
sec = estimate_audio_seconds_for_transcription(byte_size=100_000, suffix=".webm")
self.assertEqual(sec, 10.0)
class TestPhase1bEmpfangBudgeted(unittest.TestCase):
def setUp(self):
self.tmp = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False)
self.tmp.close()
self.db_path = Path(self.tmp.name)
_mk_db_device_and_practice(self.db_path)
def tearDown(self):
try:
os.unlink(self.db_path)
except OSError:
pass
def test_when_stripe_db_missing_uses_core_only(self):
from empfang_routes import _empfang_transcribe_openai_budgeted
ghost = Path(self.tmp.name + "_nodb")
with patch("backend_main._stripe_db_path", return_value=ghost):
with patch(
"empfang_routes._empfang_transcribe_openai_from_bytes_core",
return_value=("ok", 5.0),
):
req = MagicMock()
req.headers.get.return_value = None
br, tx = _empfang_transcribe_openai_budgeted(
req, b"\x00\x01", practice_id="prac_ef_test", filename_suffix=".webm"
)
self.assertIsNone(br)
self.assertEqual(tx, "ok")
def test_budgeted_gate_blocks_without_license_mapping(self):
from empfang_routes import _empfang_transcribe_openai_budgeted
req = MagicMock()
req.headers.get.return_value = None
with patch("backend_main._stripe_db_path", return_value=self.db_path):
with patch("empfang_routes._empfang_transcribe_openai_from_bytes_core") as core:
br, tx = _empfang_transcribe_openai_budgeted(
req, b"\x00", practice_id="prac_unknown", filename_suffix=".webm"
)
self.assertIsNotNone(br)
self.assertIsNone(tx)
self.assertEqual(br.status_code, 402)
core.assert_not_called()
def test_budgeted_records_success_after_core(self):
from empfang_routes import _empfang_transcribe_openai_budgeted
req = MagicMock()
req.headers.get.return_value = None
with patch("backend_main._stripe_db_path", return_value=self.db_path):
with patch("empfang_routes._empfang_transcribe_openai_from_bytes_core", return_value=("txt", 12.3)):
br, tx = _empfang_transcribe_openai_budgeted(
req, b"x" * 5000, practice_id="prac_ef_test", filename_suffix=".webm"
)
self.assertIsNone(br)
self.assertEqual(tx, "txt")
with sqlite3.connect(str(self.db_path)) as con:
n = con.execute(
"SELECT COUNT(*) FROM ai_usage_events WHERE status='success' AND operation_type='transcription'"
).fetchone()[0]
self.assertGreaterEqual(n, 1)
class TestAdminCsvColumns(unittest.TestCase):
def test_expected_headers_no_secret_tokens(self):
forbidden_substrings = ("OPENAI", "API_KEY", "SECRET", "ADMIN_TOKEN", "MEDWORK", "password")
row = [
"sub_x",
"user@example.test",
"active",
"aza_basic_monthly",
"2026-01-01 00:00",
"2026-02-01 00:00",
10.0,
3.7,
6.3,
63,
5,
]
buf = io.StringIO()
w = csv.writer(buf)
w.writerow([
"subscription_id", "customer_email", "license_status", "lookup_key",
"period_start_utc", "period_end_utc", "budget_usd", "used_usd", "remaining_usd",
"available_percent", "event_count",
])
w.writerow(row)
out = buf.getvalue().upper()
for bad in forbidden_substrings:
self.assertNotIn(bad, out)
if __name__ == "__main__":
unittest.main()