From 6c85e611ebcea9c40d5636f97d6e335ce42e1008 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Wed, 9 Oct 2024 00:43:39 -0700 Subject: [PATCH] Reuse cudagraph input and output tensor memory Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 128 +++++++++++++++++++++++----- 1 file changed, 109 insertions(+), 19 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 227144d4a4..67b14b2fa3 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -60,6 +60,8 @@ def _make_graphed_callables( allow_unused_input: bool = False, fp8_weight_caching: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, + reuse_graph_inputs=False, + reuse_graph_outputs=False, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, ) -> SingleOrTuple[Callable]: @@ -87,6 +89,17 @@ def _make_graphed_callables( callables = (callables,) sample_args = (sample_args,) sample_kwargs = (sample_kwargs,) + if reuse_graph_inputs: + len_args = len(sample_args[0]) + for arg in sample_args: + assert len_args == len(arg), f"Arguments must have same length and shape for reusing." + sample_args = list(sample_args) + len_kwargs = len(sample_kwargs[0]) + for kwarg in sample_kwargs: + assert len_kwargs == len( + kwarg + ), f"Keyword arguments must have same length and shape for reusing." + sample_kwargs = list(sample_kwargs) # Check sizes of args if _order is None: @@ -228,8 +241,26 @@ def _make_graphed_callables( per_callable_static_grad_inputs = [None] * len(flatten_sample_args) fwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks - for c_id in _order: + # Following variables are for input/output reusing to save memory. + fwd_order_recorder = {} + fwd_order_accu = 0 + per_callable_fwd_idx_recorder = [] + static_grad_outputs = None + static_grad_inputs = [] + static_grad_inputs_exists = False + for idx, c_id in enumerate(_order): if c_id > 0: + if reuse_graph_inputs or reuse_graph_outputs: + # Record the fwd order pattern for input data reusing. + if c_id in fwd_order_recorder: + fwd_order_recorder[c_id].append(fwd_order_accu) + else: + fwd_order_recorder[c_id] = [fwd_order_accu] + fwd_order_accu += 1 + if idx > 1 and _order[idx - 1] < 0: + # It can use the tensor buffer of a previous one. + reuse_fwd_idx = fwd_order_recorder[abs(_order[idx - 1])].pop(0) + # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] m_chunk = c_id - 1 for l_no in range(num_layers): @@ -237,13 +268,54 @@ def _make_graphed_callables( per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + ( fwd_idx[m_chunk] * num_layers + l_no ) + if reuse_graph_inputs or reuse_graph_outputs: + per_callable_fwd_idx_recorder.append(per_callable_fwd_idx) + if idx > 1 and _order[idx - 1] < 0: + # It can use the tensor buffer of a previous one. + reuse_per_callable_fwd_idx = per_callable_fwd_idx_recorder[ + reuse_fwd_idx * num_layers + l_no + ] + if reuse_graph_inputs: + sample_args[per_callable_fwd_idx] = sample_args[ + reuse_per_callable_fwd_idx + ] + sample_kwargs[per_callable_fwd_idx] = sample_kwargs[ + reuse_per_callable_fwd_idx + ] + flatten_sample_args[per_callable_fwd_idx] = flatten_sample_args[ + reuse_per_callable_fwd_idx + ] + per_callable_static_input_surfaces[per_callable_fwd_idx] = ( + per_callable_static_input_surfaces[reuse_per_callable_fwd_idx][ + : len(flatten_sample_args[per_callable_fwd_idx]) + ] + + per_callable_static_input_surfaces[per_callable_fwd_idx][ + len(flatten_sample_args[per_callable_fwd_idx]) : + ] + ) + if reuse_graph_outputs: + static_outputs = per_callable_static_outputs[ + reuse_per_callable_fwd_idx + ] + detached_static_outputs = tuple( + so.detach() for so in static_outputs + ) args = sample_args[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] with torch.cuda.graph(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) - flatten_outputs, spec = _tree_flatten(outputs) - per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) + flatten_outputs, spec = _tree_flatten(outputs) + if reuse_graph_outputs and idx > 1 and _order[idx - 1] < 0: + for i, static_output in enumerate(detached_static_outputs): + static_output.copy_(flatten_outputs[i]) + per_callable_static_outputs[per_callable_fwd_idx] = ( + detached_static_outputs + ) + else: + per_callable_static_outputs[per_callable_fwd_idx] = tuple( + flatten_outputs + ) per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec graph_callables[per_callable_fwd_idx] = func fwd_idx[m_chunk] += 1 @@ -258,9 +330,10 @@ def _make_graphed_callables( static_outputs = per_callable_static_outputs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx] # For now, assumes all static_outputs require grad - static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs - ) + if not reuse_graph_inputs or static_grad_outputs is None: + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) with torch.cuda.graph(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), @@ -269,21 +342,29 @@ def _make_graphed_callables( only_inputs=True, allow_unused=allow_unused_input, ) - # Constructs a tuple suitable for returning from Graphed.backward: - # Pads out the actually-needed grads with Nones in gradient slots for inputs - # that don't require grad. I couldn't think of a one-liner for this pattern. - static_grad_inputs = [] - grad_idx = 0 - for arg in static_input_surface: - if arg.requires_grad: - static_grad_inputs.append(grad_inputs[grad_idx]) - grad_idx += 1 - else: - static_grad_inputs.append(None) # type: ignore[arg-type] - static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs + # that don't require grad. I couldn't think of a one-liner for this pattern. + if not reuse_graph_outputs: + static_grad_inputs = [] + grad_idx = 0 + for input_idx, arg in enumerate(static_input_surface): + if arg.requires_grad: + if reuse_graph_outputs and static_grad_inputs_exists: + if static_grad_inputs[input_idx] is not None: + static_grad_inputs[input_idx].copy_(grad_inputs[grad_idx]) + else: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + elif not reuse_graph_outputs or not static_grad_inputs_exists: + static_grad_inputs.append(None) # type: ignore[arg-type] + if reuse_graph_outputs: + static_grad_inputs_exists = True per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs - per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs + per_callable_static_grad_inputs[per_callable_bwd_idx] = tuple( + static_grad_inputs + ) bwd_idx[m_chunk] += 1 else: # Capture forward graphs @@ -514,6 +595,8 @@ def make_graphed_callables( num_warmup_iters: int = 3, allow_unused_input: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, + reuse_graph_inputs: bool = False, + reuse_graph_outputs: bool = False, fp8_enabled: bool = False, fp8_calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, @@ -543,6 +626,11 @@ def make_graphed_callables( and outputs are disconnected in compute graph. sample_kwargs: (tuple of) dict, optional Keyword arguments to callable(s) + reuse_graph_inputs: bool, default = `False` + Whether or not to reuse input data buffer between graphs to save memory usage. + reuse_graph_outputs: bool, default = `False` + Whether or not to reuse output data buffer between graphs to save memory + usage. Reusing output data buffer will inevitably cause extra DtoD data copy. pool: (tuple of) int, default = `None`, optional An instance returned from function `torch.cuda.graph_pool_handle` that hints this graph may share memory with the indicated pool. @@ -621,6 +709,8 @@ def forward_func(*args, **kwargs): allow_unused_input=allow_unused_input, fp8_weight_caching=fp8_weight_caching, sample_kwargs=sample_kwargs, + reuse_graph_inputs=reuse_graph_inputs, + reuse_graph_outputs=reuse_graph_outputs, _order=_order, pool=pool, )