Skip to content

Commit cd85c43

Browse files
authored
Fixed world_size issue in nt_xent loss
1 parent 04bcf2b commit cd85c43

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

simclr/modules/nt_xent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def mask_correlated_samples(self, batch_size, world_size):
2020
mask = torch.ones((N, N), dtype=bool)
2121
mask = mask.fill_diagonal_(0)
2222
for i in range(batch_size * world_size):
23-
mask[i, batch_size + i] = 0
24-
mask[batch_size + i, i] = 0
23+
mask[i, batch_size * world_size + i] = 0
24+
mask[batch_size * world_size + i, i] = 0
2525
return mask
2626

2727
def forward(self, z_i, z_j):

0 commit comments

Comments
 (0)