Source code for core_redis.rate_limits.leaky_bucket
# -*- coding: utf-8 -*-
"""Leaky-bucket rate limiting."""
import math
import time
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
from core_redis.client import RedisClient
[docs]
class LeakyBucket: # pylint: disable=too-few-public-methods
"""
Leaky-bucket rate limiter backed by Redis.
Incoming requests fill a virtual queue (the "bucket"). The queue drains
at a fixed *leak_rate* requests per second regardless of how fast requests
arrive. A request is accepted when the queue has room; otherwise it is
rejected immediately. The output rate is strictly constant — unlike
:class:`~core_redis.rate_limits.token_bucket.TokenBucket`, bursts are
not served immediately but are absorbed into the queue and processed at
the leak rate.
**Characteristics**
* **Strictly constant output rate** — downstream systems receive requests
at exactly *leak_rate* per second, making it ideal for protecting
third-party APIs with hard per-second quotas.
* **Burst absorption** — short bursts are accepted up to *capacity* and
queued rather than dropped; excess beyond *capacity* is rejected.
* **No burst acceleration** — unlike Token Bucket, a full bucket does not
let a burst through faster than *leak_rate*.
**Redis storage**
Each identifier maps to a Redis hash with two fields:
* ``queue_size``: current number of requests in the queue (float string).
* ``last_leak``: Unix timestamp of the last write (float string).
The hash TTL is set to ``ceil(capacity / leak_rate)`` — the time it takes
to fully drain a filled queue. If the key expires (the identifier has been
idle that long) the next request initialises a fresh empty queue.
:param key_prefix:
String prepended to every Redis key.
Default: ``"rate_limit:leaky_bucket:"``.
:param redis_kwargs:
Keyword arguments forwarded verbatim to
:class:`~core_redis.client.RedisClient`
(e.g. ``{"host": "localhost", "port": 6379, "db": 0}``).
Example:
.. code-block:: python
from core_redis.rate_limits import LeakyBucket
limiter = LeakyBucket(redis_kwargs={"host": "localhost", "port": 6379})
allowed, available = limiter.is_allowed(
"user_123", capacity=100, leak_rate=10.0
)
"""
[docs]
def __init__(
self,
key_prefix: str = "rate_limit:leaky_bucket:",
redis_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self._key_prefix = key_prefix
self._client = RedisClient(**(redis_kwargs or {}))
[docs]
def is_allowed( # pylint: disable=too-many-locals
self,
identifier: str,
capacity: int = 100,
leak_rate: float = 10.0,
) -> Tuple[bool, int]:
"""
Check whether a request from *identifier* can be queued.
The queue state is read with ``HGETALL``, leaked requests are drained
proportionally to the elapsed time, and — if the queue has room — the
new request is enqueued and the state is persisted with an
``HSET + EXPIRE`` pipeline.
:param identifier:
Unique key for the subject being rate-limited (e.g. a user ID,
IP address, or API key).
:param capacity:
Maximum queue depth (number of requests that can be buffered).
Default: ``100``.
:param leak_rate:
Requests drained per second (output rate). Default: ``10.0``.
:returns:
A ``(allowed, available)`` tuple:
* **allowed** — ``True`` if the request was accepted into the
queue; ``False`` if the queue was full and it was dropped.
* **available** — remaining free slots in the queue after this
request (when allowed), or ``0`` (when blocked).
.. note::
The read (``HGETALL``) and write (``HSET + EXPIRE``) are two
separate operations. A concurrent request between them could
produce a marginal over-count under very high contention. For
strict correctness, replace the read-modify-write with a Lua
script evaluated via ``EVAL``.
"""
key = f"{self._key_prefix}{identifier}"
current_time = time.time()
ttl = math.ceil(capacity / leak_rate)
bucket = self._client.client.hgetall(key)
if not bucket:
pipe = self._client.client.pipeline()
pipe.hset(key, mapping={
"queue_size": "1",
"last_leak": str(current_time),
})
pipe.expire(key, ttl)
pipe.execute()
return True, capacity - 1
queue_raw = bucket.get("queue_size") or bucket.get(b"queue_size")
last_leak_raw = bucket.get("last_leak") or bucket.get(b"last_leak")
stored_queue: float = float(
queue_raw.decode() if isinstance(queue_raw, bytes) else queue_raw
)
last_leak: float = float(
last_leak_raw.decode() if isinstance(last_leak_raw, bytes) else last_leak_raw
)
elapsed = current_time - last_leak
current_queue = max(0.0, stored_queue - elapsed * leak_rate)
if math.ceil(current_queue) < capacity:
new_queue = current_queue + 1
pipe = self._client.client.pipeline()
pipe.hset(key, mapping={
"queue_size": str(new_queue),
"last_leak": str(current_time),
})
pipe.expire(key, ttl)
pipe.execute()
return True, int(capacity - new_queue)
return False, 0