Skip to content

Commit 680d971

Browse files
[TOSA] Update tosa.cast check according to TOSA v1.0 spec
* Update checkValidityOfCast function for tosa.cast according to the latest TOSA v1.0 spec: https://www.mlplatform.org/tosa/tosa_spec.html#_cast * Clean up some dead code in TorchToTosa Signed-off-by: Justin Ngo <[email protected]> Change-Id: I41209c698a694bca57ebf49ed3608cf89a0d8ba8
1 parent bf594b0 commit 680d971

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+2-10
Original file line numberDiff line numberDiff line change
@@ -5513,11 +5513,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
55135513
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
55145514
}
55155515

5516-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
5517-
op, resultTy,
5518-
// OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
5519-
// op.getType()),
5520-
result);
5516+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, result);
55215517

55225518
return success();
55235519
}
@@ -6451,11 +6447,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
64516447
tosa::getConstTensor<int32_t>(rewriter, op,
64526448
/*vec=*/{0, 3, 1, 2},
64536449
/*shape=*/{static_cast<int32_t>(4)});
6454-
// SmallVector<int64_t> transposedOutputShape(
6455-
// {transposedResizedOpShape[0], transposedResizedOpShape[3],
6456-
// transposedResizedOpShape[1], transposedResizedOpShape[2]});
6457-
// auto transposedOutputType = RankedTensorType::get(
6458-
// makeShapeLLVMCompatible(transposedOutputShape), inputElemTy);
6450+
64596451
rewriter
64606452
.replaceOpWithNewOp<tosa::TransposeOp>(
64616453
op, getTypeConverter()->convertType(resultType), resizeOpResult,

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

+21-1
Original file line numberDiff line numberDiff line change
@@ -274,16 +274,28 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) {
274274
(src.isInteger(64) && dest.isF32()) ||
275275
// int32 -> *
276276
(src.isInteger(32) && dest.isInteger(64)) ||
277+
(src.isInteger(32) && dest.isInteger(16)) ||
278+
(src.isInteger(32) && dest.isInteger(8)) ||
277279
(src.isInteger(32) && dest.isInteger(1)) ||
278280
(src.isInteger(32) && dest.isF32()) ||
281+
(src.isInteger(32) && dest.isF16()) ||
279282
(src.isInteger(32) && dest.isBF16()) ||
280283
// int16 -> *
281284
(src.isInteger(16) && dest.isBF16()) ||
285+
(src.isInteger(16) && dest.isF32()) ||
286+
(src.isInteger(16) && dest.isF16()) ||
282287
// int8 -> *
288+
(src.isInteger(8) && dest.isInteger(32)) ||
289+
(src.isInteger(8) && dest.isInteger(16)) ||
283290
(src.isInteger(8) && dest.isInteger(1)) ||
284291
(src.isInteger(8) && dest.isBF16()) ||
292+
(src.isInteger(8) && dest.isF32()) ||
293+
(src.isInteger(8) && dest.isF16()) ||
285294
// int1 -> *
286295
(src.isInteger(1) && dest.isInteger(64)) ||
296+
(src.isInteger(1) && dest.isInteger(32)) ||
297+
(src.isInteger(1) && dest.isInteger(16)) ||
298+
(src.isInteger(1) && dest.isInteger(8)) ||
287299
(src.isInteger(1) && dest.isF32()) ||
288300
// f64 -> *
289301
(src.isF64() && dest.isF32()) ||
@@ -292,9 +304,17 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) {
292304
(src.isF32() && dest.isF64()) ||
293305
(src.isF32() && dest.isBF16()) ||
294306
(src.isF32() && dest.isF16()) ||
295-
(src.isF32() && dest.isInteger(8)) ||
296307
(src.isF32() && dest.isInteger(64)) ||
308+
(src.isF32() && dest.isInteger(32)) ||
309+
(src.isF32() && dest.isInteger(16)) ||
310+
(src.isF32() && dest.isInteger(8)) ||
297311
(src.isF32() && dest.isInteger(1)) ||
312+
// f16 -> *
313+
(src.isF16() && dest.isF32()) ||
314+
(src.isF16() && dest.isBF16()) ||
315+
(src.isF16() && dest.isInteger(32)) ||
316+
(src.isF16() && dest.isInteger(16)) ||
317+
(src.isF16() && dest.isInteger(8)) ||
298318
// bf16 -> *
299319
(src.isBF16() && dest.isInteger(8)) ||
300320
(src.isBF16() && dest.isInteger(16)) ||

projects/pt1/e2e_testing/xfail_sets.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,8 @@
17041704
# Write the TOSA set as a "passing" set as it is very early in development
17051705
# and very few tests work yet.
17061706
TOSA_PASS_SET = {
1707+
"BmmFloat16Module_basic",
1708+
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
17071709
"Unfold_Module_Rank_4",
17081710
"Unfold_Module_Rank_Zero_basic",
17091711
"Unfold_Module_basic",
@@ -2541,6 +2543,8 @@
25412543
}
25422544
) - {
25432545
### Test failing in make_fx_tosa but not in tosa
2546+
"ElementwiseRreluEvalStaticModule_basic",
2547+
"ElementwiseRreluTrainStaticModule_basic",
25442548
"AdaptiveMaxPool1dDimOneStatic_basic",
25452549
"FloatPowerTensorTensorStaticModule_basic",
25462550
# Dynamic shape, has extra unsupported broadcast ops
@@ -3880,7 +3884,6 @@
38803884
"TraceModule_empty",
38813885
"TraceUnsignedIntModule_empty",
38823886
"TypeConversionI1ToF64Module_basic",
3883-
"TypeConversionI1ToI32Module_basic",
38843887
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
38853888
"UpSampleNearest2dBackwardScalesNone_basic",
38863889
"UpSampleNearest2dBackward_basic",

0 commit comments

Comments
 (0)