24
24
from typing import List
25
25
from collections import defaultdict
26
26
import random
27
+ import threading
27
28
28
29
class Instance :
29
30
def __init__ (self , label : str = "" ):
@@ -33,10 +34,13 @@ class InstanceIterator:
33
34
def __init__ (self , start , end ):
34
35
self .cur = start
35
36
self .end = end
36
- def has_next () -> bool :
37
- pass
38
- def next () -> Instance :
39
- pass
37
+ self .lock = threading .Lock ()
38
+ def has_next (self ) -> bool :
39
+ with self .lock : # 没要求多线程不写这句
40
+ pass
41
+ def next (self ) -> Instance :
42
+ with self .lock : # 没要求多线程不写这句
43
+ pass
40
44
41
45
def sampling (iterator : InstanceIterator , requirement : dict [str , int ]) -> dict [str , List [Instance ]]:
42
46
ret = defaultdict (list )
@@ -72,4 +76,59 @@ def sampling(iterator: InstanceIterator, requirement: dict[str, int]) -> dict[st
72
76
sample = order Examples by rnd;
73
77
sample = limit sample $M;
74
78
----
75
- '''
79
+ '''
80
+
81
+ # 下面这版实现 是考虑线程安全 InstanceIterator类 和 sampling方法里 都相应加了锁
82
+ class Instance :
83
+ def __init__ (self , label : str = "" ):
84
+ self .label = label
85
+ # Instance 类不需要任何同步机制
86
+ # 1. 它是不可变的(label 在初始化后不会改变)
87
+ # 2. 每次调用 InstanceIterator.next() 都会创建一个新实例
88
+ # 3. 它没有修改内部状态的方法
89
+
90
+ class InstanceIterator :
91
+ def __init__ (self , start , end ):
92
+ self .cur = start
93
+ self .end = end
94
+ self .lock = threading .Lock () # InstanceIterator 需要锁来保证线程安全
95
+
96
+ def has_next (self ) -> bool :
97
+ with self .lock :
98
+ return self .cur < self .end
99
+
100
+ def next (self ) -> Instance :
101
+ with self .lock :
102
+ if self .cur >= self .end :
103
+ raise StopIteration
104
+ instance = Instance (f"Label_{ self .cur } " ) # 创建新的Instance对象
105
+ self .cur += 1
106
+ return instance
107
+
108
+ def sampling (iterator : InstanceIterator , requirement : dict [str , int ]) -> dict [str , List [Instance ]]:
109
+ ret = defaultdict (list )
110
+ counter = defaultdict (int )
111
+ lock = threading .Lock () # 这个锁用于保护ret和counter
112
+
113
+ while True :
114
+ try :
115
+ cur_ins = iterator .next () # InstanceIterator的next方法已经是线程安全的
116
+ except StopIteration :
117
+ break
118
+
119
+ cur_label = cur_ins .label # 访问Instance的label不需要同步
120
+
121
+ with lock :
122
+ cur_cnt = len (ret [cur_label ])
123
+ cur_num = counter [cur_label ]
124
+
125
+ if cur_cnt < requirement [cur_label ]:
126
+ ret [cur_label ].append (cur_ins )
127
+ else :
128
+ idx = random .randint (0 , cur_num )
129
+ if idx < requirement [cur_label ]:
130
+ ret [cur_label ][idx ] = cur_ins
131
+
132
+ counter [cur_label ] = cur_num + 1
133
+
134
+ return ret
0 commit comments