Source code for core_redis.rate_limits.token_bucket
# -*- coding: utf-8 -*-
"""Token-bucket rate limiting."""
import math
import time
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
from core_redis.client import RedisClient
[docs]
class TokenBucket: # pylint: disable=too-few-public-methods
"""
Token-bucket rate limiter backed by Redis. A virtual bucket
holds up to *capacity* tokens. Each request consumes *tokens_per_request*
tokens. Tokens are replenished continuously at *refill_rate* tokens
per second up to *capacity*. A request is allowed when the bucket
contains enough tokens; otherwise it is rejected.
**Characteristics**
* **Burst-friendly**: the bucket can absorb a short burst of up to
*capacity* requests before throttling begins, unlike fixed or sliding
windows that spread the budget evenly.
* **Smooth long-term rate**: over time, throughput converges to
*refill_rate* requests/second regardless of burst patterns.
* **Variable cost**: *tokens_per_request* lets different operations
consume different amounts (e.g. a bulk export costs 10 tokens, a
lightweight read costs 1).
**Redis storage**
Each identifier maps to a Redis hash with two fields:
* ``tokens``: current token count (float string).
* ``last_refill``: Unix timestamp of the last write (float string).
The hash TTL is set to ``ceil(capacity / refill_rate)``, the time it
takes to fully refill an empty bucket. If the key expires (the identifier
has been idle that long) the next request re-initializes the bucket as
full, which is the correct behavior.
:param key_prefix:
String prepended to every Redis key.
Default: ``"rate_limit:token_bucket:"``.
:param redis_kwargs:
Keyword arguments forwarded verbatim to
:class:`~core_redis.client.RedisClient`
(e.g. ``{"host": "localhost", "port": 6379, "db": 0}``).
Example:
.. code-block:: python
from core_redis.rate_limits import TokenBucket
limiter = TokenBucket(redis_kwargs={"host": "localhost", "port": 6379})
allowed, tokens = limiter.is_allowed(
"user_123", capacity=100, refill_rate=10.0
)
"""
[docs]
def __init__(
self,
key_prefix: str = "rate_limit:token_bucket:",
redis_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self._key_prefix = key_prefix
self._client = RedisClient(**(redis_kwargs or {}))
[docs]
def is_allowed( # pylint: disable=too-many-locals
self,
identifier: str,
capacity: int = 100,
refill_rate: float = 10.0,
tokens_per_request: int = 1,
) -> Tuple[bool, int]:
"""
Check whether a request from *identifier* can be served.
The bucket state is read with ``HGETALL``, tokens are refilled
proportionally to the elapsed time since the last write, and if
enough tokens are available, the updated state is persisted with a
``HSET + EXPIRE`` pipeline.
:param identifier:
Unique key for the subject being rate-limited (e.g. a user ID,
IP address, or API key).
:param capacity:
Maximum number of tokens the bucket can hold (burst ceiling).
Default: ``100``.
:param refill_rate:
Tokens added per second. Long-term throughput converges to this
value. Default: ``10.0``.
:param tokens_per_request:
Tokens consumed by this request. Use values greater than ``1``
for expensive operations. Default: ``1``.
:returns:
A ``(allowed, available_tokens)`` tuple:
* **allowed**: ``True`` if the bucket had enough tokens and the
request is permitted; ``False`` if it should be rejected.
* **available_tokens**: tokens remaining after this request
(when allowed) or tokens currently in the bucket (when blocked).
:raises ValueError:
If *tokens_per_request* exceeds *capacity* (a request that can
never be satisfied).
.. note::
The read (``HGETALL``) and write (``HSET + EXPIRE``) are two
separate operations. A concurrent request between them could
produce a marginal over-count under very high contention. For
strict correctness, replace the read-modify-write with a Lua
script evaluated via ``EVAL``.
"""
if tokens_per_request > capacity:
raise ValueError(
f"tokens_per_request ({tokens_per_request}) exceeds "
f"capacity ({capacity})"
)
key = f"{self._key_prefix}{identifier}"
current_time = time.time()
ttl = math.ceil(capacity / refill_rate)
bucket = self._client.client.hgetall(key)
if not bucket:
new_tokens: float = capacity - tokens_per_request
pipe = self._client.client.pipeline()
pipe.hset(key, mapping={
"tokens": str(new_tokens),
"last_refill": str(current_time),
})
pipe.expire(key, ttl)
pipe.execute()
return True, int(new_tokens)
tokens_raw = bucket.get("tokens") or bucket.get(b"tokens")
last_refill_raw = bucket.get("last_refill") or bucket.get(b"last_refill")
stored_tokens = float(
tokens_raw.decode() if isinstance(tokens_raw, bytes) else tokens_raw
)
last_refill = float(
last_refill_raw.decode() if isinstance(last_refill_raw, bytes) else last_refill_raw
)
elapsed = current_time - last_refill
current_tokens = min(stored_tokens + elapsed * refill_rate, capacity)
if current_tokens >= tokens_per_request:
new_tokens = current_tokens - tokens_per_request
pipe = self._client.client.pipeline()
pipe.hset(key, mapping={
"tokens": str(new_tokens),
"last_refill": str(current_time),
})
pipe.expire(key, ttl)
pipe.execute()
return True, int(new_tokens)
return False, int(current_tokens)