Skip to content

Commit 64e2ea0

Browse files
committed
[ghstack] Add support for more shapes
ghstack-source-id: 8a2af2b ghstack-comment-id: 2779402838 Pull Request resolved: #2021
1 parent 3219393 commit 64e2ea0

File tree

6 files changed

+204
-40
lines changed

6 files changed

+204
-40
lines changed

benchmarks/microbenchmarks/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ model_params:
4646
[2048, 4096, 1024],
4747
[4096, 4096, 1024]
4848
]
49+
- name: "llama"
50+
- name: "pow2"
51+
min_power: 10 # Optional, default is 10 (1024)
52+
max_power: 14 # Optional, default is 14 (16,384)
53+
- name: "pow2_extended"
54+
min_power: 10 # Optional, default is 10 (1024)
55+
max_power: 14 # Optional, default is 14 (16,384)
56+
- name: "sweep"
57+
min_power: 8 # Optional, default is 8 (256)
58+
max_power: 15 # Optional, default is 15 (32,768)
4959
high_precision_dtype: "torch.bfloat16"
5060
compile: "max-autotune" # Options: "default", "max-autotune", "false"
5161
device: "cuda" # Options: "cuda", "mps", "xpu", "cpu"
@@ -54,6 +64,13 @@ model_params:
5464
5565
## Configuration Options
5666
67+
### Shape Generation Options
68+
- `custom`: Manually specify shapes as a list of [m, k, n] dimensions
69+
- `llama`: Use LLaMa 2 70B single-node weight shapes (assumes fused attn.wqkv and ffn.w13)
70+
- `pow2`: Generate shapes with dimensions that are powers of 2 (e.g., 1024, 2048, 4096, etc.)
71+
- `pow2_extended`: Generate shapes with dimensions that are powers of 2 and powers of 2 + half (e.g., 1024, 1536, 2048, 3072, etc.)
72+
- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions
73+
5774
### Quantization Methods
5875
Currently, quantization string is in same format as the one being passed in llama/generate.py.
5976
- `baseline`: No quantization

benchmarks/microbenchmarks/benchmark_runner.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,50 @@ def get_shapes_for_config(
4848
name = shape_config["name"]
4949
if name == "custom":
5050
shapes.extend([(name, shape) for shape in shape_config["shapes"]])
51+
elif name == "llama":
52+
# LLaMa 2 70B single-node weight shapes
53+
# assumes fused attn.wqkv and ffn.w13
54+
bsz, seq_len = 4, 4096
55+
M = bsz * seq_len
56+
llama_shapes = {
57+
"attn.wqkv": (M, 8192, 1280),
58+
"attn.w0": (M, 1024, 8192),
59+
"ffn.w13": (M, 8192, 7168),
60+
"ffn.w2": (M, 3584, 8192),
61+
}
62+
shapes.extend([(f"{name}_{k}", v) for k, v in llama_shapes.items()])
63+
elif name == "pow2":
64+
# Generate shapes with dimensions that are powers of 2
65+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
66+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
67+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
68+
val = 2**power_of_2
69+
shapes.append((f"{name}_{idx}", [val, val, val]))
70+
elif name == "pow2_extended":
71+
# Generate shapes with dimensions that are powers of 2 and powers of 2 + half
72+
min_power_of_2 = shape_config.get("min_power", 10) # 1024
73+
max_power_of_2 = shape_config.get("max_power", 14) # 16,384
74+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
75+
val1 = 2**power_of_2
76+
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
77+
shapes.append((f"{name}_{idx*2}", [val1, val1, val1]))
78+
shapes.append((f"{name}_{idx*2+1}", [val2, val2, val2]))
79+
elif name == "sweep":
80+
# Generate a sweep of shapes with different powers of 2 for M, K, N
81+
min_p2 = shape_config.get("min_power", 8) # 256
82+
max_p2 = shape_config.get("max_power", 15) # 32,768
83+
counter = 0
84+
for M_p2 in range(min_p2, max_p2 + 1):
85+
M = 2**M_p2
86+
for K_p2 in range(min_p2, max_p2 + 1):
87+
K = 2**K_p2
88+
for N_p2 in range(min_p2, max_p2 + 1):
89+
N = 2**N_p2
90+
shapes.append((f"{name}_{counter}", [M, K, N]))
91+
counter += 1
5192
else:
5293
raise NotImplementedError(
53-
f"Shape config {name} not supported. Currently only supports custom shapes."
94+
f"Shape config {name} not supported. Supported options: custom, llama, pow2, pow2_extended, sweep."
5495
)
5596
return shapes
5697

benchmarks/microbenchmarks/test/benchmark_config.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ model_params:
3131
[2048, 4096, 1024],
3232
# [4096, 4096, 1024]
3333
]
34+
# Example of using LLaMa shapes
35+
- name: "llama"
36+
# Example of using power of 2 shapes
37+
- name: "pow2"
38+
min_power: 10 # 1024
39+
max_power: 12 # 4096
40+
# Example of using extended power of 2 shapes
41+
- name: "pow2_extended"
42+
min_power: 10 # 1024
43+
max_power: 11 # 2048
44+
# Example of using sweep shapes (commented out as it generates many shapes)
45+
# - name: "sweep"
46+
# min_power: 8 # 256
47+
# max_power: 9 # 512
3448
high_precision_dtype: "torch.bfloat16"
3549
use_torch_compile: true
3650
torch_compile_mode: "max-autotune"

benchmarks/microbenchmarks/test/test_benchmark_runner.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,72 @@ def tearDown(self):
5757
shutil.rmtree(self.temp_dir)
5858

5959
def test_get_shapes_for_config(self):
60+
# Test custom shapes
6061
shapes = get_shapes_for_config(
6162
self.test_config["model_params"][0]["matrix_shapes"]
6263
)
6364
self.assertEqual(len(shapes), 1)
6465
self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024]))
6566

67+
# Test llama shapes
68+
llama_shapes = get_shapes_for_config([{"name": "llama"}])
69+
self.assertEqual(len(llama_shapes), 4) # 4 LLaMa shapes
70+
self.assertTrue(
71+
any(name.startswith("llama_attn.wqkv") for name, _ in llama_shapes)
72+
)
73+
self.assertTrue(
74+
any(name.startswith("llama_attn.w0") for name, _ in llama_shapes)
75+
)
76+
self.assertTrue(
77+
any(name.startswith("llama_ffn.w13") for name, _ in llama_shapes)
78+
)
79+
self.assertTrue(
80+
any(name.startswith("llama_ffn.w2") for name, _ in llama_shapes)
81+
)
82+
83+
# Test pow2 shapes
84+
pow2_shapes = get_shapes_for_config(
85+
[{"name": "pow2", "min_power": 10, "max_power": 12}]
86+
)
87+
self.assertEqual(len(pow2_shapes), 3) # 3 powers of 2 (10, 11, 12)
88+
self.assertEqual(pow2_shapes[0], ("pow2_0", [1024, 1024, 1024])) # 2^10
89+
self.assertEqual(pow2_shapes[1], ("pow2_1", [2048, 2048, 2048])) # 2^11
90+
self.assertEqual(pow2_shapes[2], ("pow2_2", [4096, 4096, 4096])) # 2^12
91+
92+
# Test pow2_extended shapes
93+
pow2_extended_shapes = get_shapes_for_config(
94+
[{"name": "pow2_extended", "min_power": 10, "max_power": 11}]
95+
)
96+
self.assertEqual(
97+
len(pow2_extended_shapes), 4
98+
) # 2 powers of 2, each with 2 variants
99+
self.assertEqual(
100+
pow2_extended_shapes[0], ("pow2_extended_0", [1024, 1024, 1024])
101+
) # 2^10
102+
self.assertEqual(
103+
pow2_extended_shapes[1], ("pow2_extended_1", [1536, 1536, 1536])
104+
) # 2^10 + 2^9
105+
self.assertEqual(
106+
pow2_extended_shapes[2], ("pow2_extended_2", [2048, 2048, 2048])
107+
) # 2^11
108+
self.assertEqual(
109+
pow2_extended_shapes[3], ("pow2_extended_3", [3072, 3072, 3072])
110+
) # 2^11 + 2^10
111+
112+
# Test sweep shapes (limited to a small range for testing)
113+
sweep_shapes = get_shapes_for_config(
114+
[{"name": "sweep", "min_power": 8, "max_power": 9}]
115+
)
116+
# For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations)
117+
self.assertEqual(len(sweep_shapes), 8)
118+
# Check that all shapes have the expected format
119+
for name, shape in sweep_shapes:
120+
self.assertTrue(name.startswith("sweep_"))
121+
self.assertEqual(len(shape), 3) # [M, K, N]
122+
# Check that all dimensions are powers of 2 between 2^8 and 2^9
123+
for dim in shape:
124+
self.assertTrue(dim in [256, 512]) # 2^8, 2^9
125+
66126
def test_get_param_combinations(self):
67127
model_param = self.test_config["model_params"][0]
68128
shapes, params = get_param_combinations(model_param)

benchmarks/microbenchmarks/test/test_utils.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_rms_norm(self):
171171
x = torch.randn(16, 64)
172172
out = rms_norm(x)
173173
self.assertEqual(out.shape, (16, 64))
174-
174+
175175
# Test with different eps
176176
rms_norm = RMSNorm(dim=64, eps=1e-5)
177177
out = rms_norm(x)
@@ -184,38 +184,50 @@ def test_rms_norm_linear_activation(self):
184184
out = model(x)
185185
self.assertEqual(out.shape, (16, 32))
186186
self.assertEqual(out.dtype, torch.float32)
187-
187+
188188
# Test with ReLU activation
189-
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu")
189+
model = RMSNormLinearActivation(
190+
fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="relu"
191+
)
190192
out = model(x)
191193
self.assertEqual(out.shape, (16, 32))
192194
self.assertTrue(torch.all(out >= 0)) # Check ReLU output range
193-
195+
194196
# Test with SiLU activation
195-
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu")
197+
model = RMSNormLinearActivation(
198+
fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="silu"
199+
)
196200
out = model(x)
197201
self.assertEqual(out.shape, (16, 32))
198-
202+
199203
# Test with invalid activation
200204
with self.assertRaises(ValueError):
201-
RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid")
205+
RMSNormLinearActivation(
206+
fc_dim1=64, fc_dim2=32, dtype=torch.float32, activation="invalid"
207+
)
202208

203209
def test_transformer_block(self):
204210
# Test with default parameters
205-
model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32)
211+
model = TransformerBlock(
212+
hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32
213+
)
206214
x = torch.randn(16, 16, 64) # [batch_size, seq_len, hidden_dim]
207215
out = model(x)
208216
self.assertEqual(out.shape, (16, 16, 64))
209217
self.assertEqual(out.dtype, torch.float32)
210-
218+
211219
# Test with different parameters
212-
model = TransformerBlock(hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32)
220+
model = TransformerBlock(
221+
hidden_dim=128, num_heads=4, mlp_ratio=2, dtype=torch.float32
222+
)
213223
x = torch.randn(8, 32, 128)
214224
out = model(x)
215225
self.assertEqual(out.shape, (8, 32, 128))
216-
226+
217227
# Test with different head dimensions
218-
model = TransformerBlock(hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32)
228+
model = TransformerBlock(
229+
hidden_dim=96, num_heads=6, mlp_ratio=3, dtype=torch.float32
230+
)
219231
x = torch.randn(4, 8, 96)
220232
out = model(x)
221233
self.assertEqual(out.shape, (4, 8, 96))
@@ -255,7 +267,7 @@ def test_create_model_and_input(self):
255267
)
256268
self.assertIsInstance(model, RMSNormLinearActivation)
257269
self.assertEqual(input_data.shape, (m, k))
258-
270+
259271
# Test TransformerBlock
260272
model, input_data = create_model_and_input(
261273
model_type="transformer_block",
@@ -266,40 +278,50 @@ def test_create_model_and_input(self):
266278
device="cpu",
267279
)
268280
self.assertIsInstance(model, TransformerBlock)
269-
self.assertEqual(input_data.shape, (m, 16, k)) # [batch_size, seq_len, hidden_dim]
281+
self.assertEqual(
282+
input_data.shape, (m, 16, k)
283+
) # [batch_size, seq_len, hidden_dim]
270284

271285
def test_quantization_on_models(self):
272286
# Test quantization on RMSNormLinearActivation
273287
model = RMSNormLinearActivation(fc_dim1=64, fc_dim2=32, dtype=torch.float32)
274288
x = torch.randn(16, 64)
275-
289+
276290
# Test with Int8WeightOnlyConfig
277291
config = string_to_config(quantization="int8wo", sparsity=None)
278292
if config is not None:
279293
# Skip quantization test if torchao.quantization.quantize is not available
280294
try:
281295
from torchao.quantization import quantize
296+
282297
quantized_model = quantize(model, config)
283298
out = quantized_model(x)
284299
self.assertEqual(out.shape, (16, 32))
285300
except ImportError:
286-
print("Skipping quantization test: torchao.quantization.quantize not available")
287-
301+
print(
302+
"Skipping quantization test: torchao.quantization.quantize not available"
303+
)
304+
288305
# Test quantization on TransformerBlock
289-
model = TransformerBlock(hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32)
306+
model = TransformerBlock(
307+
hidden_dim=64, num_heads=8, mlp_ratio=4, dtype=torch.float32
308+
)
290309
x = torch.randn(16, 16, 64)
291-
310+
292311
# Test with Int8WeightOnlyConfig
293312
config = string_to_config(quantization="int8wo", sparsity=None)
294313
if config is not None:
295314
# Skip quantization test if torchao.quantization.quantize is not available
296315
try:
297316
from torchao.quantization import quantize
317+
298318
quantized_model = quantize(model, config)
299319
out = quantized_model(x)
300320
self.assertEqual(out.shape, (16, 16, 64))
301321
except ImportError:
302-
print("Skipping quantization test: torchao.quantization.quantize not available")
322+
print(
323+
"Skipping quantization test: torchao.quantization.quantize not available"
324+
)
303325

304326
def test_generate_results_csv(self):
305327
results = [

0 commit comments

Comments
 (0)