-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcluster.py
73 lines (65 loc) · 2.31 KB
/
cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import httpx
import hashlib
from datetime import datetime
import base64
# TODO: Give nodes states: pending->ok->grace->dead
# For now we assume they are all always healthy
class Cluster():
def __init__(self, config):
self.config = config
self.nodes = []
self.virtual_nodes = {} # Virtual node hash -> physical node
self.sorted_vnode_hashes = [] # Virtual node hash ring
self.update_nodes()
def update_nodes(self):
# TODO: Give grace period, timeouts
self.nodes = [
{
"id": str(base64.b64encode(bytes(node, "utf-8")),"utf-8"),
"host": node,
"ts": datetime.now(),
"status": self.node_health(node)
} for node in self.config.current()["nodes"]
]
self._generate_virtual_nodes()
def healthy_nodes(self):
return [n for n in self.nodes if n["status"]["healthy"]]
def node_health(self, node):
try:
res = httpx.get(f"{node}/ready")
if res.status_code == 200:
return {"healthy": True, "message": "ok"}
return {"healthy": False, "message": f"GET {node}/health returned {res.status_code}"}
except ConnectionError:
return {"healthy": False, "message": f"Could not connect to {node}"}
except Exception as err:
return {"healthy": False, "message": f"Unexpected {err=}, {type(err)=}"}
def vnodes_per_node(self):
return self.config.current()["vnodes_per_node"]
def _generate_virtual_nodes(self):
"""
Create virtual nodes for each physical node.
"""
virtual_nodes = {}
for node in self.nodes:
for i in range(self.vnodes_per_node()):
vnode_key = f"{node['host']}_vnode_{i}"
vnode_hash = self.hash(vnode_key)
virtual_nodes[vnode_hash] = node['host']
self.virtual_nodes = dict(sorted(virtual_nodes.items()))
self.sorted_vnode_hashes = sorted(virtual_nodes.keys())
def hash(self, key):
"""
Hash function for consistent hashing.
"""
return int(hashlib.md5(key.encode('utf-8')).hexdigest(), 16)
def get_physical_node(self, key):
"""
Find the physical node responsible for the given key.
"""
hash_value = self.hash(key)
for vnode_hash in self.virtual_nodes:
if hash_value <= vnode_hash:
return self.virtual_nodes[vnode_hash]
# Wrap around to the first virtual node
return next(iter(self.virtual_nodes.values()))