@@ -209,7 +209,10 @@ def __repr__(self):
209
209
210
210
def get_text (self , enable_debug_info = False ):
211
211
return str (
212
- self .module .operation .get_asm (enable_debug_info = enable_debug_info )
212
+ self .module .operation .get_asm (
213
+ enable_debug_info = enable_debug_info ,
214
+ large_elements_limit = 16 ,
215
+ )
213
216
)
214
217
215
218
@property
@@ -326,8 +329,24 @@ def _convert_q_dq_per_channel_args_to_list(
326
329
327
330
def exported_program_to_mlir (
328
331
exported_program : torch .export .ExportedProgram ,
332
+ * ,
333
+ ir_context : ir .Context | None = None ,
334
+ _pre_lower_pass : (
335
+ Callable [[torch .export .ExportedProgram ], None ] | None
336
+ ) = None ,
329
337
) -> MlirLowered :
330
- """Lower the exported program to MLIR."""
338
+ """Lower the exported program to MLIR.
339
+
340
+ Args:
341
+ exported_program: The exported program to lower.
342
+ ir_context: The MLIR context to use. If not provided, a new context will be
343
+ created.
344
+ _pre_lower_pass: A function to run on exported program before lowering.
345
+
346
+ Returns:
347
+ The lowered MLIR module, metadata, and weight tensors bundle from exported
348
+ program.
349
+ """
331
350
exported_program = fx_infra .safe_run_decompositions (
332
351
exported_program ,
333
352
fx_infra .decomp .pre_lower_decomp (),
@@ -340,10 +359,16 @@ def exported_program_to_mlir(
340
359
# Do not call run_decompositions after applying the passes.
341
360
_convert_q_dq_per_channel_args_to_list (exported_program )
342
361
343
- with export_utils .create_ir_context () as context , ir .Location .unknown ():
362
+ if _pre_lower_pass :
363
+ _pre_lower_pass (exported_program )
364
+
365
+ if not ir_context :
366
+ ir_context = export_utils .create_ir_context ()
367
+
368
+ with ir_context , ir .Location .unknown ():
344
369
345
370
module = ir .Module .create ()
346
- lctx = LoweringContext (context , module )
371
+ lctx = LoweringContext (ir_context , module )
347
372
interpreter = LoweringInterpreter (exported_program .graph_module , lctx )
348
373
ir_flat_inputs , export_flat_args , tensor_metas = _build_flat_inputs (
349
374
exported_program
@@ -382,7 +407,6 @@ def exported_program_to_mlir(
382
407
383
408
main_func .attributes ["sym_visibility" ] = ir .StringAttr .get ("public" )
384
409
temp_func .erase ()
385
-
386
410
module .operation .verify ()
387
411
388
412
input_signature = []
@@ -422,5 +446,5 @@ def exported_program_to_mlir(
422
446
for tensor_meta in _get_output_metas (exported_program )
423
447
]
424
448
return MlirLowered (
425
- context , module , state_dict , input_signature , output_signature
449
+ ir_context , module , state_dict , input_signature , output_signature
426
450
)
0 commit comments