Skip to content

Commit 2acdecd

Browse files
NeosZhangyangbofun
andauthored
[ascend]zq/fix device_config for ascend pow (DeepLink-org#831)
* fix device_config for ascend pow --------- Co-authored-by: yangbofun <[email protected]>
1 parent 20ce541 commit 2acdecd

File tree

2 files changed

+70
-69
lines changed

2 files changed

+70
-69
lines changed

diopi_test/python/configs/diopi_configs.py

100644100755
Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,7 @@
13541354
),
13551355
),
13561356

1357-
'pow': dict(
1357+
'pow_scalar_base_float_exp': dict(
13581358
name=['pow'],
13591359
interface=['torch'],
13601360
is_inplace=True,
@@ -1376,7 +1376,7 @@
13761376
),
13771377
),
13781378

1379-
'pow_int': dict(
1379+
'pow_scalar_base_int_exp': dict(
13801380
name=['pow'],
13811381
interface=['torch'],
13821382
is_inplace=True,
@@ -1391,71 +1391,83 @@
13911391
(2, 128, 3072), (2, 512, 38, 38),
13921392
(0,), (0, 8), (7, 0, 9)),
13931393
"dtype": [np.int16, np.int32, np.int64,
1394-
np.int8, np.uint8],
1395-
"gen_fn": dict(fn='Genfunc.randint', low=-4, high=4),
1394+
np.int8, np.uint8, np.bool_],
1395+
"gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4),
13961396
}
13971397
],
13981398
),
13991399
),
14001400

1401-
'pow_bool': dict(
1401+
# attention: Integers to negative integer powers are not allowed.
1402+
# may cause overflow if both base and exponet are uint8.
1403+
# int zero to negative int exp powers are not defined.
1404+
'pow_tensor_base_positive_exp': dict(
14021405
name=['pow'],
14031406
interface=['torch'],
14041407
is_inplace=True,
1405-
para=dict(
1406-
exponent=[0, -1.2, 2, 0.6, 1.2, 0.],
1407-
),
1408+
dtype=[np.float16, np.float32, np.float64,
1409+
np.int16, np.int32, np.int64,
1410+
np.int8],
14081411
tensor_para=dict(
14091412
args=[
14101413
{
14111414
"ins": ['input'],
1412-
"shape": ((), (20267, 80),
1415+
"shape": ((), (1, ), (20267, 80),
14131416
(2, 128, 3072),
14141417
(2, 512, 38, 38),
1415-
(0,), (0, 8)),
1416-
"dtype": [np.bool_],
1417-
"gen_fn": 'Genfunc.mask',
1418-
}
1418+
(0,), (0, 4), (9, 0, 3)),
1419+
"gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4),
1420+
},
1421+
{
1422+
"ins": ['exponent'],
1423+
"shape": ((), (1, ), (20267, 80),
1424+
(2, 128, 3072),
1425+
(2, 512, 38, 38),
1426+
(0,), (0, 4), (9, 0, 3)),
1427+
"gen_fn": dict(fn='Genfunc.uniform', low=1, high=4),
1428+
},
14191429
],
14201430
),
14211431
),
14221432

1423-
'pow_tensor': dict(
1433+
# attention: Integers to negative integer powers are not allowed.
1434+
# int zero to negative int exp powers are not defined.
1435+
'pow_tensor_base_negative_exp': dict(
14241436
name=['pow'],
14251437
interface=['torch'],
14261438
is_inplace=True,
1427-
dtype=[np.float16, np.float32, np.float64,
1428-
np.int16, np.int32, np.int64,
1429-
np.int8, np.uint8],
1439+
dtype=[np.float16, np.float32, np.float64],
14301440
tensor_para=dict(
1431-
gen_fn=dict(fn='Genfunc.randn_int', low=-4, high=4),
14321441
args=[
14331442
{
14341443
"ins": ['input'],
14351444
"shape": ((), (1, ), (20267, 80),
14361445
(2, 128, 3072),
14371446
(2, 512, 38, 38),
14381447
(0,), (0, 4), (9, 0, 3)),
1448+
"gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4),
14391449
},
14401450
{
14411451
"ins": ['exponent'],
14421452
"shape": ((), (1, ), (20267, 80),
14431453
(2, 128, 3072),
14441454
(2, 512, 38, 38),
14451455
(0,), (0, 4), (9, 0, 3)),
1456+
"gen_fn": dict(fn='Genfunc.uniform', low=-4, high=-1),
14461457
},
14471458
],
14481459
),
14491460
),
14501461

1462+
# int zero to negative int exp powers are not defined.
14511463
'pow_tensor_only_0_1': dict(
14521464
name=['pow'],
14531465
interface=['torch'],
14541466
is_inplace=True,
14551467
dtype=[np.int16, np.int32, np.int64,
14561468
np.int8, np.uint8],
14571469
tensor_para=dict(
1458-
gen_fn='Genfunc.randn',
1470+
gen_fn=dict(fn='Genfunc.uniform', low=0, high=2),
14591471
args=[
14601472
{
14611473
"ins": ['input'],
@@ -1520,61 +1532,65 @@
15201532
),
15211533
),
15221534

1535+
# attention: Integers to negative integer powers are not allowed.
1536+
# may cause overflow if both base and exponet are uint8
1537+
# int zero to negative int exp powers are not defined.
15231538
'pow_diff_dtype_cast': dict(
15241539
name=['pow'],
15251540
interface=['torch'],
15261541
tensor_para=dict(
1527-
gen_fn=dict(fn='Genfunc.randn_int', low=-4, high=4),
15281542
args=[
15291543
{
15301544
"ins": ['input'],
15311545
"shape": ((1024, ),),
15321546
"dtype": [np.int64, np.int32, np.int16,
15331547
np.bool_, np.bool_, np.bool_, np.bool_],
1548+
"gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4),
15341549
},
15351550
{
15361551
"ins": ['exponent'],
15371552
"shape": ((1024, ),),
15381553
"dtype": [np.float32, np.float64, np.float16,
15391554
np.int32, np.float32, np.int8, np.uint8],
1555+
"gen_fn": dict(fn='Genfunc.uniform', low=1, high=4),
15401556
},
15411557
],
15421558
),
15431559
),
15441560

1545-
# FIXME pow的input与exponent输入uint8和int8,结果不一致
1561+
# attention: Integers to negative integer powers are not allowed.
1562+
# may cause overflow if both base and exponet are uint8
1563+
# int zero to negative int exp powers are not defined.
15461564
'pow_diff_dtype': dict(
15471565
name=['pow'],
15481566
interface=['torch'],
15491567
is_inplace=True,
15501568
tensor_para=dict(
1551-
gen_fn=dict(fn='Genfunc.randn_int', low=-4, high=4),
15521569
args=[
15531570
{
15541571
"ins": ['input'],
15551572
"shape": ((1024, ),),
1556-
# "dtype":[np.float64, np.float32, np.float16,
1557-
# np.int32, np.float64, np.float64,
1558-
# np.int8, np.float32, np.uint8],
15591573
"dtype": [np.float64, np.float32, np.float16,
15601574
np.int32, np.float64, np.float32,
15611575
np.float32, np.int16, np.int64],
1576+
"gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4),
15621577
},
15631578
{
15641579
"ins": ['exponent'],
15651580
"shape": ((1024, ),),
1566-
# "dtype":[np.int32, np.uint8, np.bool_,
1567-
# np.int64, np.float16, np.float32,
1568-
# np.uint8, np.bool_, np.int8],
15691581
"dtype": [np.int32, np.uint8, np.bool_,
15701582
np.int64, np.float16, np.float64,
15711583
np.bool_, np.uint8, np.bool_],
1584+
"gen_fn": dict(fn='Genfunc.uniform', low=1, high=4),
15721585
},
15731586
],
15741587
),
15751588
),
15761589

1577-
'pow_input_scalar': dict(
1590+
# attention: Integers to negative integer powers are not allowed.
1591+
# may cause overflow if exponet are uint8
1592+
# int zero to negative int exp powers are not defined.
1593+
'pow_input_scalar_positive_exp': dict(
15781594
name=['pow'],
15791595
interface=['torch'],
15801596
para=dict(
@@ -1589,13 +1605,34 @@
15891605
(0,), (0, 4), (9, 0, 6)),
15901606
"dtype": [np.float16, np.float32, np.float64,
15911607
np.int16, np.int32, np.int64,
1592-
np.int8, np.uint8, np.bool_],
1593-
"gen_fn": dict(fn='Genfunc.randn_int', low=-4, high=4),
1608+
np.int8, np.bool_],
1609+
"gen_fn": dict(fn='Genfunc.uniform', low=1, high=4),
15941610
}
15951611
],
15961612
),
15971613
),
15981614

1615+
'pow_input_scalar_negative_exp': dict(
1616+
name=['pow'],
1617+
interface=['torch'],
1618+
para=dict(
1619+
self=[-2, -0.5, 0, 0.6, 2, 3, 4., 1.],
1620+
),
1621+
tensor_para=dict(
1622+
args=[
1623+
{
1624+
"ins": ['exponent'],
1625+
"shape": ((), (8,), (125, 1),
1626+
(70, 1, 2), (4, 256, 16, 16),
1627+
(0,), (0, 4), (9, 0, 6)),
1628+
"dtype": [np.float16, np.float32, np.float64],
1629+
"gen_fn": dict(fn='Genfunc.uniform', low=-4, high=-1),
1630+
}
1631+
],
1632+
),
1633+
),
1634+
1635+
# attention: Integers to negative integer powers are not allowed.
15991636
'pow_input_scalar_bool': dict(
16001637
name=['pow'],
16011638
interface=['torch'],
@@ -1610,7 +1647,7 @@
16101647
"dtype": [np.float16, np.float32, np.float64,
16111648
np.int16, np.int32, np.int64,
16121649
np.int8, np.uint8],
1613-
"gen_fn": dict(fn='Genfunc.randn_int', low=-4, high=4),
1650+
"gen_fn": dict(fn='Genfunc.uniform', low=1, high=4),
16141651
}
16151652
],
16161653
),

impl/ascend/device_configs.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from skip import Skip
44

5-
# topk, normal, norm, nll_loss, gather, fill_, triu, bmm, mm llm used
5+
# topk, normal, norm, nll_loss, gather, fill_, triu, bmm, mm, pow llm used
66

77
device_configs = {
88
# temp for 910B
@@ -433,42 +433,6 @@
433433
),
434434
),
435435

436-
'pow_tensor': dict( # llm used
437-
name=['pow'],
438-
tensor_para=dict(
439-
args=[
440-
{
441-
"ins": ['input'],
442-
"shape": [Skip(()),Skip((1,)),Skip((20267, 80)),Skip((2, 128, 3072)),Skip((2, 512, 38, 38)),Skip((0,)),Skip((0, 4)),Skip((9, 0, 3)),],
443-
},
444-
]
445-
),
446-
),
447-
448-
'pow_tensor_only_0_1': dict( # llm used
449-
name=['pow'],
450-
tensor_para=dict(
451-
args=[
452-
{
453-
"ins": ['input'],
454-
"shape": [Skip(()),Skip((1,)),Skip((20267, 80)),Skip((2, 128, 3072)),Skip((2, 512, 38, 38)),Skip((0,)),Skip((0, 4)),Skip((9, 0, 3)),],
455-
},
456-
]
457-
),
458-
),
459-
460-
'pow_diff_dtype': dict( # llm used
461-
name=['pow'],
462-
tensor_para=dict(
463-
args=[
464-
{
465-
"ins": ['input'],
466-
"dtype": [Skip(np.float64),Skip(np.float32),Skip(np.float16),Skip(np.int32),Skip(np.float64),Skip(np.float32),Skip(np.float32),Skip(np.int16),Skip(np.int64),],
467-
},
468-
]
469-
),
470-
),
471-
472436
'matmul': dict(
473437
name=['matmul'],
474438
tensor_para=dict(

0 commit comments

Comments
 (0)