Python Consistent Hashing — Deep Dive

Implementation from scratch

A production-quality consistent hash ring needs: a sorted ring data structure, virtual nodes for balance, and efficient lookup via binary search.

import bisect
import hashlib
from typing import Generic, Optional, TypeVar

T = TypeVar("T")


class ConsistentHashRing(Generic[T]):
    """Consistent hash ring with virtual nodes for even distribution."""

    def __init__(self, replicas: int = 150):
        self.replicas = replicas
        self._ring: list[int] = []  # Sorted hash positions
        self._node_map: dict[int, T] = {}  # Hash position → node
        self._nodes: set[T] = set()

    @staticmethod
    def _hash(key: str) -> int:
        """Generate a 32-bit hash using MD5 for uniform distribution."""
        digest = hashlib.md5(key.encode()).hexdigest()
        return int(digest[:8], 16)

    def add_node(self, node: T, weight: int = 1) -> None:
        """Add a node with optional weight (more virtual nodes = more keys)."""
        self._nodes.add(node)
        num_replicas = self.replicas * weight
        for i in range(num_replicas):
            virtual_key = f"{node}:vn{i}"
            h = self._hash(virtual_key)
            if h not in self._node_map:
                self._node_map[h] = node
                bisect.insort(self._ring, h)

    def remove_node(self, node: T) -> None:
        """Remove a node and all its virtual nodes from the ring."""
        self._nodes.discard(node)
        to_remove = [h for h, n in self._node_map.items() if n == node]
        for h in to_remove:
            del self._node_map[h]
            idx = bisect.bisect_left(self._ring, h)
            if idx < len(self._ring) and self._ring[idx] == h:
                self._ring.pop(idx)

    def get_node(self, key: str) -> Optional[T]:
        """Find the node responsible for a given key."""
        if not self._ring:
            return None
        h = self._hash(key)
        idx = bisect.bisect_right(self._ring, h)
        if idx == len(self._ring):
            idx = 0  # Wrap around the ring
        return self._node_map[self._ring[idx]]

    def get_nodes(self, key: str, count: int = 3) -> list[T]:
        """Get multiple distinct nodes for replication."""
        if not self._ring:
            return []
        result: list[T] = []
        seen: set = set()
        h = self._hash(key)
        idx = bisect.bisect_right(self._ring, h)

        for _ in range(len(self._ring)):
            if idx >= len(self._ring):
                idx = 0
            node = self._node_map[self._ring[idx]]
            if node not in seen:
                seen.add(node)
                result.append(node)
                if len(result) == count:
                    break
            idx += 1

        return result

    @property
    def node_count(self) -> int:
        return len(self._nodes)

    def distribution(self) -> dict[T, float]:
        """Calculate the fraction of the key space each node owns."""
        if not self._ring:
            return {}
        total_space = 2**32
        ownership: dict[T, int] = {n: 0 for n in self._nodes}

        for i, h in enumerate(self._ring):
            prev_h = self._ring[i - 1] if i > 0 else self._ring[-1] - total_space
            span = h - prev_h
            if span < 0:
                span += total_space
            ownership[self._node_map[h]] += span

        return {n: s / total_space for n, s in ownership.items()}

Testing distribution quality

def test_distribution_balance():
    """Verify that keys are evenly distributed across nodes."""
    ring = ConsistentHashRing(replicas=150)
    for i in range(5):
        ring.add_node(f"server-{i}")

    # Assign 100,000 keys and count per server
    counts: dict[str, int] = {}
    for i in range(100_000):
        node = ring.get_node(f"key:{i}")
        counts[node] = counts.get(node, 0) + 1

    # Each of 5 servers should get ~20,000 (±15%)
    for node, count in sorted(counts.items()):
        pct = count / 100_000
        assert 0.14 < pct < 0.26, f"{node} got {pct:.1%}, expected ~20%"
        print(f"  {node}: {count:,} keys ({pct:.1%})")


def test_minimal_redistribution():
    """Removing a node should only move ~1/N of keys."""
    ring = ConsistentHashRing(replicas=150)
    for i in range(5):
        ring.add_node(f"server-{i}")

    # Record initial assignments
    keys = [f"key:{i}" for i in range(100_000)]
    before = {k: ring.get_node(k) for k in keys}

    # Remove one server
    ring.remove_node("server-2")
    after = {k: ring.get_node(k) for k in keys}

    moved = sum(1 for k in keys if before[k] != after[k])
    move_pct = moved / len(keys)
    print(f"  Keys moved: {moved:,} ({move_pct:.1%})")
    # Should be close to 20% (1/5), not 80%
    assert move_pct < 0.30

Weighted nodes

When servers have different capacities (one has 64 GB RAM, another has 16 GB), assign proportional weights:

ring = ConsistentHashRing(replicas=100)
ring.add_node("large-server", weight=4)   # 400 virtual nodes
ring.add_node("medium-server", weight=2)  # 200 virtual nodes
ring.add_node("small-server", weight=1)   # 100 virtual nodes

dist = ring.distribution()
# large-server: ~57%, medium-server: ~29%, small-server: ~14%

Integration with a distributed cache

import redis
from typing import Any


class ConsistentHashCache:
    """Distributed cache using consistent hashing for routing."""

    def __init__(self, servers: dict[str, str], replicas: int = 150):
        """
        servers: mapping of name → redis URL
        Example: {"cache-1": "redis://host1:6379", "cache-2": "redis://host2:6379"}
        """
        self.ring = ConsistentHashRing(replicas=replicas)
        self.clients: dict[str, redis.Redis] = {}

        for name, url in servers.items():
            self.ring.add_node(name)
            self.clients[name] = redis.from_url(url, decode_responses=True)

    def _get_client(self, key: str) -> redis.Redis:
        node = self.ring.get_node(key)
        return self.clients[node]

    def get(self, key: str) -> Optional[str]:
        client = self._get_client(key)
        try:
            return client.get(key)
        except redis.RedisError:
            return None

    def set(self, key: str, value: Any, ttl: int = 300) -> None:
        client = self._get_client(key)
        try:
            client.setex(key, ttl, value)
        except redis.RedisError:
            pass

    def remove_server(self, name: str) -> None:
        """Gracefully remove a server (keys auto-route to neighbors)."""
        self.ring.remove_node(name)
        self.clients.pop(name, None)

    def add_server(self, name: str, url: str) -> None:
        """Add a new server to the cluster."""
        self.ring.add_node(name)
        self.clients[name] = redis.from_url(url, decode_responses=True)

Jump consistent hash

For fixed-size clusters (no dynamic node removal), Google’s jump consistent hash is simpler and faster:

def jump_consistent_hash(key: int, num_buckets: int) -> int:
    """Google's jump consistent hash — O(ln(n)) with perfect balance."""
    b, j = -1, 0
    while j < num_buckets:
        b = j
        key = ((key * 2862933555777941757) + 1) & 0xFFFFFFFFFFFFFFFF
        j = int((b + 1) * (1 << 31) / ((key >> 33) + 1))
    return b

Jump hash produces perfectly balanced output in O(ln n) time with zero memory overhead. The trade-off: it doesn’t support weighted nodes or efficient single-node removal (removing a node reshuffles keys from all higher-numbered buckets).

Hash function selection

The choice of hash function affects both speed and distribution quality:

Hash functionSpeed (ns/op)DistributionNotes
MD5~350ExcellentOverkill for hashing, but well-distributed
MurmurHash3~45ExcellentBest balance of speed and quality
xxHash~25ExcellentFastest option, needs xxhash package
Python built-in hash()~20VariableNot deterministic across processes

Never use Python’s built-in hash() for consistent hashing — it’s randomized per process (PYTHONHASHSEED) and varies across Python versions.

Production considerations

  • Health checks — monitor each node and automatically remove unresponsive nodes from the ring. Re-add them when they recover.
  • Replication — use get_nodes(key, count=3) to write to multiple nodes. Read from any replica.
  • Graceful migration — when adding a node, read from both the old and new node during a transition period to handle keys that haven’t been migrated yet.
  • Monitoring — track distribution skew and key movement rates during topology changes.

The one thing to remember: consistent hashing turns cluster resizing from a catastrophic “move everything” event into a surgical “move 1/N of the data” operation — implement it with virtual nodes for balance and binary search for O(log n) lookups.

pythondistributed-systemsalgorithms

See Also

  • Python Sharding Strategies Understand database sharding through a library card catalog analogy that makes splitting data across servers intuitive.
  • Ci Cd Why big apps can ship updates every day without turning your phone into a glitchy mess — CI/CD is the behind-the-scenes quality gate and delivery truck.
  • Containerization Why does software that works on your computer break on everyone else's? Containers fix that — and they're why Netflix can deploy 100 updates a day without the site going down.
  • Python 310 New Features Python 3.10 gave programmers a shape-sorting machine, friendlier error messages, and cleaner ways to say 'this or that' in type hints.
  • Python 311 New Features Python 3.11 made everything faster, error messages smarter, and let you catch several mistakes at once instead of stopping at the first one.