Skip to content

Commit 411afae

Browse files
committed
Move bufferization and pm run into schedule application.
1 parent dff29c9 commit 411afae

File tree

1 file changed

+12
-44
lines changed

1 file changed

+12
-44
lines changed

python/examples/llama/test_llama3.py

Lines changed: 12 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,22 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
128128
return schedule
129129

130130

131+
def bufferize_module(ctx: ir.Context, kernel: ir.Module) -> None:
132+
with ctx:
133+
pm = PassManager("builtin.module")
134+
pm.add("one-shot-bufferize{bufferize-function-boundaries}")
135+
pm.run(kernel.operation)
136+
137+
131138
def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
139+
bufferize_module(kernel.context, kernel)
132140
interpreter.apply_named_sequence(
133141
payload_root=kernel,
134142
transform_root=schedule.body.operations[0],
135143
transform_module=schedule,
136144
)
137-
138-
139-
def bufferize_module(ctx: ir.Context, kernel: ir.Module) -> None:
140-
with ctx:
141-
pm = PassManager("builtin.module")
142-
pm.add("one-shot-bufferize{bufferize-function-boundaries}")
143-
pm.run(kernel.operation)
145+
pm = create_pass_pipeline(kernel.context)
146+
pm.run(kernel.operation)
144147

145148

146149
#### IR builders #####
@@ -1196,11 +1199,8 @@ def bin_op(a, b, out):
11961199

11971200
ir_type = to_ir_type(elem_type, ctx)
11981201
module = generate_module(ir_type)
1199-
bufferize_module(ctx, module)
12001202
schedule = create_schedule(ctx)
12011203
apply_schedule(module, schedule)
1202-
pm = create_pass_pipeline(ctx)
1203-
pm.run(module.operation)
12041204

12051205
eng = ExecutionEngine(module, opt_level=2)
12061206
func_ptr = eng.lookup("bin_op")
@@ -1257,11 +1257,8 @@ def unary_op(a, out):
12571257

12581258
ir_type = to_ir_type(elem_type, ctx)
12591259
module = generate_module(ir_type)
1260-
bufferize_module(ctx, module)
12611260
schedule = create_schedule(ctx)
12621261
apply_schedule(module, schedule)
1263-
pm = create_pass_pipeline(ctx)
1264-
pm.run(module.operation)
12651262

12661263
eng = ExecutionEngine(module, opt_level=2)
12671264
func_ptr = eng.lookup("unary_op")
@@ -1299,12 +1296,9 @@ def rms_norm(a, out):
12991296

13001297
ir_type = to_ir_type(elem_type, ctx)
13011298
module = generate_module(ir_type)
1302-
print(module)
1303-
bufferize_module(ctx, module)
13041299
schedule = create_schedule(ctx)
13051300
apply_schedule(module, schedule)
1306-
pm = create_pass_pipeline(ctx)
1307-
pm.run(module.operation)
1301+
13081302
eng = ExecutionEngine(module, opt_level=2)
13091303
func_ptr = eng.lookup("rms_norm")
13101304
torch_dtype = lh_utils.mlir_type_to_torch_dtype(ir_type)
@@ -1352,11 +1346,8 @@ def linear_op(x, w, b, out):
13521346

13531347
ir_type = to_ir_type("f32", ctx)
13541348
module = generate_module(ir_type, shape, in_features, out_features)
1355-
bufferize_module(ctx, module)
13561349
schedule = create_schedule(ctx)
13571350
apply_schedule(module, schedule)
1358-
pm = create_pass_pipeline(ctx)
1359-
pm.run(module.operation)
13601351

13611352
eng = ExecutionEngine(module, opt_level=2)
13621353
func_ptr = eng.lookup("linear_op")
@@ -1398,11 +1389,8 @@ def polar_op(magnitude, angle, out):
13981389

13991390
ir_type = to_ir_type("f32", ctx)
14001391
module = generate_module(ir_type)
1401-
bufferize_module(ctx, module)
14021392
schedule = create_schedule(ctx)
14031393
apply_schedule(module, schedule)
1404-
pm = create_pass_pipeline(ctx)
1405-
pm.run(module.operation)
14061394

14071395
eng = ExecutionEngine(module, opt_level=2)
14081396
func_ptr = eng.lookup("polar_op")
@@ -1438,11 +1426,8 @@ def repeat_kv_op(x, out):
14381426
n_rep = 4
14391427
ir_type = to_ir_type("f32", ctx)
14401428
module = generate_module(ir_type, n_rep)
1441-
bufferize_module(ctx, module)
14421429
schedule = create_schedule(ctx)
14431430
apply_schedule(module, schedule)
1444-
pm = create_pass_pipeline(ctx)
1445-
pm.run(module.operation)
14461431

14471432
eng = ExecutionEngine(module, opt_level=2)
14481433
func_ptr = eng.lookup("repeat_kv_op")
@@ -1481,11 +1466,8 @@ def reshape_for_broadcast_op(freqs_cis, x, out):
14811466

14821467
ir_type = to_ir_type("f32", ctx)
14831468
module = generate_module(ir_type)
1484-
bufferize_module(ctx, module)
14851469
schedule = create_schedule(ctx)
14861470
apply_schedule(module, schedule)
1487-
pm = create_pass_pipeline(ctx)
1488-
pm.run(module.operation)
14891471

14901472
eng = ExecutionEngine(module, opt_level=2)
14911473
func_ptr = eng.lookup("reshape_for_broadcast")
@@ -1528,11 +1510,8 @@ def view_as_complex_op(x, out):
15281510

15291511
ir_type = to_ir_type("f32", ctx)
15301512
module = generate_module(ir_type)
1531-
bufferize_module(ctx, module)
15321513
schedule = create_schedule(ctx)
15331514
apply_schedule(module, schedule)
1534-
pm = create_pass_pipeline(ctx)
1535-
pm.run(module.operation)
15361515

15371516
eng = ExecutionEngine(module, opt_level=2)
15381517
func_ptr = eng.lookup("view_as_complex_op")
@@ -1569,11 +1548,8 @@ def as_real_op(x, out):
15691548

15701549
ir_type = to_ir_type("f32", ctx)
15711550
module = generate_module(ir_type)
1572-
bufferize_module(ctx, module)
15731551
schedule = create_schedule(ctx)
15741552
apply_schedule(module, schedule)
1575-
pm = create_pass_pipeline(ctx)
1576-
pm.run(module.operation)
15771553

15781554
eng = ExecutionEngine(module, opt_level=2)
15791555
func_ptr = eng.lookup("as_real_op")
@@ -1636,11 +1612,8 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
16361612
freqs_cis_shape=freqs_cis_shape,
16371613
elty=ir_type,
16381614
)
1639-
bufferize_module(ctx, module)
16401615
schedule = create_schedule(ctx)
16411616
apply_schedule(module, schedule)
1642-
pm = create_pass_pipeline(ctx)
1643-
pm.run(module.operation)
16441617

16451618
eng = ExecutionEngine(module, opt_level=2)
16461619
func_ptr = eng.lookup("rotary_emb")
@@ -1714,11 +1687,8 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out):
17141687

17151688
ir_type = to_ir_type("f32", ctx)
17161689
module = generate_module(ir_type)
1717-
bufferize_module(ctx, module)
17181690
schedule = create_schedule(ctx)
17191691
apply_schedule(module, schedule)
1720-
pm = create_pass_pipeline(ctx)
1721-
pm.run(module.operation)
17221692

17231693
eng = ExecutionEngine(module, opt_level=2)
17241694
func_ptr = eng.lookup("feed_forward")
@@ -1866,11 +1836,9 @@ def attention_op(x, wq, wk, wv, wo, freqs_cis, mask, out):
18661836

18671837
ir_type = to_ir_type("f32", ctx)
18681838
module = generate_module(ir_type, model_args)
1869-
bufferize_module(ctx, module)
18701839
schedule = create_schedule(ctx)
18711840
apply_schedule(module, schedule)
1872-
pm = create_pass_pipeline(ctx)
1873-
pm.run(module.operation)
1841+
18741842
eng = ExecutionEngine(module, opt_level=2)
18751843
func_ptr = eng.lookup("attention_op")
18761844

0 commit comments

Comments
 (0)