Skip to content

Commit 110e1f4

Browse files
kauvar01wonjeon
authored andcommitted
[mlir][tosa] TosaInputShape supports functions with multiple arguments
The new command-line syntax is --experimental-tosa-input-shape="args=arg0:5x5,arg8:2x9" etc. Signed-off-by: Kaushik Varadharajan <[email protected]> Change-Id: I393d51a89a9017212437bda40a0100c881198777
1 parent f242360 commit 110e1f4

File tree

5 files changed

+214
-0
lines changed

5 files changed

+214
-0
lines changed

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ void populateTosaConstantReduction(MLIRContext *ctx,
4242
void populateTosaTypeConversion(TypeConverter &converter);
4343

4444
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
45+
std::unique_ptr<Pass>
46+
createTosaInputShapePass(std::vector<std::string> args = {});
4547

4648
#define GEN_PASS_REGISTRATION
4749
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,25 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
127127
}];
128128
}
129129

130+
def TosaInputShape : Pass<"experimental-tosa-input-shape", "func::FuncOp"> {
131+
let summary = "Override dynamic input shapes of function arguments to specified static shapes.";
132+
let description = [{
133+
Pass that overrides the dynamic input shapes of function arguments to specified static shapes.
134+
It is an error if a specified static shape conflicts with the static dimensions in an original input shape.
135+
}];
136+
137+
let constructor = "tosa::createTosaInputShapePass()";
138+
let dependentDialects = [
139+
"func::FuncDialect",
140+
"tensor::TensorDialect",
141+
"tosa::TosaDialect",
142+
];
143+
let options = [
144+
ListOption<"args", "args", "std::string",
145+
"Comma-separated list of shape descriptions. Each description contains the "
146+
"argument name, a colon, and a shape with dimensions separated by x "
147+
"(e.g. arg0:5x5,arg3:2x64)">,
148+
];
149+
}
150+
130151
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
1010
TosaTypeConverters.cpp
1111
TosaProfileCompliance.cpp
1212
TosaValidation.cpp
13+
TosaInputShape.cpp
1314

1415
ADDITIONAL_HEADER_DIRS
1516
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
//===- TosaInputShape.cpp -------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Change input shape of function argument to specified shape.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
14+
15+
#include "mlir/Dialect/Func/IR/FuncOps.h"
16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
17+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
18+
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
19+
#include "mlir/IR/Builders.h"
20+
#include "mlir/IR/BuiltinOps.h"
21+
#include "mlir/IR/IRMapping.h"
22+
#include "mlir/IR/Matchers.h"
23+
#include "mlir/Interfaces/InferTypeOpInterface.h"
24+
#include "mlir/Pass/Pass.h"
25+
#include "mlir/Transforms/DialectConversion.h"
26+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27+
#include "llvm/Support/FormatVariadic.h"
28+
29+
namespace mlir {
30+
namespace tosa {
31+
#define GEN_PASS_DEF_TOSAINPUTSHAPE
32+
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
33+
} // namespace tosa
34+
} // namespace mlir
35+
36+
using namespace mlir;
37+
using namespace mlir::tosa;
38+
39+
namespace {
40+
41+
std::pair<std::vector<std::pair<size_t, std::vector<int64_t>>>, std::string>
42+
parse_input_shapes(std::vector<std::string> args) {
43+
/**
44+
* This function returns two values: a vector of parsed arguments, and an
45+
* optional error message. Each arguments contains its argument number and the
46+
* shape. For example:
47+
* "args=arg0:5x10,arg8:3x9" => {{{0, {5, 10}}, {8, {3, 9}}}, ""}
48+
* "args=arg0:" => {{}, "error message"}
49+
*/
50+
51+
std::vector<std::pair<size_t, std::vector<int64_t>>> shapes;
52+
53+
for (std::string arg : args) {
54+
if (arg.substr(0, 3) != "arg") {
55+
return {{}, "Arguments must start with 'arg'"};
56+
}
57+
58+
char *endptr;
59+
size_t argnum = std::strtoul(&arg[3], &endptr, /*base=*/10);
60+
if (*endptr != ':') {
61+
return {{}, "Invalid argument name"};
62+
}
63+
std::string shape_str = endptr + 1;
64+
65+
std::vector<int64_t> curr;
66+
while (!shape_str.empty()) {
67+
size_t dim = std::strtoul(shape_str.data(), &endptr, /*base=*/10);
68+
if ((*endptr != '\0' && *endptr != 'x') || shape_str == endptr) {
69+
return {{}, "Invalid input shape description"};
70+
}
71+
curr.push_back(dim);
72+
if (*endptr == '\0') {
73+
break;
74+
}
75+
shape_str = endptr + 1;
76+
}
77+
shapes.push_back({argnum, curr});
78+
}
79+
return {shapes, ""};
80+
}
81+
82+
/// Pass that change function input shapes to specified static input shapes
83+
struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
84+
public:
85+
TosaInputShape() = default;
86+
explicit TosaInputShape(std::vector<std::string> args) : TosaInputShape() {
87+
this->args = args;
88+
}
89+
void runOnOperation() override {
90+
func::FuncOp func = getOperation();
91+
auto [args_parsed, args_parse_err] = parse_input_shapes(args);
92+
93+
if (!args_parse_err.empty()) {
94+
func.emitError() << args_parse_err;
95+
return;
96+
}
97+
98+
for (auto &block : func.getBody()) {
99+
100+
for (auto [argnum, shape] : args_parsed) {
101+
if (argnum >= block.getNumArguments()) {
102+
func.emitError() << "arg" << argnum << " doesn't exist.";
103+
return;
104+
}
105+
BlockArgument block_arg = block.getArgument(argnum);
106+
Type arg_type = block_arg.getType();
107+
TensorType tensor_type = cast<TensorType>(arg_type);
108+
if (failed(mlir::verifyCompatibleShape(tensor_type.getShape(), shape))) {
109+
func->emitError()
110+
<< "arg" << argnum << " has incompatible shape with input shape.";
111+
return;
112+
}
113+
SmallVector<int64_t> new_shape(shape.begin(), shape.end());
114+
auto new_tensor_type =
115+
tensor_type.cloneWith(new_shape, tensor_type.getElementType());
116+
block_arg.setType(new_tensor_type);
117+
}
118+
119+
bool found_func_op = false;
120+
121+
for (Operation &op : block) {
122+
// Update result shape for func.func
123+
func::FuncOp funcOp = mlir::dyn_cast<func::FuncOp>(op.getParentOp());
124+
if (funcOp && !found_func_op) {
125+
FunctionType old_function_type = funcOp.getFunctionType();
126+
std::vector<Type> inputs = old_function_type.getInputs();
127+
128+
for (auto [argnum, shape] : args_parsed) {
129+
if ((size_t)argnum >= inputs.size()) {
130+
func.emitError() << "arg" << argnum << " doesn't exist.";
131+
return;
132+
}
133+
auto tensor_type = cast<TensorType>(inputs[argnum]);
134+
135+
if (failed(mlir::verifyCompatibleShape(tensor_type.getShape(), shape))) {
136+
funcOp->emitError()
137+
<< "arg" << argnum
138+
<< " has incompatible shape with input shape.";
139+
return;
140+
}
141+
SmallVector<int64_t> new_shape(shape.begin(), shape.end());
142+
auto new_tensor_type =
143+
tensor_type.cloneWith(new_shape, tensor_type.getElementType());
144+
inputs[argnum] = cast<Type>(new_tensor_type);
145+
}
146+
147+
FunctionType new_function_type = old_function_type.clone(
148+
TypeRange{ArrayRef(inputs)},
149+
TypeRange{old_function_type.getResults()});
150+
funcOp.setFunctionType(new_function_type);
151+
found_func_op = true;
152+
}
153+
// Update result shape of func.return
154+
func::ReturnOp returnOp = mlir::dyn_cast<func::ReturnOp>(op);
155+
if (returnOp) {
156+
func::FuncOp funcOp = dyn_cast<func::FuncOp>(op.getParentOp());
157+
if (funcOp) {
158+
FunctionType old_function_type = funcOp.getFunctionType();
159+
FunctionType new_function_type = old_function_type.clone(
160+
TypeRange{old_function_type.getInputs()},
161+
returnOp.getOperandTypes());
162+
funcOp.setFunctionType(new_function_type);
163+
}
164+
}
165+
}
166+
}
167+
}
168+
};
169+
170+
} // namespace
171+
172+
std::unique_ptr<Pass>
173+
mlir::tosa::createTosaInputShapePass(std::vector<std::string> args) {
174+
return std::make_unique<TosaInputShape>(args);
175+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt --split-input-file --experimental-tosa-input-shape="args=arg0:2x16,arg3:64x9" %s | FileCheck %s
2+
3+
func.func @test_input_shape(
4+
// CHECK: %arg0: tensor<2x16xi32>
5+
%arg0: tensor<2x?xi32>,
6+
// CHECK: %arg1: tensor<?x256xf32>
7+
%arg1: tensor<?x256xf32>,
8+
// CHECK: %arg2: tensor<2x?xi32>
9+
%arg2: tensor<2x?xi32>,
10+
// CHECK: %arg3: tensor<64x9xf32>
11+
%arg3: tensor<?x9xf32>) -> (tensor<2x?xi32>, tensor<?x9xf32>) {
12+
13+
// CHECK: %arg0, %arg3 : tensor<2x16xi32>, tensor<64x9xf32>
14+
return %arg0, %arg3 : tensor<2x?xi32>, tensor<?x9xf32>
15+
}

0 commit comments

Comments
 (0)