Skip to content

Commit a2a31eb

Browse files
committed
not for land: deepgemm hack
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 5bdc25d commit a2a31eb

File tree

4 files changed

+387
-6
lines changed

4 files changed

+387
-6
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,21 @@ def run(
6262
):
6363
device = "cuda"
6464
# TODO(future PR): this is ugly
65-
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas"), "unsupported"
65+
assert recipe in (
66+
"tensorwise",
67+
"rowwise",
68+
"mxfp8_cublas",
69+
"deepgemm_128_1_128_128",
70+
), "unsupported"
6671

6772
specs = get_specs()
6873
bf16_peak_tops = specs["bf16_peak_tops"]
6974
fp8_peak_tops = specs["fp8_peak_tops"]
7075
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
7176
print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}")
77+
# TODO(this PR): make gpu kernel time work with deepgemm kernel
78+
print(f"use_gpu_kernel_time: {use_gpu_kernel_time}")
79+
print(f"recipe: {recipe}")
7280

7381
headers = (
7482
"fast_accum",
@@ -121,16 +129,31 @@ def run(
121129
elif recipe == "mxfp8_cublas":
122130
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
123131
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
132+
elif recipe == "deepgemm_128_1_128_128":
133+
scale_a = torch.ones(M, K // 128, device=device)
134+
scale_b = torch.ones(N // 128, K // 128, device=device)
124135
else:
125136
assert False, f"unknown recipe {recipe}"
126137

127-
def do_matmul(A, B):
128-
nonlocal scale_a
129-
nonlocal scale_b
130-
return torch._scaled_mm(
131-
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
138+
if recipe == "deepgemm_128_1_128_128":
139+
from torchao.prototype.deep_gemm_float8_training.deep_gemm_utils import (
140+
scaled_mm_deep_gemm_128_1_128_128,
132141
)
133142

143+
def do_matmul(A, B):
144+
nonlocal scale_a
145+
nonlocal scale_b
146+
return scaled_mm_deep_gemm_128_1_128_128(A, B.t(), scale_a, scale_b)
147+
148+
else:
149+
150+
def do_matmul(A, B):
151+
nonlocal scale_a
152+
nonlocal scale_b
153+
return torch._scaled_mm(
154+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
155+
)
156+
134157
fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
135158
tops, fp8_peak_tops, use_gpu_kernel_time, do_matmul, A, B
136159
)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import random
9+
10+
import pytest
11+
import torch
12+
13+
from torchao.utils import (
14+
TORCH_VERSION_AT_LEAST_2_7,
15+
)
16+
17+
if not TORCH_VERSION_AT_LEAST_2_7:
18+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
19+
20+
21+
from torchao.float8.float8_utils import compute_error
22+
from torchao.prototype.deep_gemm_float8_training.deep_gemm_utils import (
23+
scale_narrow_tiles,
24+
scale_square_tiles,
25+
scaled_mm_deep_gemm_128_1_128_1,
26+
scaled_mm_deep_gemm_128_1_128_128,
27+
unscale_narrow_tiles,
28+
unscale_square_tiles,
29+
)
30+
from torchao.prototype.deep_gemm_float8_training.linear import (
31+
DeepGemmFloat8Linear,
32+
DeepGemmFloat8LinearConfig,
33+
)
34+
from torchao.quantization import quantize_
35+
36+
random.seed(0)
37+
torch.manual_seed(0)
38+
39+
40+
class TestDeepGemmUtils:
41+
@pytest.mark.parametrize("mkn", [(128, 128, 128), (256, 512, 1024)])
42+
def test_128_1_128_128_gemm(self, mkn):
43+
M, K, N = mkn
44+
tile_size = 128
45+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
46+
w = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
47+
xq, xs = scale_narrow_tiles(x, tile_size=tile_size)
48+
wq, ws = scale_square_tiles(w, tile_size=tile_size)
49+
y = scaled_mm_deep_gemm_128_1_128_128(xq, wq, 1.0 / xs, 1.0 / ws)
50+
y_ref = x @ w.T
51+
sqnr = compute_error(y_ref, y)
52+
assert sqnr > 26.0
53+
54+
@pytest.mark.parametrize("mkn", [(128, 128, 128), (256, 512, 1024)])
55+
def test_128_1_128_1_gemm(self, mkn):
56+
M, K, N = mkn
57+
tile_size = 128
58+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
59+
g = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
60+
xq, xs = scale_narrow_tiles(x, tile_size=tile_size)
61+
gq, gs = scale_narrow_tiles(g, tile_size=tile_size)
62+
gi = scaled_mm_deep_gemm_128_1_128_1(xq, gq, 1.0 / xs, 1.0 / gs)
63+
gi_ref = x @ g.T
64+
sqnr = compute_error(gi_ref, gi)
65+
assert sqnr > 27.0
66+
67+
def test_scale_square_tiles(self):
68+
h, w = 8, 8
69+
tile_size = 4
70+
71+
x = torch.arange(h * w, device="cuda").float().reshape(h, w)
72+
xq, s = scale_square_tiles(x, tile_size=tile_size)
73+
xqdq = unscale_square_tiles(xq, s, tile_size=tile_size)
74+
sqnr = compute_error(x, xqdq)
75+
assert sqnr >= 25.0
76+
77+
def test_scale_narrow_tiles(self):
78+
h, w = 8, 16
79+
tile_size = 4
80+
81+
x = torch.arange(h * w, device="cuda").float().reshape(h, w)
82+
xq, s = scale_narrow_tiles(x, tile_size=tile_size)
83+
xqdq = unscale_narrow_tiles(xq, s, tile_size=tile_size)
84+
sqnr = compute_error(x, xqdq)
85+
assert sqnr >= 32.0
86+
87+
88+
class TestDeepGemmLinear:
89+
@pytest.mark.parametrize("x_rank", [2, 3])
90+
def test_hello_world(self, x_rank):
91+
M, K, N = 128, 256, 512
92+
93+
x_ref = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
94+
while len(x_ref.shape) < x_rank:
95+
x_ref = x_ref.unsqueeze(0)
96+
x_ref.requires_grad_()
97+
98+
m_ref = torch.nn.Linear(K, N, bias=False).bfloat16().cuda()
99+
go_ref = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
100+
while len(go_ref.shape) < x_rank:
101+
go_ref = go_ref.unsqueeze(0)
102+
103+
x = copy.deepcopy(x_ref).requires_grad_()
104+
m = copy.deepcopy(m_ref)
105+
go = copy.deepcopy(go_ref)
106+
107+
m = DeepGemmFloat8Linear.from_float(m)
108+
109+
y_ref = m_ref(x_ref)
110+
y_ref.backward(go_ref)
111+
y = m(x)
112+
y.backward(go)
113+
114+
sqnr_y = compute_error(y_ref, y)
115+
sqnr_gi = compute_error(x_ref.grad, x.grad)
116+
sqnr_gw = compute_error(m_ref.weight.grad, m.weight.grad)
117+
assert sqnr_y >= 25.0
118+
assert sqnr_gi >= 25.0
119+
assert sqnr_gw >= 25.0
120+
121+
def test_api(self):
122+
m = torch.nn.Sequential(torch.nn.Linear(128, 128, bias=False))
123+
quantize_(m, config=DeepGemmFloat8LinearConfig())
124+
assert type(m[0]) == DeepGemmFloat8Linear
125+
126+
127+
if __name__ == "__main__":
128+
pytest.main([__file__])
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# TODO gate by existence of deep_gemm library
2+
import deep_gemm
3+
import torch
4+
5+
6+
def scaled_mm_deep_gemm_128_1_128_128(a, b, a_scale, b_scale):
7+
M, K = a.shape
8+
N, K = b.shape
9+
out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device)
10+
deep_gemm.gemm_fp8_fp8_bf16_nt((a, a_scale), (b, b_scale), out=out)
11+
return out
12+
13+
14+
def scaled_mm_deep_gemm_128_1_128_1(a, b, a_scale, b_scale):
15+
M, K = a.shape
16+
N, K = b.shape
17+
# Note: the results from `wgrad_gemm_fp8_fp8_fp32_nt` are **accumulated**
18+
# into this tensor. For now, we initialize with `zeros` to get correct
19+
# numerics in toy examples. For a real use case, this will need to pass
20+
# in the gradient tensor directly.
21+
out = torch.zeros((M, N), dtype=torch.float, device=a.device)
22+
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((a, a_scale), (b, b_scale), out=out)
23+
return out
24+
25+
26+
def scale_narrow_tiles(x, tile_size=128):
27+
"""
28+
Input: weight tensor in high precision
29+
Output: weight tensor in float8, and scale, tiled 1 by tile_size
30+
31+
This is one function because logically this should be a fused kernel.
32+
"""
33+
# TODO assert row major
34+
orig_shape = x.shape
35+
x = x.reshape(-1, tile_size)
36+
x_amax = x.abs().max(dim=1).values.unsqueeze(1).clamp(1e-4)
37+
# TODO read from finfo instead of hardcoding
38+
s = 448.0 / x_amax
39+
40+
x = (x * s).clamp(min=-448.0, max=448.0).to(torch.float8_e4m3fn)
41+
x = x.reshape(*orig_shape)
42+
s = s.reshape(orig_shape[0], -1).to(torch.float)
43+
return x, s
44+
45+
46+
def unscale_narrow_tiles(x, s, tile_size=128):
47+
# for debugging
48+
orig_shape = x.shape
49+
x = x.reshape(-1, tile_size)
50+
s = s.reshape(-1).unsqueeze(1)
51+
x = x.to(torch.float) / s
52+
x = x.reshape(*orig_shape)
53+
return x
54+
55+
56+
def scale_square_tiles(x, tile_size=128):
57+
"""
58+
Input: weight tensor in high precision
59+
Output: weight tensor in float8, and scale, tiled tile_size by tile_size
60+
61+
This is one function because logically this should be a fused kernel.
62+
`torch.compile` currently has three kernels, we should write a triton
63+
to speed this up kernel and file an issue for compile to catch up.
64+
"""
65+
# TODO assert row major
66+
assert len(x.shape) == 2, "unsupported"
67+
height, width = x.shape
68+
69+
# might be funky with dynamic shapes...
70+
t_h = height // tile_size
71+
t_w = width // tile_size
72+
x = x.reshape(t_h, tile_size, t_w, tile_size)
73+
x = x.permute(0, 2, 1, 3)
74+
x = x.reshape(-1, tile_size * tile_size)
75+
m = x.abs().max(dim=1).values.unsqueeze(1).clamp(1e-4)
76+
77+
# convert to scale
78+
# TODO read from finfo instead of hardcoding
79+
s = 448.0 / m
80+
81+
x = (x * s).clamp(min=-448.0, max=448.0).to(torch.float8_e4m3fn)
82+
x = x.reshape(t_h, t_w, tile_size, tile_size)
83+
x = x.permute(0, 2, 1, 3)
84+
x = x.reshape(height, width)
85+
s = s.reshape(t_h, t_w).to(torch.float)
86+
87+
return x, s
88+
89+
90+
def unscale_square_tiles(x, s, tile_size=128):
91+
# for debugging
92+
93+
assert len(x.shape) == 2, "unsupported"
94+
height, width = x.shape
95+
96+
# might be funky with dynamic shapes...
97+
t_h = height // tile_size
98+
t_w = width // tile_size
99+
x = x.reshape(t_h, tile_size, t_w, tile_size)
100+
x = x.permute(0, 2, 1, 3)
101+
x = x.reshape(-1, tile_size * tile_size)
102+
103+
s = s.reshape(-1).unsqueeze(1)
104+
105+
x = x.float() / s
106+
107+
x = x.reshape(t_h, t_w, tile_size, tile_size)
108+
x = x.permute(0, 2, 1, 3)
109+
x = x.reshape(height, width)
110+
return x

0 commit comments

Comments
 (0)