Skip to content

Commit 6edce83

Browse files
committed
potential bug in pipeline block
1 parent df3befc commit 6edce83

File tree

1 file changed

+23
-1
lines changed
  • src/nanotron/parallel/pipeline_parallel

1 file changed

+23
-1
lines changed

Diff for: src/nanotron/parallel/pipeline_parallel/block.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def forward(self, **kwargs):
9393
pipeline_state=self.pipeline_state,
9494
)
9595
continue
96-
9796
if v.requires_grad is True:
9897
raise ValueError(
9998
f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
@@ -152,6 +151,29 @@ def forward(self, **kwargs):
152151
# We don't store result in a buffer
153152
recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank)
154153
name_to_recv_id[name] = recv_id
154+
elif isinstance(tensor, dict):
155+
new_kwargs[name] = {}
156+
for k, v in tensor.items():
157+
# the same as above just looped over the dict
158+
if isinstance(v, TensorPointer):
159+
if isinstance(self.pipeline_state, PipelineTrainBatchState):
160+
for _ in range(len(self.pipeline_state.microbatches_activations_to_send)):
161+
send_activation = self.pipeline_state.microbatches_activations_to_send.popleft()
162+
# Execute
163+
send_activation()
164+
165+
if self.pipeline_state is not None:
166+
new_kwargs[name][k] = recv_from_pipeline_state_buffer(
167+
from_rank=tensor.group_rank,
168+
p2p=self.p2p,
169+
pipeline_state=self.pipeline_state,
170+
)
171+
continue
172+
# We don't store result in a buffer
173+
recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank)
174+
name_to_recv_id[name] = recv_id
175+
else:
176+
new_kwargs[name][k] = v
155177
else:
156178
new_kwargs[name] = tensor
157179

0 commit comments

Comments
 (0)