We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fed2f37 commit e0c6042Copy full SHA for e0c6042
python/fedml/core/security/common/utils.py
@@ -21,7 +21,9 @@ def is_weight_param(k):
21
)
22
23
24
-def compute_euclidean_distance(v1, v2):
+def compute_euclidean_distance(v1, v2, device='cpu'):
25
+ v1 = v1.to(device)
26
+ v2 = v2.to(device)
27
return (v1 - v2).norm()
28
29
@@ -36,7 +38,7 @@ def compute_middle_point(alphas, model_list):
36
38
"""
37
39
sum_batch = torch.zeros(model_list[0].shape)
40
for a, a_batch_w in zip(alphas, model_list):
- sum_batch += a * a_batch_w
41
+ sum_batch += a * a_batch_w.float().cpu().numpy()
42
return sum_batch
43
44
0 commit comments