Skip to content

Commit f61dde4

Browse files
committed
Fix mask issue in distributed nt_xent loss
1 parent cd85c43 commit f61dde4

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

simclr/modules/nt_xent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ def forward(self, z_i, z_j):
3131
"""
3232
N = 2 * self.batch_size * self.world_size
3333

34-
z = torch.cat((z_i, z_j), dim=0)
3534
if self.world_size > 1:
36-
z = torch.cat(GatherLayer.apply(z), dim=0)
35+
z_i = torch.cat(GatherLayer.apply(z_i), dim=0)
36+
z_j = torch.cat(GatherLayer.apply(z_j), dim=0)
37+
z = torch.cat((z_i, z_j), dim=0)
3738

3839
sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
3940

0 commit comments

Comments
 (0)