Skip to content

Commit

Permalink
[onnx] Enable --iree-input-type=onnx (iree-org#15995)
Browse files Browse the repository at this point in the history
Includes a one-line fix to torch-mlir (already landed) that fixes an
unused variable warning.

With this, the following now works (for a model.onnx from the test
suite):

```
python -m iree.compiler.tools.import_onnx ~/tmp/model.onnx | ./tools/iree-compile --iree-input-type=onnx --iree-hal-target-backends=vmvx - -o /dev/null
```

We've still got a way to go on the conversions before there is enough
coverage to be generally usable.
  • Loading branch information
stellaraccident authored Dec 21, 2023
1 parent 7b7ffeb commit 5e33995
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
1 change: 1 addition & 0 deletions compiler/plugins/input/Torch/torch-iree/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_cc_library(
torch-mlir::TorchConversionDialectIR
torch-mlir::TorchDialectIR
torch-mlir::TorchDialectPasses
torch-mlir::TorchOnnxToTorchPasses
torch-mlir::ConversionPasses
torch-mlir-dialects::TMTensorDialectIR
PUBLIC
Expand Down
12 changes: 11 additions & 1 deletion compiler/plugins/input/Torch/torch-iree/PluginRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "torch-iree/InputConversion/Passes.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir/Conversion/Passes.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
Expand Down Expand Up @@ -39,6 +40,7 @@ struct TorchSession
mlir::torch::registerTorchPasses();
mlir::torch::registerTorchConversionPasses();
mlir::torch::registerConversionPasses();
mlir::torch::onnx_c::registerTorchOnnxToTorchPasses();
TorchInput::registerTMTensorConversionPasses();
}

Expand All @@ -51,14 +53,21 @@ struct TorchSession

bool extendCustomInputConversionPassPipeline(
OpPassManager &passManager, std::string_view typeMnemonic) override {
if (typeMnemonic == "torch") {
if (typeMnemonic == "onnx") {
// ONNX input is a pre-processing step to torch.
passManager.addNestedPass<func::FuncOp>(
mlir::torch::onnx_c::createTorchOnnxToTorchPass());
}

if (typeMnemonic == "torch" || typeMnemonic == "onnx") {
TorchInput::TorchToIREELoweringPipelineOptions torchOptions;
torchOptions.strictSymbolicShapes = options.strictSymbolicShapes;
TorchInput::createTorchToIREEPipeline(passManager, torchOptions);
passManager.addNestedPass<func::FuncOp>(
TorchInput::createConvertTMTensorToLinalgExtPass());
return true;
}

// TODO: Retire the tm_tensor input pipeline once we are fully switched
// to the 'torch' pipeline, which handles everything from the 'torch'
// dialect down (vs just 'tm_tensor' which was converting a couple of
Expand All @@ -74,6 +83,7 @@ struct TorchSession
void populateCustomInputConversionTypes(StringSet<> &typeMnemonics) override {
typeMnemonics.insert("tm_tensor");
typeMnemonics.insert("torch");
typeMnemonics.insert("onnx");
}

void populateDetectedCustomInputConversionTypes(
Expand Down
35 changes: 35 additions & 0 deletions compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,41 @@ iree_tablegen_library(
-gen-pass-capi-impl Conversion/Passes.capi.cpp.inc
)

###############################################################################
# TorchOnnxToTorch
###############################################################################

file(GLOB _TorchOnnxToTorchPasses_SRCS
"${TORCH_MLIR_ROOT_DIR}/lib/Conversion/TorchOnnxToTorch/*.cpp"
)
iree_cc_library(
NAME
TorchOnnxToTorchPasses
SRCS
"${TORCH_MLIR_ROOT_DIR}/lib/Conversion/TorchOnnxToTorch/Passes.cpp"
${_TorchOnnxToTorchPasses_SRCS}
DEPS
::defs
::TorchConversionDialectIR
::TorchDialectIR
::TorchOnnxToTorchPassesGen
MLIRArithDialect
MLIRFuncDialect
MLIRIR
MLIRPass
)

iree_tablegen_library(
NAME
TorchOnnxToTorchPassesGen
TD_FILE
"${TORCH_MLIR_ROOT_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td"
OUTS
-gen-pass-decls Conversion/TorchOnnxToTorch/Passes.h.inc
-gen-pass-capi-header Conversion/TorchOnnxToTorch/Passes.capi.h.inc
-gen-pass-capi-impl Conversion/TorchOnnxToTorch/Passes.capi.cpp.inc
)

###############################################################################
# CAPI
###############################################################################
Expand Down
2 changes: 1 addition & 1 deletion third_party/torch-mlir

0 comments on commit 5e33995

Please sign in to comment.