Skip to content

Commit 0239128

Browse files
author
aaron.liu
committed
update
1 parent b002464 commit 0239128

File tree

1 file changed

+64
-5
lines changed

1 file changed

+64
-5
lines changed

Math/LinkedIn Stratified Sampling.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import List
2525
from collections import defaultdict
2626
import random
27+
import threading
2728

2829
class Instance:
2930
def __init__(self, label: str = ""):
@@ -33,10 +34,13 @@ class InstanceIterator:
3334
def __init__(self, start, end):
3435
self.cur = start
3536
self.end = end
36-
def has_next() -> bool:
37-
pass
38-
def next() -> Instance:
39-
pass
37+
self.lock = threading.Lock()
38+
def has_next(self) -> bool:
39+
with self.lock: # 没要求多线程不写这句
40+
pass
41+
def next(self) -> Instance:
42+
with self.lock: # 没要求多线程不写这句
43+
pass
4044

4145
def sampling(iterator: InstanceIterator, requirement: dict[str, int]) -> dict[str, List[Instance]]:
4246
ret = defaultdict(list)
@@ -72,4 +76,59 @@ def sampling(iterator: InstanceIterator, requirement: dict[str, int]) -> dict[st
7276
sample = order Examples by rnd;
7377
sample = limit sample $M;
7478
----
75-
'''
79+
'''
80+
81+
# 下面这版实现 是考虑线程安全 InstanceIterator类 和 sampling方法里 都相应加了锁
82+
class Instance:
83+
def __init__(self, label: str = ""):
84+
self.label = label
85+
# Instance 类不需要任何同步机制
86+
# 1. 它是不可变的(label 在初始化后不会改变)
87+
# 2. 每次调用 InstanceIterator.next() 都会创建一个新实例
88+
# 3. 它没有修改内部状态的方法
89+
90+
class InstanceIterator:
91+
def __init__(self, start, end):
92+
self.cur = start
93+
self.end = end
94+
self.lock = threading.Lock() # InstanceIterator 需要锁来保证线程安全
95+
96+
def has_next(self) -> bool:
97+
with self.lock:
98+
return self.cur < self.end
99+
100+
def next(self) -> Instance:
101+
with self.lock:
102+
if self.cur >= self.end:
103+
raise StopIteration
104+
instance = Instance(f"Label_{self.cur}") # 创建新的Instance对象
105+
self.cur += 1
106+
return instance
107+
108+
def sampling(iterator: InstanceIterator, requirement: dict[str, int]) -> dict[str, List[Instance]]:
109+
ret = defaultdict(list)
110+
counter = defaultdict(int)
111+
lock = threading.Lock() # 这个锁用于保护ret和counter
112+
113+
while True:
114+
try:
115+
cur_ins = iterator.next() # InstanceIterator的next方法已经是线程安全的
116+
except StopIteration:
117+
break
118+
119+
cur_label = cur_ins.label # 访问Instance的label不需要同步
120+
121+
with lock:
122+
cur_cnt = len(ret[cur_label])
123+
cur_num = counter[cur_label]
124+
125+
if cur_cnt < requirement[cur_label]:
126+
ret[cur_label].append(cur_ins)
127+
else:
128+
idx = random.randint(0, cur_num)
129+
if idx < requirement[cur_label]:
130+
ret[cur_label][idx] = cur_ins
131+
132+
counter[cur_label] = cur_num + 1
133+
134+
return ret

0 commit comments

Comments
 (0)