Skip to content

Commit 5eb1bb0

Browse files
chunnienccopybara-github
authored andcommitted
Add experimental API to odml torch export
PiperOrigin-RevId: 744813361
1 parent 94f36d6 commit 5eb1bb0

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

ai_edge_torch/odml_torch/export.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ def __repr__(self):
209209

210210
def get_text(self, enable_debug_info=False):
211211
return str(
212-
self.module.operation.get_asm(enable_debug_info=enable_debug_info)
212+
self.module.operation.get_asm(
213+
enable_debug_info=enable_debug_info,
214+
large_elements_limit=16,
215+
)
213216
)
214217

215218
@property
@@ -326,8 +329,24 @@ def _convert_q_dq_per_channel_args_to_list(
326329

327330
def exported_program_to_mlir(
328331
exported_program: torch.export.ExportedProgram,
332+
*,
333+
ir_context: ir.Context | None = None,
334+
_pre_lower_pass: (
335+
Callable[[torch.export.ExportedProgram], None] | None
336+
) = None,
329337
) -> MlirLowered:
330-
"""Lower the exported program to MLIR."""
338+
"""Lower the exported program to MLIR.
339+
340+
Args:
341+
exported_program: The exported program to lower.
342+
ir_context: The MLIR context to use. If not provided, a new context will be
343+
created.
344+
_pre_lower_pass: A function to run on exported program before lowering.
345+
346+
Returns:
347+
The lowered MLIR module, metadata, and weight tensors bundle from exported
348+
program.
349+
"""
331350
exported_program = fx_infra.safe_run_decompositions(
332351
exported_program,
333352
fx_infra.decomp.pre_lower_decomp(),
@@ -340,10 +359,16 @@ def exported_program_to_mlir(
340359
# Do not call run_decompositions after applying the passes.
341360
_convert_q_dq_per_channel_args_to_list(exported_program)
342361

343-
with export_utils.create_ir_context() as context, ir.Location.unknown():
362+
if _pre_lower_pass:
363+
_pre_lower_pass(exported_program)
364+
365+
if not ir_context:
366+
ir_context = export_utils.create_ir_context()
367+
368+
with ir_context, ir.Location.unknown():
344369

345370
module = ir.Module.create()
346-
lctx = LoweringContext(context, module)
371+
lctx = LoweringContext(ir_context, module)
347372
interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
348373
ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
349374
exported_program
@@ -382,7 +407,6 @@ def exported_program_to_mlir(
382407

383408
main_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
384409
temp_func.erase()
385-
386410
module.operation.verify()
387411

388412
input_signature = []
@@ -422,5 +446,5 @@ def exported_program_to_mlir(
422446
for tensor_meta in _get_output_metas(exported_program)
423447
]
424448
return MlirLowered(
425-
context, module, state_dict, input_signature, output_signature
449+
ir_context, module, state_dict, input_signature, output_signature
426450
)

0 commit comments

Comments
 (0)