update
This commit is contained in:
308
AzA march 2026/tests/test_ai_budget_phase1b.py
Normal file
308
AzA march 2026/tests/test_ai_budget_phase1b.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Phase 1b: Empfang-Gate, Practice-Mapping, Audio-Schätzung, Admin-CSV-Spalten (ohne Secrets)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import sqlite3
|
||||
import struct
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from aza_ai_budget import (
|
||||
budget_gate_blocked_payload_or_none,
|
||||
ensure_ai_budget_schema,
|
||||
estimate_audio_seconds_for_transcription,
|
||||
resolve_license_for_empfang,
|
||||
resolve_license_for_practice_id,
|
||||
)
|
||||
|
||||
|
||||
def _mk_minimal_wav(path: Path, duration_sec: float = 0.5, sample_rate: int = 8000) -> None:
|
||||
"""16-bit mono PCM WAV für Dauer-Test."""
|
||||
nframes = int(duration_sec * sample_rate)
|
||||
buf = io.BytesIO()
|
||||
buf.write(b"RIFF")
|
||||
buf.write(struct.pack("<I", 36 + nframes * 2))
|
||||
buf.write(b"WAVEfmt ")
|
||||
buf.write(
|
||||
struct.pack(
|
||||
"<IHHIIHH",
|
||||
16,
|
||||
1,
|
||||
1,
|
||||
sample_rate,
|
||||
sample_rate * 2,
|
||||
2,
|
||||
16,
|
||||
)
|
||||
)
|
||||
buf.write(b"data")
|
||||
buf.write(struct.pack("<I", nframes * 2))
|
||||
buf.write(b"\x00\x00" * nframes)
|
||||
path.write_bytes(buf.getvalue())
|
||||
|
||||
|
||||
def _mk_db_device_and_practice(path: Path) -> None:
|
||||
now = 1_700_000_000
|
||||
con = sqlite3.connect(str(path))
|
||||
con.execute(
|
||||
"""
|
||||
CREATE TABLE device_bindings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
customer_email TEXT NOT NULL,
|
||||
user_key TEXT NOT NULL,
|
||||
device_hash TEXT NOT NULL,
|
||||
first_seen_at INTEGER NOT NULL,
|
||||
last_seen_at INTEGER NOT NULL,
|
||||
is_active INTEGER DEFAULT 1,
|
||||
UNIQUE(customer_email, user_key, device_hash)
|
||||
)
|
||||
"""
|
||||
)
|
||||
con.execute(
|
||||
"""
|
||||
CREATE TABLE licenses (
|
||||
subscription_id TEXT PRIMARY KEY,
|
||||
customer_id TEXT,
|
||||
status TEXT,
|
||||
lookup_key TEXT,
|
||||
allowed_users INTEGER,
|
||||
devices_per_user INTEGER,
|
||||
customer_email TEXT,
|
||||
client_reference_id TEXT,
|
||||
current_period_end INTEGER,
|
||||
current_period_start INTEGER,
|
||||
updated_at INTEGER NOT NULL,
|
||||
license_key TEXT,
|
||||
practice_id TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
con.execute(
|
||||
"""
|
||||
INSERT INTO licenses(subscription_id, customer_id, status, lookup_key, allowed_users, devices_per_user,
|
||||
customer_email, client_reference_id, current_period_end, current_period_start, updated_at, license_key, practice_id)
|
||||
VALUES ('sub_ef', 'cus_x', 'active', 'aza_basic_monthly', 1, 2,
|
||||
'ef@example.test', NULL, ?, ?, ?, 'KEY', 'prac_ef_test')
|
||||
""",
|
||||
(now + 86400 * 30, now, now),
|
||||
)
|
||||
dh = hashlib.sha256(b"dev-ef-test").hexdigest()
|
||||
con.execute(
|
||||
"""
|
||||
INSERT INTO device_bindings(customer_email, user_key, device_hash, first_seen_at, last_seen_at)
|
||||
VALUES ('ef@example.test', 'uk', ?, ?, ?)
|
||||
""",
|
||||
(dh, now, now),
|
||||
)
|
||||
con.commit()
|
||||
con.close()
|
||||
|
||||
|
||||
class TestPhase1bResolve(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False)
|
||||
self.tmp.close()
|
||||
self.db_path = Path(self.tmp.name)
|
||||
_mk_db_device_and_practice(self.db_path)
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
os.unlink(self.db_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def test_resolve_by_practice(self):
|
||||
with sqlite3.connect(str(self.db_path)) as con:
|
||||
lic = resolve_license_for_practice_id(con, "prac_ef_test")
|
||||
self.assertIsNotNone(lic)
|
||||
self.assertEqual(lic.subscription_id, "sub_ef")
|
||||
|
||||
def test_empfang_prefers_device_then_practice(self):
|
||||
with sqlite3.connect(str(self.db_path)) as con:
|
||||
lic = resolve_license_for_empfang(
|
||||
con, x_device_id="dev-ef-test", session_practice_id="prac_ef_test"
|
||||
)
|
||||
self.assertIsNotNone(lic)
|
||||
self.assertEqual(lic.customer_email, "ef@example.test")
|
||||
|
||||
def test_empfang_device_practice_conflict_drops_device(self):
|
||||
with sqlite3.connect(str(self.db_path)) as con:
|
||||
lic = resolve_license_for_empfang(
|
||||
con, x_device_id="dev-ef-test", session_practice_id="prac_other"
|
||||
)
|
||||
self.assertIsNone(lic)
|
||||
|
||||
|
||||
class TestPhase1bGateAndAudio(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False)
|
||||
self.tmp.close()
|
||||
self.db_path = Path(self.tmp.name)
|
||||
_mk_db_device_and_practice(self.db_path)
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
os.unlink(self.db_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def test_gate_blocked_payload(self):
|
||||
from aza_ai_budget import LicenseBudgetRow, insert_usage_event, sum_usage_usd_for_period
|
||||
|
||||
with sqlite3.connect(str(self.db_path)) as con:
|
||||
ensure_ai_budget_schema(con)
|
||||
lic = resolve_license_for_practice_id(con, "prac_ef_test")
|
||||
assert lic is not None
|
||||
ps, pe = int(lic.period_start or 0), int(lic.period_end or 0)
|
||||
insert_usage_event(
|
||||
con,
|
||||
lic=lic,
|
||||
device_id=None,
|
||||
period_start=ps,
|
||||
period_end=pe,
|
||||
operation_type="chat",
|
||||
model="gpt-4o-mini",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
total_tokens=0,
|
||||
audio_seconds=0.0,
|
||||
estimated_cost_usd=50.0,
|
||||
request_id="x",
|
||||
status="success",
|
||||
)
|
||||
self.assertGreater(sum_usage_usd_for_period(con, lic.subscription_id, ps, pe), 9.0)
|
||||
blocked = budget_gate_blocked_payload_or_none(
|
||||
con,
|
||||
lic,
|
||||
device_id=None,
|
||||
request_id="r1",
|
||||
operation_type="transcription",
|
||||
model="whisper-1",
|
||||
gate_meta={"route": "test"},
|
||||
)
|
||||
self.assertIsNotNone(blocked)
|
||||
self.assertEqual(blocked.get("error_code"), "AI_BUDGET_EXCEEDED")
|
||||
|
||||
def test_wav_duration_not_heuristic_only(self):
|
||||
wf = Path(self.tmp.name + ".wav")
|
||||
try:
|
||||
_mk_minimal_wav(wf, duration_sec=2.0, sample_rate=8000)
|
||||
sec = estimate_audio_seconds_for_transcription(
|
||||
byte_size=wf.stat().st_size,
|
||||
file_path=str(wf),
|
||||
suffix=".wav",
|
||||
)
|
||||
self.assertGreaterEqual(sec, 2.0)
|
||||
self.assertLessEqual(sec, 3.01)
|
||||
finally:
|
||||
try:
|
||||
wf.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def test_webm_uses_byte_heuristic(self):
|
||||
sec = estimate_audio_seconds_for_transcription(byte_size=100_000, suffix=".webm")
|
||||
self.assertEqual(sec, 10.0)
|
||||
|
||||
|
||||
class TestPhase1bEmpfangBudgeted(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmp = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False)
|
||||
self.tmp.close()
|
||||
self.db_path = Path(self.tmp.name)
|
||||
_mk_db_device_and_practice(self.db_path)
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
os.unlink(self.db_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def test_when_stripe_db_missing_uses_core_only(self):
|
||||
from empfang_routes import _empfang_transcribe_openai_budgeted
|
||||
|
||||
ghost = Path(self.tmp.name + "_nodb")
|
||||
with patch("backend_main._stripe_db_path", return_value=ghost):
|
||||
with patch(
|
||||
"empfang_routes._empfang_transcribe_openai_from_bytes_core",
|
||||
return_value=("ok", 5.0),
|
||||
):
|
||||
req = MagicMock()
|
||||
req.headers.get.return_value = None
|
||||
br, tx = _empfang_transcribe_openai_budgeted(
|
||||
req, b"\x00\x01", practice_id="prac_ef_test", filename_suffix=".webm"
|
||||
)
|
||||
self.assertIsNone(br)
|
||||
self.assertEqual(tx, "ok")
|
||||
|
||||
def test_budgeted_gate_blocks_without_license_mapping(self):
|
||||
from empfang_routes import _empfang_transcribe_openai_budgeted
|
||||
|
||||
req = MagicMock()
|
||||
req.headers.get.return_value = None
|
||||
with patch("backend_main._stripe_db_path", return_value=self.db_path):
|
||||
with patch("empfang_routes._empfang_transcribe_openai_from_bytes_core") as core:
|
||||
br, tx = _empfang_transcribe_openai_budgeted(
|
||||
req, b"\x00", practice_id="prac_unknown", filename_suffix=".webm"
|
||||
)
|
||||
self.assertIsNotNone(br)
|
||||
self.assertIsNone(tx)
|
||||
self.assertEqual(br.status_code, 402)
|
||||
core.assert_not_called()
|
||||
|
||||
def test_budgeted_records_success_after_core(self):
|
||||
from empfang_routes import _empfang_transcribe_openai_budgeted
|
||||
|
||||
req = MagicMock()
|
||||
req.headers.get.return_value = None
|
||||
with patch("backend_main._stripe_db_path", return_value=self.db_path):
|
||||
with patch("empfang_routes._empfang_transcribe_openai_from_bytes_core", return_value=("txt", 12.3)):
|
||||
br, tx = _empfang_transcribe_openai_budgeted(
|
||||
req, b"x" * 5000, practice_id="prac_ef_test", filename_suffix=".webm"
|
||||
)
|
||||
self.assertIsNone(br)
|
||||
self.assertEqual(tx, "txt")
|
||||
with sqlite3.connect(str(self.db_path)) as con:
|
||||
n = con.execute(
|
||||
"SELECT COUNT(*) FROM ai_usage_events WHERE status='success' AND operation_type='transcription'"
|
||||
).fetchone()[0]
|
||||
self.assertGreaterEqual(n, 1)
|
||||
|
||||
|
||||
class TestAdminCsvColumns(unittest.TestCase):
|
||||
def test_expected_headers_no_secret_tokens(self):
|
||||
forbidden_substrings = ("OPENAI", "API_KEY", "SECRET", "ADMIN_TOKEN", "MEDWORK", "password")
|
||||
row = [
|
||||
"sub_x",
|
||||
"user@example.test",
|
||||
"active",
|
||||
"aza_basic_monthly",
|
||||
"2026-01-01 00:00",
|
||||
"2026-02-01 00:00",
|
||||
10.0,
|
||||
3.7,
|
||||
6.3,
|
||||
63,
|
||||
5,
|
||||
]
|
||||
buf = io.StringIO()
|
||||
w = csv.writer(buf)
|
||||
w.writerow([
|
||||
"subscription_id", "customer_email", "license_status", "lookup_key",
|
||||
"period_start_utc", "period_end_utc", "budget_usd", "used_usd", "remaining_usd",
|
||||
"available_percent", "event_count",
|
||||
])
|
||||
w.writerow(row)
|
||||
out = buf.getvalue().upper()
|
||||
for bad in forbidden_substrings:
|
||||
self.assertNotIn(bad, out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user