Skip to content

Commit 2cc9690

Browse files
authored
[Ascend] Wx/fix the device config for index_put op an ascend (DeepLink-org#854)
* skip test cases when accumulate is True for index_put op
1 parent 300dc31 commit 2cc9690

File tree

1 file changed

+10
-35
lines changed

1 file changed

+10
-35
lines changed

impl/ascend/device_configs.py

+10-35
Original file line numberDiff line numberDiff line change
@@ -1093,61 +1093,36 @@
10931093

10941094
'index_put_acc_three_indices': dict( # llm used
10951095
name=['index_put'],
1096-
tensor_para=dict(
1097-
args=[
1098-
{
1099-
"ins": ['input'],
1100-
"dtype": [Skip(np.float32),Skip(np.float64),Skip(np.float16),Skip(np.int32),Skip(np.int64),Skip(np.uint8),Skip(np.int8),Skip(np.bool_),],
1101-
},
1102-
]
1096+
para=dict(
1097+
accumulate=[Skip(False),],
11031098
),
11041099
),
11051100

11061101
'index_put_acc_two_indices': dict( # llm used
11071102
name=['index_put'],
1108-
tensor_para=dict(
1109-
args=[
1110-
{
1111-
"ins": ['input'],
1112-
"dtype": [Skip(np.float32),Skip(np.float64),],
1113-
},
1114-
]
1103+
para=dict(
1104+
accumulate=[Skip(False),],
11151105
),
11161106
),
11171107

11181108
'index_put_acc_one_indices': dict( # llm used
11191109
name=['index_put'],
1120-
tensor_para=dict(
1121-
args=[
1122-
{
1123-
"ins": ['input'],
1124-
"dtype": [Skip(np.float32),Skip(np.float64),],
1125-
},
1126-
]
1110+
para=dict(
1111+
accumulate=[Skip(False),],
11271112
),
11281113
),
11291114

11301115
'index_put_acc_bool_indices_zeros': dict( # llm used
11311116
name=['index_put'],
1132-
tensor_para=dict(
1133-
args=[
1134-
{
1135-
"ins": ['input'],
1136-
"dtype": [Skip(np.float32),Skip(np.int64),],
1137-
},
1138-
]
1117+
para=dict(
1118+
accumulate=[Skip(False),],
11391119
),
11401120
),
11411121

11421122
'index_put_one_indices': dict( # llm used
11431123
name=['index_put'],
1144-
tensor_para=dict(
1145-
args=[
1146-
{
1147-
"ins": ['input'],
1148-
"dtype": [Skip(np.float32),Skip(np.float64),Skip(np.float16),Skip(np.int32),Skip(np.int64),Skip(np.uint8),Skip(np.int8),Skip(np.bool_),],
1149-
},
1150-
]
1124+
para=dict(
1125+
accumulate=[Skip(False),],
11511126
),
11521127
),
11531128

0 commit comments

Comments
 (0)