Skip to content

Commit e0c6042

Browse files
committed
rfa defense - to device
1 parent fed2f37 commit e0c6042

File tree

1 file changed

+4
-2
lines changed
  • python/fedml/core/security/common

1 file changed

+4
-2
lines changed

python/fedml/core/security/common/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def is_weight_param(k):
2121
)
2222

2323

24-
def compute_euclidean_distance(v1, v2):
24+
def compute_euclidean_distance(v1, v2, device='cpu'):
25+
v1 = v1.to(device)
26+
v2 = v2.to(device)
2527
return (v1 - v2).norm()
2628

2729

@@ -36,7 +38,7 @@ def compute_middle_point(alphas, model_list):
3638
"""
3739
sum_batch = torch.zeros(model_list[0].shape)
3840
for a, a_batch_w in zip(alphas, model_list):
39-
sum_batch += a * a_batch_w
41+
sum_batch += a * a_batch_w.float().cpu().numpy()
4042
return sum_batch
4143

4244

0 commit comments

Comments
 (0)