53 lines
1.5 KiB
Python
53 lines
1.5 KiB
Python
|
|
import time
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Dict
|
||
|
|
|
||
|
|
from fastapi import HTTPException, status
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class Bucket:
|
||
|
|
tokens: float
|
||
|
|
last_ts: float
|
||
|
|
|
||
|
|
|
||
|
|
class TokenBucketRateLimiter:
|
||
|
|
"""
|
||
|
|
Simple in-memory token bucket limiter.
|
||
|
|
capacity: max burst
|
||
|
|
refill_rate: tokens per second
|
||
|
|
"""
|
||
|
|
def __init__(self, capacity: int, refill_rate: float):
|
||
|
|
self.capacity = float(capacity)
|
||
|
|
self.refill_rate = float(refill_rate)
|
||
|
|
self._buckets: Dict[str, Bucket] = {}
|
||
|
|
|
||
|
|
def _get_bucket(self, key: str) -> Bucket:
|
||
|
|
now = time.time()
|
||
|
|
b = self._buckets.get(key)
|
||
|
|
if not b:
|
||
|
|
b = Bucket(tokens=self.capacity, last_ts=now)
|
||
|
|
self._buckets[key] = b
|
||
|
|
return b
|
||
|
|
|
||
|
|
elapsed = now - b.last_ts
|
||
|
|
b.tokens = min(self.capacity, b.tokens + elapsed * self.refill_rate)
|
||
|
|
b.last_ts = now
|
||
|
|
return b
|
||
|
|
|
||
|
|
def consume(self, key: str, cost: float = 1.0) -> None:
|
||
|
|
b = self._get_bucket(key)
|
||
|
|
if b.tokens < cost:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||
|
|
detail="Rate limit exceeded",
|
||
|
|
)
|
||
|
|
b.tokens -= cost
|
||
|
|
|
||
|
|
|
||
|
|
# Defaults:
|
||
|
|
# - Per token: ~30 req/min (0.5 tokens/sec), burst 10
|
||
|
|
# - Per IP: ~60 req/min (1.0 tokens/sec), burst 20
|
||
|
|
default_token_limiter = TokenBucketRateLimiter(capacity=10, refill_rate=0.5)
|
||
|
|
default_ip_limiter = TokenBucketRateLimiter(capacity=20, refill_rate=1.0)
|