Skip to content

Commit 9a167e2

Browse files
[TOSA] Update tosa.cast check according to TOSA v1.0 spec (#3948)
* Update checkValidityOfCast function for tosa.cast according to the latest TOSA v1.0 spec: https://www.mlplatform.org/tosa/tosa_spec.html#_cast * Clean up some dead code in TorchToTosa Change-Id: I41209c698a694bca57ebf49ed3608cf89a0d8ba8 Signed-off-by: Justin Ngo <[email protected]>
1 parent 98e4eb2 commit 9a167e2

File tree

3 files changed

+75
-51
lines changed

3 files changed

+75
-51
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+2-10
Original file line numberDiff line numberDiff line change
@@ -5513,11 +5513,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
55135513
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
55145514
}
55155515

5516-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
5517-
op, resultTy,
5518-
// OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
5519-
// op.getType()),
5520-
result);
5516+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, result);
55215517

55225518
return success();
55235519
}
@@ -6451,11 +6447,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
64516447
tosa::getConstTensor<int32_t>(rewriter, op,
64526448
/*vec=*/{0, 3, 1, 2},
64536449
/*shape=*/{static_cast<int32_t>(4)});
6454-
// SmallVector<int64_t> transposedOutputShape(
6455-
// {transposedResizedOpShape[0], transposedResizedOpShape[3],
6456-
// transposedResizedOpShape[1], transposedResizedOpShape[2]});
6457-
// auto transposedOutputType = RankedTensorType::get(
6458-
// makeShapeLLVMCompatible(transposedOutputShape), inputElemTy);
6450+
64596451
rewriter
64606452
.replaceOpWithNewOp<tosa::TransposeOp>(
64616453
op, getTypeConverter()->convertType(resultType), resizeOpResult,

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

+56-22
Original file line numberDiff line numberDiff line change
@@ -264,42 +264,68 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
264264
return const_op.getResult();
265265
}
266266

267-
static LogicalResult checkValidityOfCast(Type src, Type dest) {
267+
// Valid TOSA casting pairs according to TOSA spec:
268+
// https://www.mlplatform.org/tosa/tosa_spec.html#_cast
269+
// Note: currently TOSA doesn't support casting to and from I64 and F64
270+
[[maybe_unused]] static LogicalResult checkValidityOfCast(Type src, Type dest) {
268271
// clang-format off
269272
if ((src == dest) ||
270-
// int64 -> *
271-
(src.isInteger(64) && dest.isInteger(32)) ||
272-
(src.isInteger(64) && dest.isInteger(8)) ||
273-
(src.isInteger(64) && dest.isInteger(1)) ||
274-
(src.isInteger(64) && dest.isF32()) ||
275273
// int32 -> *
276-
(src.isInteger(32) && dest.isInteger(64)) ||
274+
(src.isInteger(32) && dest.isInteger(16)) ||
275+
(src.isInteger(32) && dest.isInteger(8)) ||
277276
(src.isInteger(32) && dest.isInteger(1)) ||
278277
(src.isInteger(32) && dest.isF32()) ||
278+
(src.isInteger(32) && dest.isF16()) ||
279279
(src.isInteger(32) && dest.isBF16()) ||
280280
// int16 -> *
281+
(src.isInteger(16) && dest.isInteger(32)) ||
282+
(src.isInteger(16) && dest.isInteger(8)) ||
283+
(src.isInteger(16) && dest.isInteger(1)) ||
281284
(src.isInteger(16) && dest.isBF16()) ||
285+
(src.isInteger(16) && dest.isF32()) ||
286+
(src.isInteger(16) && dest.isF16()) ||
282287
// int8 -> *
288+
(src.isInteger(8) && dest.isInteger(32)) ||
289+
(src.isInteger(8) && dest.isInteger(16)) ||
283290
(src.isInteger(8) && dest.isInteger(1)) ||
284291
(src.isInteger(8) && dest.isBF16()) ||
292+
(src.isInteger(8) && dest.isF32()) ||
293+
(src.isInteger(8) && dest.isF16()) ||
285294
// int1 -> *
286-
(src.isInteger(1) && dest.isInteger(64)) ||
287-
(src.isInteger(1) && dest.isF32()) ||
288-
// f64 -> *
289-
(src.isF64() && dest.isF32()) ||
290-
(src.isF64() && dest.isBF16()) ||
295+
(src.isInteger(1) && dest.isInteger(32)) ||
296+
(src.isInteger(1) && dest.isInteger(16)) ||
297+
(src.isInteger(1) && dest.isInteger(8)) ||
291298
// f32 -> *
292-
(src.isF32() && dest.isF64()) ||
299+
(src.isF32() && dest.isInteger(32)) ||
300+
(src.isF32() && dest.isInteger(16)) ||
301+
(src.isF32() && dest.isInteger(8)) ||
293302
(src.isF32() && dest.isBF16()) ||
294303
(src.isF32() && dest.isF16()) ||
295-
(src.isF32() && dest.isInteger(8)) ||
296-
(src.isF32() && dest.isInteger(64)) ||
297-
(src.isF32() && dest.isInteger(1)) ||
304+
(src.isF32() && dest.isFloat8E4M3()) ||
305+
(src.isF32() && dest.isFloat8E5M2()) ||
306+
// f16 -> *
307+
(src.isF16() && dest.isInteger(32)) ||
308+
(src.isF16() && dest.isInteger(16)) ||
309+
(src.isF16() && dest.isInteger(8)) ||
310+
(src.isF16() && dest.isBF16()) ||
311+
(src.isF16() && dest.isF32()) ||
312+
(src.isF16() && dest.isFloat8E4M3()) ||
313+
(src.isF16() && dest.isFloat8E5M2()) ||
298314
// bf16 -> *
299-
(src.isBF16() && dest.isInteger(8)) ||
300-
(src.isBF16() && dest.isInteger(16)) ||
301315
(src.isBF16() && dest.isInteger(32)) ||
302-
(src.isBF16() && dest.isF32())) {
316+
(src.isBF16() && dest.isInteger(16)) ||
317+
(src.isBF16() && dest.isInteger(8)) ||
318+
(src.isBF16() && dest.isF32()) ||
319+
(src.isBF16() && dest.isFloat8E4M3()) ||
320+
(src.isBF16() && dest.isFloat8E5M2()) ||
321+
// fp8e4m3 -> *
322+
(src.isFloat8E4M3() && dest.isBF16()) ||
323+
(src.isFloat8E4M3() && dest.isF32()) ||
324+
(src.isFloat8E4M3() && dest.isF16()) ||
325+
// fp8e5m2 -> *
326+
(src.isFloat8E5M2() && dest.isBF16()) ||
327+
(src.isFloat8E5M2() && dest.isF32()) ||
328+
(src.isFloat8E5M2() && dest.isF16())) {
303329
return success();
304330
}
305331
// clang-format on
@@ -313,9 +339,17 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
313339
Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
314340
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
315341

316-
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
317-
return rewriter.notifyMatchFailure(
318-
op, "casting to result dtype is invalid or unsupported");
342+
// Temporarily disable checkValidityOfCast as it's currently strictly
343+
// following TOSA spec and might cause many e2e tests to fail. This is because
344+
// even though there are some casting pairs that are not congruent to TOSA
345+
// spec, they are still permissible. TOSA validation should flag these illegal
346+
// constructs in a per-profile manner. This strict validity check will be
347+
// enabled later in a potential `--strict` mode which checks for strict
348+
// casting only when needed (the default value of `--strict` mode will be
349+
// off).
350+
// if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
351+
// return rewriter.notifyMatchFailure(
352+
// op, "casting to result dtype is invalid or unsupported");
319353

320354
if (destElemTy.isInteger(1)) {
321355
auto srcType = dyn_cast<TensorType>(src.getType());

projects/pt1/e2e_testing/xfail_sets.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,21 @@
17051705
# Write the TOSA set as a "passing" set as it is very early in development
17061706
# and very few tests work yet.
17071707
TOSA_PASS_SET = {
1708+
"AtenEyeMModuleInt2D_basic",
1709+
"AtenEyeModuleInt2D_basic",
1710+
"ElementwiseWhereScalarOtherStaticModule_basic",
1711+
"FullModuleFalsePinMemory_basic",
1712+
"FullModuleInt2D_basic",
1713+
"MaskedFillScalarFloatValueModule_basic",
1714+
"MaskedFillScalarFloatValueStaticModule_basic",
1715+
"NewFullModuleInt2D_basic",
1716+
"NewFullModuleInt3D_basic",
1717+
"Threshold3dIntModule_basic",
1718+
"TrilIndicesModule_basic",
1719+
"TrilIndicesOfssetGreaterThanRowModule_basic",
1720+
"TriuIndicesNegativeOffsetModule_basic",
1721+
"BmmFloat16Module_basic",
1722+
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
17081723
"Unfold_Module_Rank_4",
17091724
"Unfold_Module_Rank_Zero_basic",
17101725
"Unfold_Module_basic",
@@ -2546,6 +2561,8 @@
25462561
}
25472562
) - {
25482563
### Test failing in make_fx_tosa but not in tosa
2564+
"ElementwiseRreluEvalStaticModule_basic",
2565+
"ElementwiseRreluTrainStaticModule_basic",
25492566
"AdaptiveMaxPool1dDimOneStatic_basic",
25502567
"FloatPowerTensorTensorStaticModule_basic",
25512568
# Dynamic shape, has extra unsupported broadcast ops
@@ -3466,7 +3483,6 @@
34663483
"LayerNormFwAndBwModule_basic",
34673484
"LayerNormManualFwAndBwModule_basic",
34683485
"SelfAttentionFwAndBwModule_basic",
3469-
"Threshold3dIntModule_basic",
34703486
"ElementwiseCopysignModule_basic",
34713487
"ElementwiseSignbitModule_basic",
34723488
"Aten_TrilinearModuleVaryingRanks_basic",
@@ -3515,12 +3531,9 @@
35153531
"TensorsConcatComplex64FloatModule_basic",
35163532
"TimeOutModule_basic",
35173533
"TrilIndicesAllZerosModule_basic",
3518-
"TrilIndicesModule_basic",
35193534
"TrilIndicesNegativeOffsetModule_basic",
3520-
"TrilIndicesOfssetGreaterThanRowModule_basic",
35213535
"TriuIndicesAllZerosModule_basic",
35223536
"TriuIndicesModule_basic",
3523-
"TriuIndicesNegativeOffsetModule_basic",
35243537
"TypeConversionUint8ToF32Module_basic",
35253538
"WeightNormInterfaceModule_basic",
35263539
"AdaptiveAvgPool3dDynamicNoBatch_basic",
@@ -3550,8 +3563,6 @@
35503563
"AtenComplexViewModule_basic",
35513564
"AtenEmbeddingBagStaticModule_basic",
35523565
"AtenEmbeddingBagSumExample_basic",
3553-
"AtenEyeMModuleInt2D_basic",
3554-
"AtenEyeModuleInt2D_basic",
35553566
"AtenFloatScalarModule_basic",
35563567
"AtenIntBoolOpConstFalseModule_basic",
35573568
"AtenIntBoolOpConstTrueModule_basic",
@@ -3586,11 +3597,8 @@
35863597
"AvgPool2dIntModule_basic",
35873598
"AvgPool2dStaticModule_basic",
35883599
"BernoulliFloatModule_basic",
3589-
"BernoulliModule_basic",
3590-
"BernoulliOnesModule_basic",
35913600
"BernoulliPModule_basic",
35923601
"BernoulliTensorModule_basic",
3593-
"BernoulliZerosModule_basic",
35943602
"BincountMinlengthModule_basic",
35953603
"BincountModule_basic",
35963604
"BincountStaticSizeModule_basic",
@@ -3680,11 +3688,8 @@
36803688
"ElementwiseSinhModule_basic",
36813689
"ElementwiseToDtypeF32ToI64Module_basic",
36823690
"ElementwiseToDtypeI64ToUI8Module_basic",
3683-
"ElementwiseWhereScalarOtherStaticModule_basic",
36843691
"EqIntModule_basic",
36853692
"FloatImplicitModule_basic",
3686-
"FullLikeModuleInt2D_basic",
3687-
"FullLikeModuleInt3D_basic",
36883693
"GeFloatIntModule_basic",
36893694
"GeFloatModule_basic",
36903695
"GeIntModule_basic",
@@ -3770,8 +3775,6 @@
37703775
"NativeGroupNormBackwardModule_basic",
37713776
"NeFloatIntModule_basic",
37723777
"NeIntModule_basic",
3773-
"NewFullModuleInt2D_basic",
3774-
"NewFullModuleInt3D_basic",
37753778
"NllLossModuleBackward1DMeanWeight_basic",
37763779
"NllLossModuleBackward1DMean_basic",
37773780
"NllLossModuleBackward1DSumWeight_basic",
@@ -3784,7 +3787,6 @@
37843787
"NormalFunctionalModule_basic",
37853788
"NumelModule_basic",
37863789
"NumelZeroRankModule_basic",
3787-
"OnesLikeModule_falsePinMemory",
37883790
"PowIntIntModule_basic",
37893791
"PrimMaxIntModule_basic",
37903792
"PrimMinIntDynamicModule_basic",
@@ -3880,15 +3882,12 @@
38803882
"TorchPrimLoopWhileLikeModule_basic",
38813883
"TraceModule_empty",
38823884
"TraceUnsignedIntModule_empty",
3883-
"TypeConversionI1ToF64Module_basic",
3884-
"TypeConversionI1ToI32Module_basic",
38853885
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
38863886
"UpSampleNearest2dBackwardScalesNone_basic",
38873887
"UpSampleNearest2dBackward_basic",
38883888
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
38893889
"ViewSizeFromOtherTensor_basic",
38903890
"VisionTransformerModule_basic",
3891-
"ZerosLikeModule_falsePinMemory",
38923891
# Unexpected failures due to new PyTorch version update
38933892
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
38943893
"AdaptiveAvgPool1dGeneralDynamic_basic",
@@ -4651,7 +4650,6 @@
46514650
"QuantizedReluUint8_basic",
46524651
"QuantizedSingleLayer_basic",
46534652
"RandIntDtypeModule_basic",
4654-
"RandIntLowDtypeModule_basic",
46554653
"RandIntModule_basic",
46564654
"RandIntPinMemoryModule_basic",
46574655
"RandLikeDtypeModule_basic",

0 commit comments

Comments
 (0)