@@ -962,6 +962,119 @@ def forward(self, input):
962
962
f"The optimized model results shape and torch model results shape should be equal in empty_stride" ,
963
963
)
964
964
965
+ @parameterized .expand (
966
+ [
967
+ (
968
+ "scatter_add_zero_dim_indexOne_constant" ,
969
+ 0 ,
970
+ torch .tensor ([[0 , 1 , 2 , 0 ]]).cuda (),
971
+ torch .tensor ([[1 , 2 , 3 , 4 ]], dtype = torch .int32 ).cuda (),
972
+ {torch .ops .aten .add .Tensor },
973
+ ),
974
+ (
975
+ "scatter_add_zero_dim_indexTwo_constant" ,
976
+ 0 ,
977
+ torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ]]).cuda (),
978
+ torch .tensor ([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]], dtype = torch .int32 ).cuda (),
979
+ {torch .ops .aten .add .Tensor , torch .ops .aten .scatter .src },
980
+ ),
981
+ (
982
+ "scatter_add_one_dim_indexOne_constant" ,
983
+ 1 ,
984
+ torch .tensor ([[0 , 1 , 2 , 0 ]]).cuda (),
985
+ torch .tensor ([[1 , 2 , 3 , 1 ]], dtype = torch .int32 ).cuda (),
986
+ {
987
+ torch .ops .aten .add .Tensor ,
988
+ torch .ops .aten .scatter .src ,
989
+ torch .ops .aten .full_like .default ,
990
+ },
991
+ ),
992
+ (
993
+ "scatter_add_one_dim_indexTwo_constant" ,
994
+ 1 ,
995
+ torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ]]).cuda (),
996
+ torch .tensor ([[1 , 2 , 3 , 1 ], [5 , 6 , 5 , 5 ]], dtype = torch .int32 ).cuda (),
997
+ {
998
+ torch .ops .aten .add .Tensor ,
999
+ torch .ops .aten .scatter .src ,
1000
+ torch .ops .aten .full_like .default ,
1001
+ },
1002
+ ),
1003
+ (
1004
+ "scatter_add_one_dim_indexTwo_constant" ,
1005
+ 1 ,
1006
+ torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ], [3 , 2 , 1 , 2 ]]).cuda (),
1007
+ torch .tensor (
1008
+ [[1 , 2 , 3 , 1 ], [5 , 6 , 5 , 5 ], [2 , 4 , 3 , 2 ]], dtype = torch .int32
1009
+ ).cuda (),
1010
+ {
1011
+ torch .ops .aten .add .Tensor ,
1012
+ torch .ops .aten .scatter .src ,
1013
+ torch .ops .aten .full_like .default ,
1014
+ },
1015
+ ),
1016
+ ]
1017
+ )
1018
+ def test_scatter_add (self , _ , dim , index , src , expected_ops_param ):
1019
+ class TestModule (torch .nn .Module ):
1020
+ def __init__ (self ):
1021
+ super ().__init__ ()
1022
+
1023
+ def forward (self , input ):
1024
+ return torch .ops .aten .scatter_add .default (input , dim , index , src )
1025
+
1026
+ # Operations expected to be included in the traced graph after decompositions
1027
+ expected_ops = expected_ops_param
1028
+ unexpected_ops = {torch .ops .aten .scatter_add .default }
1029
+
1030
+ input = torch .zeros (3 , 5 , dtype = torch .int32 ).cuda ()
1031
+ inputs = [input ]
1032
+
1033
+ fx_graph = torch .fx .symbolic_trace (TestModule ())
1034
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
1035
+ fx_graph ,
1036
+ inputs ,
1037
+ expected_ops = expected_ops ,
1038
+ unexpected_ops = unexpected_ops ,
1039
+ min_block_size = 2 ,
1040
+ )
1041
+
1042
+ self .assertEquals (
1043
+ len (expected_ops_unseen ),
1044
+ 0 ,
1045
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
1046
+ )
1047
+
1048
+ self .assertEquals (
1049
+ len (unexpected_ops_seen ),
1050
+ 0 ,
1051
+ f"The following expected ops were not encountered: { unexpected_ops_seen } " ,
1052
+ )
1053
+
1054
+ torch ._dynamo .reset ()
1055
+
1056
+ # Validate that the results between Torch and Torch-TRT are similar
1057
+ optimized_model = torch_tensorrt .compile (
1058
+ fx_graph ,
1059
+ "torch_compile" ,
1060
+ inputs ,
1061
+ min_block_size = 1 ,
1062
+ truncate_double = True ,
1063
+ pass_through_build_failures = True ,
1064
+ )
1065
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
1066
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
1067
+
1068
+ max_diff = float (
1069
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
1070
+ )
1071
+ self .assertAlmostEqual (
1072
+ max_diff ,
1073
+ 0 ,
1074
+ DECIMALS_OF_AGREEMENT ,
1075
+ f"Scatter_add TRT outputs don't match with the original model." ,
1076
+ )
1077
+
965
1078
966
1079
if __name__ == "__main__" :
967
1080
run_tests ()
0 commit comments