9
9
import torch
10
10
from executorch .backends .arm .test import common
11
11
from executorch .backends .arm .test .tester .test_pipeline import (
12
+ EthosU55PipelineINT ,
13
+ EthosU85PipelineINT ,
12
14
TosaPipelineFP ,
13
15
TosaPipelineINT ,
14
16
VgfPipeline ,
@@ -30,8 +32,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
32
return torch .unflatten (x , self .dim , self .sizes )
31
33
32
34
test_data : dict [str , test_data_t ] = {
33
- "randn_4d" : (lambda : (Unflatten (1 , (2 , 2 )), (torch .randn (3 , 4 , 5 , 1 ),))),
34
- "rand_3d" : (lambda : (Unflatten (1 , (- 1 , 2 )), (torch .rand (3 , 4 , 4 ),))),
35
+ "rand_3d_batch3" : (lambda : (Unflatten (1 , (- 1 , 2 )), (torch .rand (3 , 4 , 4 ),))),
36
+ "rand_3d_batch1" : (lambda : (Unflatten (1 , (- 1 , 2 )), (torch .rand (1 , 4 , 4 ),))),
37
+ "randn_4d_dim1" : (lambda : (Unflatten (1 , (2 , 2 )), (torch .randn (3 , 4 , 5 , 1 ),))),
38
+ "randn_4d_dim3" : (lambda : (Unflatten (3 , (2 , 2 )), (torch .randn (1 , 1 , 5 , 4 ),))),
35
39
}
36
40
37
41
@@ -49,7 +53,33 @@ def test_unflatten_int_tosa_FP(test_data: test_data_t):
49
53
@common .parametrize ("test_data" , Unflatten .test_data )
50
54
def test_unflatten_int_tosa_INT (test_data : test_data_t ):
51
55
module , inputs = test_data ()
52
- pipeline = TosaPipelineINT [input_t ](
56
+ pipeline = TosaPipelineINT [input_t ](module , inputs , Unflatten .aten_op )
57
+ pipeline .run ()
58
+
59
+
60
+ xfails = {
61
+ "rand_3d_batch3" : "Batch size > 1 currently not supported for FVP tests" ,
62
+ "randn_4d_dim1" : "Batch size > 1 currently not supported for FVP tests" ,
63
+ }
64
+
65
+
66
+ @common .parametrize ("test_data" , Unflatten .test_data , xfails = xfails , strict = False )
67
+ @common .XfailIfNoCorstone300
68
+ def test_unflatten_int_u55_INT (test_data : test_data_t ):
69
+ module , inputs = test_data ()
70
+ pipeline = EthosU55PipelineINT [input_t ](
71
+ module ,
72
+ inputs ,
73
+ Unflatten .aten_op ,
74
+ )
75
+ pipeline .run ()
76
+
77
+
78
+ @common .parametrize ("test_data" , Unflatten .test_data , xfails = xfails , strict = False )
79
+ @common .XfailIfNoCorstone320
80
+ def test_unflatten_int_u85_INT (test_data : test_data_t ):
81
+ module , inputs = test_data ()
82
+ pipeline = EthosU85PipelineINT [input_t ](
53
83
module ,
54
84
inputs ,
55
85
Unflatten .aten_op ,
0 commit comments