Skip to content

Commit 053ff42

Browse files
committed
Added support for all_gather object
1 parent 11a1fba commit 053ff42

File tree

5 files changed

+62
-13
lines changed

5 files changed

+62
-13
lines changed

ignite/distributed/comp_models/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _apply_op(
181181
return tensor
182182

183183
def _collective_op(
184-
self, tensor: Union[torch.Tensor, float, str], fn: Callable, *args: Any, **kwargs: Any
184+
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
185185
) -> Union[torch.Tensor, float, List[float], List[str]]:
186186
tensor_to_number = tensor_to_str = False
187187
device = self.device()
@@ -216,10 +216,10 @@ def all_reduce(
216216
return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op, group=group))
217217

218218
def all_gather(
219-
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
219+
self, tensor: Union[torch.Tensor, float, str, Any], group: Optional[Any] = None
220220
) -> Union[torch.Tensor, float, List[float], List[str]]:
221221
if not isinstance(tensor, (torch.Tensor, Number, str)):
222-
raise TypeError(f"Unhandled input type {type(tensor)}")
222+
return self._do_all_gather_object(tensor, group=group)
223223

224224
return self._collective_op(tensor, self._do_all_gather, group=group)
225225

@@ -282,6 +282,10 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
282282
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
283283
pass
284284

285+
@abstractmethod
286+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
287+
pass
288+
285289
@abstractmethod
286290
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
287291
pass
@@ -373,6 +377,9 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
373377
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
374378
return tensor
375379

380+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Any:
381+
return tensor
382+
376383
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
377384
return ranks
378385

ignite/distributed/comp_models/horovod.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
192192
tensor = tensor.unsqueeze(0)
193193
return hvd.allgather(tensor)
194194

195+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
196+
if group is not None:
197+
raise NotImplementedError("all_gather with group for horovod is not implemented")
198+
199+
return hvd.allgather_object(tensor)
200+
195201
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
196202
return hvd.ProcessSet(ranks)
197203

ignite/distributed/comp_models/native.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
423423
if group is not None and not isinstance(group, dist.ProcessGroup):
424424
raise ValueError("Argument group should be list of int or ProcessGroup")
425425
reduce_op = self._reduce_op_map[op]
426+
# we do if/else here for compatbility with older pytorch versions
426427
if group is not None:
427428
dist.all_reduce(tensor, reduce_op, group=group)
428429
else:
@@ -441,12 +442,26 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
441442
if tensor.ndimension() == 0:
442443
tensor = tensor.unsqueeze(0)
443444
output = [torch.zeros_like(tensor) for _ in range(group_size)]
445+
# we do if/else here for compatbility with older pytorch versions
444446
if group is not None:
445447
dist.all_gather(output, tensor, group=group)
446448
else:
447449
dist.all_gather(output, tensor)
448450
return torch.cat(output, dim=0)
449451

452+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
453+
if group == dist.GroupMember.NON_GROUP_MEMBER:
454+
return tensor
455+
elif group is None:
456+
group_size = self.get_world_size()
457+
elif isinstance(group, dist.ProcessGroup):
458+
group_size = group.size()
459+
else:
460+
raise ValueError("Argument group should be list of int or ProcessGroup")
461+
output = [None for _ in range(group_size)]
462+
dist.all_gather_object(output, tensor, group=group)
463+
return output
464+
450465
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
451466
return dist.new_group(ranks=ranks, **kwargs)
452467

ignite/distributed/comp_models/xla.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
155155
xm.all_reduce("sum", [output], groups=group)
156156
return output.reshape(-1, *output.shape[2:])
157157

158+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
159+
raise NotImplementedError("all_gather on object is not implemented for xla")
160+
158161
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
159162
return [ranks]
160163

tests/ignite/distributed/utils/__init__.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,22 @@ def _test_distrib_all_reduce_group(device):
156156

157157
def _test_distrib_all_gather(device):
158158
rank = idist.get_rank()
159+
ws = idist.get_world_size()
159160

160161
res = torch.tensor(idist.all_gather(10), device=device)
161-
true_res = torch.tensor([10] * idist.get_world_size(), device=device)
162+
true_res = torch.tensor([10] * ws, device=device)
162163
assert (res == true_res).all()
163164

164165
t = torch.tensor(rank, device=device)
165166
res = idist.all_gather(t)
166-
true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device)
167+
true_res = torch.tensor([i for i in range(ws)], device=device)
167168
assert (res == true_res).all()
168169

169170
x = "test-test"
170171
if rank == 0:
171172
x = "abc"
172173
res = idist.all_gather(x)
173-
true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1)
174+
true_res = ["abc"] + ["test-test"] * (ws - 1)
174175
assert res == true_res
175176

176177
base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
@@ -179,22 +180,39 @@ def _test_distrib_all_gather(device):
179180
x = "abc"
180181

181182
res = idist.all_gather(x)
182-
true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1)
183+
true_res = ["abc"] + [base_x] * (ws - 1)
183184
assert res == true_res
184185

185186
t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1)
186187
in_dtype = t.dtype
187188
res = idist.all_gather(t)
188-
assert res.shape == (idist.get_world_size() * 4, 25)
189+
assert res.shape == (ws * 4, 25)
189190
assert res.dtype == in_dtype
190-
true_res = torch.zeros(idist.get_world_size() * 4, 25, device=device)
191-
for i in range(idist.get_world_size()):
191+
true_res = torch.zeros(ws * 4, 25, device=device)
192+
for i in range(ws):
192193
true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1)
193194
assert (res == true_res).all()
194195

195-
if idist.get_world_size() > 1:
196-
with pytest.raises(TypeError, match=r"Unhandled input type"):
197-
idist.all_reduce([0, 1, 2])
196+
if ws > 1 and idist.backend() != "xla-tpu":
197+
t = {
198+
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
199+
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
200+
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
201+
}
202+
res = idist.all_gather(t)
203+
assert isinstance(res, list) and len(res) == ws
204+
for i, obj in enumerate(res):
205+
assert isinstance(obj, dict)
206+
assert list(obj.keys()) == ["a", "b", "c"], obj
207+
expected_device = device if device.type == "cpu" else torch.device(f"{device.type}:{i}")
208+
expected = {
209+
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
210+
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
211+
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
212+
}
213+
assert obj["a"] == expected["a"]
214+
assert (obj["b"] == expected["b"]).all()
215+
assert obj["c"] == expected["c"]
198216

199217

200218
def _test_distrib_all_gather_group(device):

0 commit comments

Comments
 (0)