Source code for core_redis.rate_limits.token_bucket

# -*- coding: utf-8 -*-

"""Token-bucket rate limiting."""

import math
import time
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple

from core_redis.client import RedisClient


[docs] class TokenBucket: # pylint: disable=too-few-public-methods """ Token-bucket rate limiter backed by Redis. A virtual bucket holds up to *capacity* tokens. Each request consumes *tokens_per_request* tokens. Tokens are replenished continuously at *refill_rate* tokens per second up to *capacity*. A request is allowed when the bucket contains enough tokens; otherwise it is rejected. **Characteristics** * **Burst-friendly**: the bucket can absorb a short burst of up to *capacity* requests before throttling begins, unlike fixed or sliding windows that spread the budget evenly. * **Smooth long-term rate**: over time, throughput converges to *refill_rate* requests/second regardless of burst patterns. * **Variable cost**: *tokens_per_request* lets different operations consume different amounts (e.g. a bulk export costs 10 tokens, a lightweight read costs 1). **Redis storage** Each identifier maps to a Redis hash with two fields: * ``tokens``: current token count (float string). * ``last_refill``: Unix timestamp of the last write (float string). The hash TTL is set to ``ceil(capacity / refill_rate)``, the time it takes to fully refill an empty bucket. If the key expires (the identifier has been idle that long) the next request re-initializes the bucket as full, which is the correct behavior. :param key_prefix: String prepended to every Redis key. Default: ``"rate_limit:token_bucket:"``. :param redis_kwargs: Keyword arguments forwarded verbatim to :class:`~core_redis.client.RedisClient` (e.g. ``{"host": "localhost", "port": 6379, "db": 0}``). Example: .. code-block:: python from core_redis.rate_limits import TokenBucket limiter = TokenBucket(redis_kwargs={"host": "localhost", "port": 6379}) allowed, tokens = limiter.is_allowed( "user_123", capacity=100, refill_rate=10.0 ) """
[docs] def __init__( self, key_prefix: str = "rate_limit:token_bucket:", redis_kwargs: Optional[Dict[str, Any]] = None, ) -> None: self._key_prefix = key_prefix self._client = RedisClient(**(redis_kwargs or {}))
[docs] def is_allowed( # pylint: disable=too-many-locals self, identifier: str, capacity: int = 100, refill_rate: float = 10.0, tokens_per_request: int = 1, ) -> Tuple[bool, int]: """ Check whether a request from *identifier* can be served. The bucket state is read with ``HGETALL``, tokens are refilled proportionally to the elapsed time since the last write, and if enough tokens are available, the updated state is persisted with a ``HSET + EXPIRE`` pipeline. :param identifier: Unique key for the subject being rate-limited (e.g. a user ID, IP address, or API key). :param capacity: Maximum number of tokens the bucket can hold (burst ceiling). Default: ``100``. :param refill_rate: Tokens added per second. Long-term throughput converges to this value. Default: ``10.0``. :param tokens_per_request: Tokens consumed by this request. Use values greater than ``1`` for expensive operations. Default: ``1``. :returns: A ``(allowed, available_tokens)`` tuple: * **allowed**: ``True`` if the bucket had enough tokens and the request is permitted; ``False`` if it should be rejected. * **available_tokens**: tokens remaining after this request (when allowed) or tokens currently in the bucket (when blocked). :raises ValueError: If *tokens_per_request* exceeds *capacity* (a request that can never be satisfied). .. note:: The read (``HGETALL``) and write (``HSET + EXPIRE``) are two separate operations. A concurrent request between them could produce a marginal over-count under very high contention. For strict correctness, replace the read-modify-write with a Lua script evaluated via ``EVAL``. """ if tokens_per_request > capacity: raise ValueError( f"tokens_per_request ({tokens_per_request}) exceeds " f"capacity ({capacity})" ) key = f"{self._key_prefix}{identifier}" current_time = time.time() ttl = math.ceil(capacity / refill_rate) bucket = self._client.client.hgetall(key) if not bucket: new_tokens: float = capacity - tokens_per_request pipe = self._client.client.pipeline() pipe.hset(key, mapping={ "tokens": str(new_tokens), "last_refill": str(current_time), }) pipe.expire(key, ttl) pipe.execute() return True, int(new_tokens) tokens_raw = bucket.get("tokens") or bucket.get(b"tokens") last_refill_raw = bucket.get("last_refill") or bucket.get(b"last_refill") stored_tokens = float( tokens_raw.decode() if isinstance(tokens_raw, bytes) else tokens_raw ) last_refill = float( last_refill_raw.decode() if isinstance(last_refill_raw, bytes) else last_refill_raw ) elapsed = current_time - last_refill current_tokens = min(stored_tokens + elapsed * refill_rate, capacity) if current_tokens >= tokens_per_request: new_tokens = current_tokens - tokens_per_request pipe = self._client.client.pipeline() pipe.hset(key, mapping={ "tokens": str(new_tokens), "last_refill": str(current_time), }) pipe.expire(key, ttl) pipe.execute() return True, int(new_tokens) return False, int(current_tokens)