@@ -128,19 +128,22 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
128128 return schedule
129129
130130
131+ def bufferize_module (ctx : ir .Context , kernel : ir .Module ) -> None :
132+ with ctx :
133+ pm = PassManager ("builtin.module" )
134+ pm .add ("one-shot-bufferize{bufferize-function-boundaries}" )
135+ pm .run (kernel .operation )
136+
137+
131138def apply_schedule (kernel : ir .Module , schedule : ir .Module ) -> None :
139+ bufferize_module (kernel .context , kernel )
132140 interpreter .apply_named_sequence (
133141 payload_root = kernel ,
134142 transform_root = schedule .body .operations [0 ],
135143 transform_module = schedule ,
136144 )
137-
138-
139- def bufferize_module (ctx : ir .Context , kernel : ir .Module ) -> None :
140- with ctx :
141- pm = PassManager ("builtin.module" )
142- pm .add ("one-shot-bufferize{bufferize-function-boundaries}" )
143- pm .run (kernel .operation )
145+ pm = create_pass_pipeline (kernel .context )
146+ pm .run (kernel .operation )
144147
145148
146149#### IR builders #####
@@ -1196,11 +1199,8 @@ def bin_op(a, b, out):
11961199
11971200 ir_type = to_ir_type (elem_type , ctx )
11981201 module = generate_module (ir_type )
1199- bufferize_module (ctx , module )
12001202 schedule = create_schedule (ctx )
12011203 apply_schedule (module , schedule )
1202- pm = create_pass_pipeline (ctx )
1203- pm .run (module .operation )
12041204
12051205 eng = ExecutionEngine (module , opt_level = 2 )
12061206 func_ptr = eng .lookup ("bin_op" )
@@ -1257,11 +1257,8 @@ def unary_op(a, out):
12571257
12581258 ir_type = to_ir_type (elem_type , ctx )
12591259 module = generate_module (ir_type )
1260- bufferize_module (ctx , module )
12611260 schedule = create_schedule (ctx )
12621261 apply_schedule (module , schedule )
1263- pm = create_pass_pipeline (ctx )
1264- pm .run (module .operation )
12651262
12661263 eng = ExecutionEngine (module , opt_level = 2 )
12671264 func_ptr = eng .lookup ("unary_op" )
@@ -1299,12 +1296,9 @@ def rms_norm(a, out):
12991296
13001297 ir_type = to_ir_type (elem_type , ctx )
13011298 module = generate_module (ir_type )
1302- print (module )
1303- bufferize_module (ctx , module )
13041299 schedule = create_schedule (ctx )
13051300 apply_schedule (module , schedule )
1306- pm = create_pass_pipeline (ctx )
1307- pm .run (module .operation )
1301+
13081302 eng = ExecutionEngine (module , opt_level = 2 )
13091303 func_ptr = eng .lookup ("rms_norm" )
13101304 torch_dtype = lh_utils .mlir_type_to_torch_dtype (ir_type )
@@ -1352,11 +1346,8 @@ def linear_op(x, w, b, out):
13521346
13531347 ir_type = to_ir_type ("f32" , ctx )
13541348 module = generate_module (ir_type , shape , in_features , out_features )
1355- bufferize_module (ctx , module )
13561349 schedule = create_schedule (ctx )
13571350 apply_schedule (module , schedule )
1358- pm = create_pass_pipeline (ctx )
1359- pm .run (module .operation )
13601351
13611352 eng = ExecutionEngine (module , opt_level = 2 )
13621353 func_ptr = eng .lookup ("linear_op" )
@@ -1398,11 +1389,8 @@ def polar_op(magnitude, angle, out):
13981389
13991390 ir_type = to_ir_type ("f32" , ctx )
14001391 module = generate_module (ir_type )
1401- bufferize_module (ctx , module )
14021392 schedule = create_schedule (ctx )
14031393 apply_schedule (module , schedule )
1404- pm = create_pass_pipeline (ctx )
1405- pm .run (module .operation )
14061394
14071395 eng = ExecutionEngine (module , opt_level = 2 )
14081396 func_ptr = eng .lookup ("polar_op" )
@@ -1438,11 +1426,8 @@ def repeat_kv_op(x, out):
14381426 n_rep = 4
14391427 ir_type = to_ir_type ("f32" , ctx )
14401428 module = generate_module (ir_type , n_rep )
1441- bufferize_module (ctx , module )
14421429 schedule = create_schedule (ctx )
14431430 apply_schedule (module , schedule )
1444- pm = create_pass_pipeline (ctx )
1445- pm .run (module .operation )
14461431
14471432 eng = ExecutionEngine (module , opt_level = 2 )
14481433 func_ptr = eng .lookup ("repeat_kv_op" )
@@ -1481,11 +1466,8 @@ def reshape_for_broadcast_op(freqs_cis, x, out):
14811466
14821467 ir_type = to_ir_type ("f32" , ctx )
14831468 module = generate_module (ir_type )
1484- bufferize_module (ctx , module )
14851469 schedule = create_schedule (ctx )
14861470 apply_schedule (module , schedule )
1487- pm = create_pass_pipeline (ctx )
1488- pm .run (module .operation )
14891471
14901472 eng = ExecutionEngine (module , opt_level = 2 )
14911473 func_ptr = eng .lookup ("reshape_for_broadcast" )
@@ -1528,11 +1510,8 @@ def view_as_complex_op(x, out):
15281510
15291511 ir_type = to_ir_type ("f32" , ctx )
15301512 module = generate_module (ir_type )
1531- bufferize_module (ctx , module )
15321513 schedule = create_schedule (ctx )
15331514 apply_schedule (module , schedule )
1534- pm = create_pass_pipeline (ctx )
1535- pm .run (module .operation )
15361515
15371516 eng = ExecutionEngine (module , opt_level = 2 )
15381517 func_ptr = eng .lookup ("view_as_complex_op" )
@@ -1569,11 +1548,8 @@ def as_real_op(x, out):
15691548
15701549 ir_type = to_ir_type ("f32" , ctx )
15711550 module = generate_module (ir_type )
1572- bufferize_module (ctx , module )
15731551 schedule = create_schedule (ctx )
15741552 apply_schedule (module , schedule )
1575- pm = create_pass_pipeline (ctx )
1576- pm .run (module .operation )
15771553
15781554 eng = ExecutionEngine (module , opt_level = 2 )
15791555 func_ptr = eng .lookup ("as_real_op" )
@@ -1636,11 +1612,8 @@ def rotary_emb(xq, xk, freqs_cis, xq_out, xk_out):
16361612 freqs_cis_shape = freqs_cis_shape ,
16371613 elty = ir_type ,
16381614 )
1639- bufferize_module (ctx , module )
16401615 schedule = create_schedule (ctx )
16411616 apply_schedule (module , schedule )
1642- pm = create_pass_pipeline (ctx )
1643- pm .run (module .operation )
16441617
16451618 eng = ExecutionEngine (module , opt_level = 2 )
16461619 func_ptr = eng .lookup ("rotary_emb" )
@@ -1714,11 +1687,8 @@ def feed_forward(x, w1, b1, w2, b2, w3, b3, out):
17141687
17151688 ir_type = to_ir_type ("f32" , ctx )
17161689 module = generate_module (ir_type )
1717- bufferize_module (ctx , module )
17181690 schedule = create_schedule (ctx )
17191691 apply_schedule (module , schedule )
1720- pm = create_pass_pipeline (ctx )
1721- pm .run (module .operation )
17221692
17231693 eng = ExecutionEngine (module , opt_level = 2 )
17241694 func_ptr = eng .lookup ("feed_forward" )
@@ -1866,11 +1836,9 @@ def attention_op(x, wq, wk, wv, wo, freqs_cis, mask, out):
18661836
18671837 ir_type = to_ir_type ("f32" , ctx )
18681838 module = generate_module (ir_type , model_args )
1869- bufferize_module (ctx , module )
18701839 schedule = create_schedule (ctx )
18711840 apply_schedule (module , schedule )
1872- pm = create_pass_pipeline (ctx )
1873- pm .run (module .operation )
1841+
18741842 eng = ExecutionEngine (module , opt_level = 2 )
18751843 func_ptr = eng .lookup ("attention_op" )
18761844
0 commit comments