Skip to content

[not for land] float8 blockwise scaling training prototype using deep_gemm #2386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,21 @@ def run(
):
device = "cuda"
# TODO(future PR): this is ugly
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas"), "unsupported"
assert recipe in (
"tensorwise",
"rowwise",
"mxfp8_cublas",
"deepgemm_128_1_128_128",
), "unsupported"

specs = get_specs()
bf16_peak_tops = specs["bf16_peak_tops"]
fp8_peak_tops = specs["fp8_peak_tops"]
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}")
# TODO(this PR): make gpu kernel time work with deepgemm kernel
print(f"use_gpu_kernel_time: {use_gpu_kernel_time}")
print(f"recipe: {recipe}")

headers = (
"fast_accum",
Expand Down Expand Up @@ -121,16 +129,31 @@ def run(
elif recipe == "mxfp8_cublas":
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
elif recipe == "deepgemm_128_1_128_128":
scale_a = torch.ones(M, K // 128, device=device)
scale_b = torch.ones(N // 128, K // 128, device=device)
else:
assert False, f"unknown recipe {recipe}"

def do_matmul(A, B):
nonlocal scale_a
nonlocal scale_b
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
if recipe == "deepgemm_128_1_128_128":
from torchao.prototype.deep_gemm_float8_training.deep_gemm_utils import (
scaled_mm_deep_gemm_128_1_128_128,
)

def do_matmul(A, B):
nonlocal scale_a
nonlocal scale_b
return scaled_mm_deep_gemm_128_1_128_128(A, B.t(), scale_a, scale_b)

else:

def do_matmul(A, B):
nonlocal scale_a
nonlocal scale_b
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
tops, fp8_peak_tops, use_gpu_kernel_time, do_matmul, A, B
)
Expand Down
128 changes: 128 additions & 0 deletions test/prototype/deep_gemm_float8_training/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy
import random

import pytest
import torch

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_7,
)

if not TORCH_VERSION_AT_LEAST_2_7:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.float8.float8_utils import compute_error
from torchao.prototype.deep_gemm_float8_training.deep_gemm_utils import (
scale_narrow_tiles,
scale_square_tiles,
scaled_mm_deep_gemm_128_1_128_1,
scaled_mm_deep_gemm_128_1_128_128,
unscale_narrow_tiles,
unscale_square_tiles,
)
from torchao.prototype.deep_gemm_float8_training.linear import (
DeepGemmFloat8Linear,
DeepGemmFloat8LinearConfig,
)
from torchao.quantization import quantize_

random.seed(0)
torch.manual_seed(0)


class TestDeepGemmUtils:
@pytest.mark.parametrize("mkn", [(128, 128, 128), (256, 512, 1024)])
def test_128_1_128_128_gemm(self, mkn):
M, K, N = mkn
tile_size = 128
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
xq, xs = scale_narrow_tiles(x, tile_size=tile_size)
wq, ws = scale_square_tiles(w, tile_size=tile_size)
y = scaled_mm_deep_gemm_128_1_128_128(xq, wq, 1.0 / xs, 1.0 / ws)
y_ref = x @ w.T
sqnr = compute_error(y_ref, y)
assert sqnr > 26.0

@pytest.mark.parametrize("mkn", [(128, 128, 128), (256, 512, 1024)])
def test_128_1_128_1_gemm(self, mkn):
M, K, N = mkn
tile_size = 128
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
g = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
xq, xs = scale_narrow_tiles(x, tile_size=tile_size)
gq, gs = scale_narrow_tiles(g, tile_size=tile_size)
gi = scaled_mm_deep_gemm_128_1_128_1(xq, gq, 1.0 / xs, 1.0 / gs)
gi_ref = x @ g.T
sqnr = compute_error(gi_ref, gi)
assert sqnr > 27.0

def test_scale_square_tiles(self):
h, w = 8, 8
tile_size = 4

x = torch.arange(h * w, device="cuda").float().reshape(h, w)
xq, s = scale_square_tiles(x, tile_size=tile_size)
xqdq = unscale_square_tiles(xq, s, tile_size=tile_size)
sqnr = compute_error(x, xqdq)
assert sqnr >= 25.0

def test_scale_narrow_tiles(self):
h, w = 8, 16
tile_size = 4

x = torch.arange(h * w, device="cuda").float().reshape(h, w)
xq, s = scale_narrow_tiles(x, tile_size=tile_size)
xqdq = unscale_narrow_tiles(xq, s, tile_size=tile_size)
sqnr = compute_error(x, xqdq)
assert sqnr >= 32.0


class TestDeepGemmLinear:
@pytest.mark.parametrize("x_rank", [2, 3])
def test_hello_world(self, x_rank):
M, K, N = 128, 256, 512

x_ref = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
while len(x_ref.shape) < x_rank:
x_ref = x_ref.unsqueeze(0)
x_ref.requires_grad_()

m_ref = torch.nn.Linear(K, N, bias=False).bfloat16().cuda()
go_ref = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
while len(go_ref.shape) < x_rank:
go_ref = go_ref.unsqueeze(0)

x = copy.deepcopy(x_ref).requires_grad_()
m = copy.deepcopy(m_ref)
go = copy.deepcopy(go_ref)

m = DeepGemmFloat8Linear.from_float(m)

y_ref = m_ref(x_ref)
y_ref.backward(go_ref)
y = m(x)
y.backward(go)

sqnr_y = compute_error(y_ref, y)
sqnr_gi = compute_error(x_ref.grad, x.grad)
sqnr_gw = compute_error(m_ref.weight.grad, m.weight.grad)
assert sqnr_y >= 25.0
assert sqnr_gi >= 25.0
assert sqnr_gw >= 25.0

def test_api(self):
m = torch.nn.Sequential(torch.nn.Linear(128, 128, bias=False))
quantize_(m, config=DeepGemmFloat8LinearConfig())
assert type(m[0]) == DeepGemmFloat8Linear


if __name__ == "__main__":
pytest.main([__file__])
110 changes: 110 additions & 0 deletions torchao/prototype/deep_gemm_float8_training/deep_gemm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# TODO gate by existence of deep_gemm library
import deep_gemm
import torch


def scaled_mm_deep_gemm_128_1_128_128(a, b, a_scale, b_scale):
M, K = a.shape
N, K = b.shape
out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device)
deep_gemm.gemm_fp8_fp8_bf16_nt((a, a_scale), (b, b_scale), out=out)
return out


def scaled_mm_deep_gemm_128_1_128_1(a, b, a_scale, b_scale):
M, K = a.shape
N, K = b.shape
# Note: the results from `wgrad_gemm_fp8_fp8_fp32_nt` are **accumulated**
# into this tensor. For now, we initialize with `zeros` to get correct
# numerics in toy examples. For a real use case, this will need to pass
# in the gradient tensor directly.
out = torch.zeros((M, N), dtype=torch.float, device=a.device)
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((a, a_scale), (b, b_scale), out=out)
return out


def scale_narrow_tiles(x, tile_size=128):
"""
Input: weight tensor in high precision
Output: weight tensor in float8, and scale, tiled 1 by tile_size

This is one function because logically this should be a fused kernel.
"""
# TODO assert row major
orig_shape = x.shape
x = x.reshape(-1, tile_size)
x_amax = x.abs().max(dim=1).values.unsqueeze(1).clamp(1e-4)
# TODO read from finfo instead of hardcoding
s = 448.0 / x_amax

x = (x * s).clamp(min=-448.0, max=448.0).to(torch.float8_e4m3fn)
x = x.reshape(*orig_shape)
s = s.reshape(orig_shape[0], -1).to(torch.float)
return x, s


def unscale_narrow_tiles(x, s, tile_size=128):
# for debugging
orig_shape = x.shape
x = x.reshape(-1, tile_size)
s = s.reshape(-1).unsqueeze(1)
x = x.to(torch.float) / s
x = x.reshape(*orig_shape)
return x


def scale_square_tiles(x, tile_size=128):
"""
Input: weight tensor in high precision
Output: weight tensor in float8, and scale, tiled tile_size by tile_size

This is one function because logically this should be a fused kernel.
`torch.compile` currently has three kernels, we should write a triton
to speed this up kernel and file an issue for compile to catch up.
"""
# TODO assert row major
assert len(x.shape) == 2, "unsupported"
height, width = x.shape

# might be funky with dynamic shapes...
t_h = height // tile_size
t_w = width // tile_size
x = x.reshape(t_h, tile_size, t_w, tile_size)
x = x.permute(0, 2, 1, 3)
x = x.reshape(-1, tile_size * tile_size)
m = x.abs().max(dim=1).values.unsqueeze(1).clamp(1e-4)

# convert to scale
# TODO read from finfo instead of hardcoding
s = 448.0 / m

x = (x * s).clamp(min=-448.0, max=448.0).to(torch.float8_e4m3fn)
x = x.reshape(t_h, t_w, tile_size, tile_size)
x = x.permute(0, 2, 1, 3)
x = x.reshape(height, width)
s = s.reshape(t_h, t_w).to(torch.float)

return x, s


def unscale_square_tiles(x, s, tile_size=128):
# for debugging

assert len(x.shape) == 2, "unsupported"
height, width = x.shape

# might be funky with dynamic shapes...
t_h = height // tile_size
t_w = width // tile_size
x = x.reshape(t_h, tile_size, t_w, tile_size)
x = x.permute(0, 2, 1, 3)
x = x.reshape(-1, tile_size * tile_size)

s = s.reshape(-1).unsqueeze(1)

x = x.float() / s

x = x.reshape(t_h, t_w, tile_size, tile_size)
x = x.permute(0, 2, 1, 3)
x = x.reshape(height, width)
return x
Loading
Loading