Skip to content

[TOSA] Update tosa.cast check according to TOSA v1.0 spec #3948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5513,11 +5513,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, resultTy,
// OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
// op.getType()),
result);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, result);

return success();
}
Expand Down Expand Up @@ -6451,11 +6447,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
tosa::getConstTensor<int32_t>(rewriter, op,
/*vec=*/{0, 3, 1, 2},
/*shape=*/{static_cast<int32_t>(4)});
// SmallVector<int64_t> transposedOutputShape(
// {transposedResizedOpShape[0], transposedResizedOpShape[3],
// transposedResizedOpShape[1], transposedResizedOpShape[2]});
// auto transposedOutputType = RankedTensorType::get(
// makeShapeLLVMCompatible(transposedOutputShape), inputElemTy);

rewriter
.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(resultType), resizeOpResult,
Expand Down
78 changes: 56 additions & 22 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,42 +264,68 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
return const_op.getResult();
}

static LogicalResult checkValidityOfCast(Type src, Type dest) {
// Valid TOSA casting pairs according to TOSA spec:
// https://www.mlplatform.org/tosa/tosa_spec.html#_cast
// Note: currently TOSA doesn't support casting to and from I64 and F64
[[maybe_unused]] static LogicalResult checkValidityOfCast(Type src, Type dest) {
// clang-format off
if ((src == dest) ||
// int64 -> *
(src.isInteger(64) && dest.isInteger(32)) ||
(src.isInteger(64) && dest.isInteger(8)) ||
(src.isInteger(64) && dest.isInteger(1)) ||
(src.isInteger(64) && dest.isF32()) ||
// int32 -> *
(src.isInteger(32) && dest.isInteger(64)) ||
(src.isInteger(32) && dest.isInteger(16)) ||
(src.isInteger(32) && dest.isInteger(8)) ||
(src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isF32()) ||
(src.isInteger(32) && dest.isF16()) ||
(src.isInteger(32) && dest.isBF16()) ||
// int16 -> *
(src.isInteger(16) && dest.isInteger(32)) ||
(src.isInteger(16) && dest.isInteger(8)) ||
(src.isInteger(16) && dest.isInteger(1)) ||
(src.isInteger(16) && dest.isBF16()) ||
(src.isInteger(16) && dest.isF32()) ||
(src.isInteger(16) && dest.isF16()) ||
// int8 -> *
(src.isInteger(8) && dest.isInteger(32)) ||
(src.isInteger(8) && dest.isInteger(16)) ||
(src.isInteger(8) && dest.isInteger(1)) ||
(src.isInteger(8) && dest.isBF16()) ||
(src.isInteger(8) && dest.isF32()) ||
(src.isInteger(8) && dest.isF16()) ||
// int1 -> *
(src.isInteger(1) && dest.isInteger(64)) ||
(src.isInteger(1) && dest.isF32()) ||
// f64 -> *
(src.isF64() && dest.isF32()) ||
(src.isF64() && dest.isBF16()) ||
(src.isInteger(1) && dest.isInteger(32)) ||
(src.isInteger(1) && dest.isInteger(16)) ||
(src.isInteger(1) && dest.isInteger(8)) ||
// f32 -> *
(src.isF32() && dest.isF64()) ||
(src.isF32() && dest.isInteger(32)) ||
(src.isF32() && dest.isInteger(16)) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isInteger(64)) ||
(src.isF32() && dest.isInteger(1)) ||
(src.isF32() && dest.isFloat8E4M3()) ||
(src.isF32() && dest.isFloat8E5M2()) ||
// f16 -> *
(src.isF16() && dest.isInteger(32)) ||
(src.isF16() && dest.isInteger(16)) ||
(src.isF16() && dest.isInteger(8)) ||
(src.isF16() && dest.isBF16()) ||
(src.isF16() && dest.isF32()) ||
(src.isF16() && dest.isFloat8E4M3()) ||
(src.isF16() && dest.isFloat8E5M2()) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isF32())) {
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isF32()) ||
(src.isBF16() && dest.isFloat8E4M3()) ||
(src.isBF16() && dest.isFloat8E5M2()) ||
// fp8e4m3 -> *
(src.isFloat8E4M3() && dest.isBF16()) ||
(src.isFloat8E4M3() && dest.isF32()) ||
(src.isFloat8E4M3() && dest.isF16()) ||
// fp8e5m2 -> *
(src.isFloat8E5M2() && dest.isBF16()) ||
(src.isFloat8E5M2() && dest.isF32()) ||
(src.isFloat8E5M2() && dest.isF16())) {
return success();
}
// clang-format on
Expand All @@ -313,9 +339,17 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();

if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
return rewriter.notifyMatchFailure(
op, "casting to result dtype is invalid or unsupported");
// Temporarily disable checkValidityOfCast as it's currently strictly
// following TOSA spec and might cause many e2e tests to fail. This is because
// even though there are some casting pairs that are not congruent to TOSA
// spec, they are still permissible. TOSA validation should flag these illegal
// constructs in a per-profile manner. This strict validity check will be
// enabled later in a potential `--strict` mode which checks for strict
// casting only when needed (the default value of `--strict` mode will be
// off).
// if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
// return rewriter.notifyMatchFailure(
// op, "casting to result dtype is invalid or unsupported");

if (destElemTy.isInteger(1)) {
auto srcType = dyn_cast<TensorType>(src.getType());
Expand Down
36 changes: 17 additions & 19 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,21 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleInt2D_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"FullModuleFalsePinMemory_basic",
"FullModuleInt2D_basic",
"MaskedFillScalarFloatValueModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"Threshold3dIntModule_basic",
"TrilIndicesModule_basic",
"TrilIndicesOfssetGreaterThanRowModule_basic",
"TriuIndicesNegativeOffsetModule_basic",
"BmmFloat16Module_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_basic",
Expand Down Expand Up @@ -2541,6 +2556,8 @@
}
) - {
### Test failing in make_fx_tosa but not in tosa
"ElementwiseRreluEvalStaticModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"FloatPowerTensorTensorStaticModule_basic",
# Dynamic shape, has extra unsupported broadcast ops
Expand Down Expand Up @@ -3461,7 +3478,6 @@
"LayerNormFwAndBwModule_basic",
"LayerNormManualFwAndBwModule_basic",
"SelfAttentionFwAndBwModule_basic",
"Threshold3dIntModule_basic",
"ElementwiseCopysignModule_basic",
"ElementwiseSignbitModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
Expand Down Expand Up @@ -3510,12 +3526,9 @@
"TensorsConcatComplex64FloatModule_basic",
"TimeOutModule_basic",
"TrilIndicesAllZerosModule_basic",
"TrilIndicesModule_basic",
"TrilIndicesNegativeOffsetModule_basic",
"TrilIndicesOfssetGreaterThanRowModule_basic",
"TriuIndicesAllZerosModule_basic",
"TriuIndicesModule_basic",
"TriuIndicesNegativeOffsetModule_basic",
"TypeConversionUint8ToF32Module_basic",
"WeightNormInterfaceModule_basic",
"AdaptiveAvgPool3dDynamicNoBatch_basic",
Expand Down Expand Up @@ -3545,8 +3558,6 @@
"AtenComplexViewModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"AtenEmbeddingBagSumExample_basic",
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenFloatScalarModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
Expand Down Expand Up @@ -3581,11 +3592,8 @@
"AvgPool2dIntModule_basic",
"AvgPool2dStaticModule_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliOnesModule_basic",
"BernoulliPModule_basic",
"BernoulliTensorModule_basic",
"BernoulliZerosModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
"BincountStaticSizeModule_basic",
Expand Down Expand Up @@ -3679,11 +3687,8 @@
"ElementwiseSpecialExpm1Module_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"EqIntModule_basic",
"FloatImplicitModule_basic",
"FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
Expand Down Expand Up @@ -3769,8 +3774,6 @@
"NativeGroupNormBackwardModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"NllLossModuleBackward1DMeanWeight_basic",
"NllLossModuleBackward1DMean_basic",
"NllLossModuleBackward1DSumWeight_basic",
Expand All @@ -3783,7 +3786,6 @@
"NormalFunctionalModule_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_falsePinMemory",
"PowIntIntModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
Expand Down Expand Up @@ -3879,15 +3881,12 @@
"TorchPrimLoopWhileLikeModule_basic",
"TraceModule_empty",
"TraceUnsignedIntModule_empty",
"TypeConversionI1ToF64Module_basic",
"TypeConversionI1ToI32Module_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dBackwardScalesNone_basic",
"UpSampleNearest2dBackward_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic",
"ZerosLikeModule_falsePinMemory",
# Unexpected failures due to new PyTorch version update
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
Expand Down Expand Up @@ -4650,7 +4649,6 @@
"QuantizedReluUint8_basic",
"QuantizedSingleLayer_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RandLikeDtypeModule_basic",
Expand Down
Loading