@@ -49,7 +49,7 @@ def partition(nums:List[float], k:int, start:int, end:int) -> int:
49
49
50
50
'''
51
51
followup: 如果数组太大 无法放到一台机器上 如何分布式求解?
52
- 利用p-persentile distributed calcuation求解
52
+ 利用p-percentile distributed calcuation求解
53
53
54
54
步骤 1:数据分割
55
55
将大数组分割成若干小块,每块数据可以放入单台机器进行处理。假设有 N 台机器,那么将数组分割成 N 块,每块由一个机器负责处理。
@@ -75,9 +75,10 @@ def split_data(data, num_chunks):
75
75
return np.array_split(data, num_chunks)
76
76
77
77
# 生成初始候选中位数
78
+ # replace=False无放回抽样 确保候选的多样性
78
79
def initial_candidates(data_chunks, num_candidates):
79
80
all_data = np.concatenate(data_chunks)
80
- return random.sample(list( all_data) , num_candidates)
81
+ return np. random.choice( all_data, num_candidates, replace=False )
81
82
82
83
# 在每个机器上计算小于等于候选的个数
83
84
def count_less_equal(data_chunk, candidates):
@@ -87,7 +88,10 @@ def count_less_equal(data_chunk, candidates):
87
88
def aggregate_counts(counts_per_machine):
88
89
return np.sum(counts_per_machine, axis=0)
89
90
90
- def find_median_distributed(data, num_machines, num_candidates):
91
+ # 引入error_tolerance参数 允许近似解
92
+ # 添加precision_threshold 当上下界差异很小时结束搜索
93
+ # 使用max_iterations限制最大迭代次数
94
+ def find_median_distributed(data, num_machines, num_candidates, max_iterations=100, error_tolerance=1, precision_threshold=1e-6):
91
95
# 将数据分成若干块
92
96
data_chunks = split_data(data, num_machines)
93
97
@@ -96,39 +100,47 @@ def find_median_distributed(data, num_machines, num_candidates):
96
100
97
101
# 目标中位数的位置
98
102
median_position = len(data) // 2
103
+ lower_bound, upper_bound = min(data), max(data)
99
104
100
- while True:
101
- # 在每个机器上计算小于等于候选的个数
105
+ for iteration in range(max_iterations):
102
106
counts_per_machine = [count_less_equal(chunk, candidates) for chunk in data_chunks]
103
-
104
- # 汇总所有机器的统计结果
105
107
total_counts = aggregate_counts(counts_per_machine)
106
108
107
- # 找到累计个数刚好超过中位数位置的候选
108
109
for i, count in enumerate(total_counts):
109
110
if count >= median_position:
110
111
current_median = candidates[i]
111
112
break
112
113
113
- # 检查是否满足中位数条件
114
- if total_counts[i] == median_position:
114
+ # 改进的收敛条件
115
+ if abs( total_counts[i] - median_position) <= error_tolerance :
115
116
return current_median
116
117
117
- # 更新候选范围
118
118
if total_counts[i] < median_position:
119
- lower_bound = candidates[i]
119
+ lower_bound = current_median
120
120
else:
121
- upper_bound = candidates[i]
121
+ upper_bound = current_median
122
122
123
- # 生成新的候选
124
- candidates = [random.uniform(lower_bound, upper_bound) for _ in range(num_candidates)]
125
-
126
- # 示例数据
127
- data = np.random.randint(0, 100, size=1000)
128
- num_machines = 10
129
- num_candidates = 5
130
-
131
- # 求解中位数
132
- median = find_median_distributed(data, num_machines, num_candidates)
133
- print("Estimated median is:", median)
123
+ # 检查精度阈值
124
+ if upper_bound - lower_bound < precision_threshold:
125
+ return (upper_bound + lower_bound) / 2
126
+
127
+ # 生成新的候选.
128
+ # 在当前范围内线性插值生成新候选 而不是随机均匀分布
129
+ candidates = np.linspace(lower_bound, upper_bound, num_candidates)
130
+
131
+ # 如果达到最大迭代次数,返回最佳近似值
132
+ return current_median
133
+
134
+ # unit test
135
+ def test_distributed_median():
136
+ np.random.seed(42) # 为了可重复性
137
+ data = np.random.randint(0, 1000, 100000) # 生成大量随机数据
138
+ true_median = np.median(data)
139
+
140
+ distributed_result = find_median_distributed(data, num_machines=5, num_candidates=10)
141
+
142
+ print(f"True median: {true_median}")
143
+ print(f"Distributed algorithm result: {distributed_result}")
144
+ print(f"Absolute error: {abs(true_median - distributed_result)}")
145
+ test_distributed_median()
134
146
'''
0 commit comments