Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: reduce cudagraph mem via preoallcations #1253

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformer_engine.pytorch.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
empty_cached
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd_qkvpacked,
Expand Down Expand Up @@ -72,7 +73,6 @@
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing


_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
Expand Down Expand Up @@ -3630,6 +3630,7 @@ def forward(
)
return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)


@staticmethod
def backward(ctx, *grad_outputs):
assert len(grad_outputs) > 0, "No gradients received for backprop!"
Expand Down Expand Up @@ -3711,6 +3712,19 @@ def backward(ctx, *grad_outputs):
)
return ret, None, None

if is_graph_capturing():
total_shape = list(grad_outputs[0].shape)
total_shape[split_dim] = sum([g.shape[split_dim] for g in grad_outputs])

grad_input = empty_cached(
total_shape,
dtype=grad_outputs[0].dtype,
device=grad_outputs[0].device,
requires_grad=any([g.requires_grad for g in grad_outputs]),
)
torch.cat(grad_outputs, dim=split_dim, out=grad_input)
return grad_input, None, None

return torch.cat(grad_outputs, dim=split_dim), None, None


Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/cpp_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .activation import *
from .normalization import *
from .cast import *
from .graph_cache import *
53 changes: 53 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/graph_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

import torch
import torch._prims_common as utils

import transformer_engine_torch as tex

def empty_like_cached(input, *, dtype=None, layout=None, device=None, requires_grad=False,
memory_format=torch.preserve_format):

dtype = input.dtype if dtype is None else dtype
# layout = input.layout if layout is None else layout
device = input.device if device is None else device

if isinstance(device, int):
device = torch.device(device)
if isinstance(device, str):
device = torch.device(device, torch.cuda.current_device())

copy = tex.empty_like_cached(
input,
dtype=dtype,
layout=layout,
device=device,
pin_memory=False,
memory_format=None) #TODO
wrapper = torch.Tensor()
wrapper.data = copy
wrapper.requires_grad = input.requires_grad

return wrapper

def empty_cached(*size, out=None, dtype=None, layout=torch.strided, device=None,
requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format):

size = utils.extract_shape_from_varargs(size)
dtype = torch.get_default_dtype() if dtype is None else dtype
device = torch.device("cpu") if device is None else device

if isinstance(size, torch.Size):
size = tuple(size)
if isinstance(device, int):
device = torch.device(device)
if isinstance(device, str):
device = torch.device(device, torch.cuda.current_device())

copy = tex.empty_cached(
size=size,
dtype=dtype,
device=device,
pin_memory=False,
memory_format=None)
copy.requires_grad = requires_grad
return copy
9 changes: 6 additions & 3 deletions transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import transformer_engine_torch as tex
from ..constants import TE_DType
from ._common import canonicalize_fp8_scales, empty_tensor
from .graph_cache import empty_cached, empty_like_cached


__all__ = [
Expand All @@ -32,15 +33,17 @@ def fp8_cast_transpose_fused(
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
noop_flag: Optional[torch.Tensor] = None,
graph_cache: Optional[bool] = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Cast + Transpose with FP8 output"""

# Allocate outputs if needed
if transpose_out is None:
transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8)
empty_func = empty_cached if graph_cache else torch.empty
transpose_out = empty_func(inp.shape[1], inp.shape[0], device=torch.cuda.current_device(), dtype=torch.uint8)
if cast_out is None:
cast_out = torch.empty_like(inp, dtype=torch.uint8)

empty_like_func = empty_like_cached if graph_cache else torch.empty_like
cast_out = empty_like_func(inp, device=torch.cuda.current_device(), dtype=torch.uint8)
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
Expand Down
35 changes: 25 additions & 10 deletions transformer_engine/pytorch/csrc/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/

#include "extensions.h"
#include "common.h"
#include "transformer_engine/transformer_engine.h"

Expand Down Expand Up @@ -72,7 +73,7 @@ size_t product(const std::vector<size_t>& shape) {
}

at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
bool init_to_zeros) {
bool init_to_zeros, bool graph_cache) {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
c10::IntArrayRef ar_shape(shape_int64);
if (init_to_zeros) {
Expand All @@ -83,29 +84,43 @@ at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_eng
}

at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
bool init_to_zeros) {
bool init_to_zeros, bool graph_cache) {

at::Tensor (*empty_func)(at::IntArrayRef, at::TensorOptions, ::std::optional<at::MemoryFormat> memory_format);
if (is_graph_capturing() && graph_cache)
empty_func = &empty_cached;
else
empty_func = &at::empty;

auto size = shape.ndim;
if (size == 2 && init_to_zeros) {
return at::zeros({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
at::CUDA(GetATenDType(type)));
} else if (size == 2) {
return at::empty({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
at::CUDA(GetATenDType(type)));
return empty_func({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
at::CUDA(GetATenDType(type)), std::nullopt);
} else if (size == 1 && init_to_zeros) {
return at::zeros({static_cast<int64_t>(shape.data[0])}, at::CUDA(GetATenDType(type)));
} else if (size == 1) {
return at::empty({static_cast<int64_t>(shape.data[0])}, at::CUDA(GetATenDType(type)));
return empty_func({static_cast<int64_t>(shape.data[0])}, at::CUDA(GetATenDType(type)), std::nullopt);
}
NVTE_CHECK(false, "Should never reach here! func: allocateSpace");
}

at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype) {
return at::empty({static_cast<int64_t>(M), static_cast<int64_t>(N)},
at::CUDA(GetATenDType(dtype)));
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype, bool graph_cache) {
if (is_graph_capturing() && graph_cache)
return empty_cached({static_cast<int64_t>(M), static_cast<int64_t>(N)},
at::CUDA(GetATenDType(dtype)));
else
return at::empty({static_cast<int64_t>(M), static_cast<int64_t>(N)},
at::CUDA(GetATenDType(dtype)));
}

at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype) {
return at::empty({static_cast<int64_t>(M)}, at::CUDA(GetATenDType(dtype)));
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype, bool graph_cache) {
if (is_graph_capturing() && graph_cache)
return empty_cached({static_cast<int64_t>(M)}, at::CUDA(GetATenDType(dtype)));
else
return at::empty({static_cast<int64_t>(M)}, at::CUDA(GetATenDType(dtype)));
}

void* getDataPtr(at::Tensor tensor, int offset) {
Expand Down
29 changes: 24 additions & 5 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#include <random>
#include <stdexcept>
#include <vector>

#include <map>
#include "common/util/logging.h"

namespace transformer_engine {
Expand Down Expand Up @@ -84,6 +84,25 @@ enum FP8BwdTensors {

} // namespace transformer_engine


class GraphCache {
public:
std::vector<at::Tensor> cache;
bool graph_locked;
bool graph_capturing;
int cache_index;

GraphCache() {
graph_locked = false;
graph_capturing = false;
cache_index = 0;
}

void insert(at::Tensor tensor);
at::Tensor retrieve();
};


transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe);

Expand Down Expand Up @@ -156,14 +175,14 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor,
size_t product(const std::vector<size_t>& shape);

at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
bool init_to_zeros);
bool init_to_zeros, bool graph_cache = false);

at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
bool init_to_zeros = false);
bool init_to_zeros = false, bool graph_cache = false);

at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype);
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype, bool graph_cache = false);

at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype);
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype, bool graph_cache= false);

void* getDataPtr(at::Tensor tensor, int offset = 0);

Expand Down
21 changes: 20 additions & 1 deletion transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> fused_multi_cast_tr
std::vector<int> scale_indices, std::vector<int> amax_indices,
std::vector<int> scale_inv_indices, transformer_engine::DType otype);

at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, bool graph_cache=false);

void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype);

Expand Down Expand Up @@ -475,4 +475,23 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale);

at::Tensor empty_like_cached(const at::Tensor &self, ::std::optional<at::ScalarType> dtype,
::std::optional<at::Layout> layout, ::std::optional<at::Device> device,
::std::optional<bool> pin_memory,
::std::optional<at::MemoryFormat> memory_format);
at::Tensor empty_like_cached(const at::Tensor &self , at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=std::nullopt);

at::Tensor empty_cached(at::IntArrayRef size, ::std::optional<at::ScalarType> dtype,
::std::optional<at::Layout> layout, ::std::optional<at::Device> device,
::std::optional<bool> pin_memory,
::std::optional<at::MemoryFormat> memory_format);
at::Tensor empty_cached(at::IntArrayRef size, at::TensorOptions options={}, ::std::optional<at::MemoryFormat> memory_format=std::nullopt);


void set_capture_start();
void set_capture_end();
void set_graph_cached_locked();
bool is_graph_capturing();


#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
14 changes: 7 additions & 7 deletions transformer_engine/pytorch/csrc/extensions/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType ot
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;

auto output = allocateTorchTensor(M, N, otype);
auto output = allocateTorchTensor(M, N, otype, true);

auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
Expand Down Expand Up @@ -69,7 +69,7 @@ at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType ot
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;

auto output = allocateTorchTensor(M, N, otype);
auto output = allocateTorchTensor(M, N, otype, true);

auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
Expand Down Expand Up @@ -108,7 +108,7 @@ at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType o
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;

auto output = allocateTorchTensor(M, N, otype);
auto output = allocateTorchTensor(M, N, otype, true);

auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
Expand Down Expand Up @@ -147,7 +147,7 @@ at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType o
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;

auto output = allocateTorchTensor(M, N, otype);
auto output = allocateTorchTensor(M, N, otype, true);

auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
Expand Down Expand Up @@ -186,7 +186,7 @@ at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;

auto output = allocateTorchTensor(M, N, otype);
auto output = allocateTorchTensor(M, N, otype, true);

auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
Expand Down Expand Up @@ -224,7 +224,7 @@ at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType o
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;

auto output = allocateTorchTensor(M, N, otype);
auto output = allocateTorchTensor(M, N, otype, true);

auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
Expand Down Expand Up @@ -262,7 +262,7 @@ at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType o
size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;

auto output = allocateTorchTensor(M, N, otype);
auto output = allocateTorchTensor(M, N, otype, true);

auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
Expand Down
21 changes: 17 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/apply_rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
} else {
output = torch::empty({s, b, h, d}, act_options);
}

// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
Expand Down Expand Up @@ -100,10 +101,22 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto act_options = output_grads.options().requires_grad(false);
at::Tensor input_grads;
if (transpose_output_memory) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
if (is_graph_capturing()){
input_grads = empty_cached({b, s, h, d}, act_options).transpose(0, 1);
}
else{
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
}
}
else {
if (is_graph_capturing()){
input_grads = empty_cached({s, b, h, d}, act_options);
}
else{
input_grads = torch::empty({s, b, h, d}, act_options);
}
}

const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
Expand Down Expand Up @@ -219,4 +232,4 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
at::cuda::getCurrentCUDAStream());

return input_grads;
}
}
Loading