|
1354 | 1354 | ),
|
1355 | 1355 | ),
|
1356 | 1356 |
|
1357 |
| - 'pow': dict( |
| 1357 | + 'pow_scalar_base_float_exp': dict( |
1358 | 1358 | name=['pow'],
|
1359 | 1359 | interface=['torch'],
|
1360 | 1360 | is_inplace=True,
|
|
1376 | 1376 | ),
|
1377 | 1377 | ),
|
1378 | 1378 |
|
1379 |
| - 'pow_int': dict( |
| 1379 | + 'pow_scalar_base_int_exp': dict( |
1380 | 1380 | name=['pow'],
|
1381 | 1381 | interface=['torch'],
|
1382 | 1382 | is_inplace=True,
|
|
1391 | 1391 | (2, 128, 3072), (2, 512, 38, 38),
|
1392 | 1392 | (0,), (0, 8), (7, 0, 9)),
|
1393 | 1393 | "dtype": [np.int16, np.int32, np.int64,
|
1394 |
| - np.int8, np.uint8], |
1395 |
| - "gen_fn": dict(fn='Genfunc.randint', low=-4, high=4), |
| 1394 | + np.int8, np.uint8, np.bool_], |
| 1395 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
1396 | 1396 | }
|
1397 | 1397 | ],
|
1398 | 1398 | ),
|
1399 | 1399 | ),
|
1400 | 1400 |
|
1401 |
| - 'pow_bool': dict( |
| 1401 | + # attention: Integers to negative integer powers are not allowed. |
| 1402 | + # may cause overflow if both base and exponet are uint8. |
| 1403 | + # int zero to negative int exp powers are not defined. |
| 1404 | + 'pow_tensor_base_positive_exp': dict( |
1402 | 1405 | name=['pow'],
|
1403 | 1406 | interface=['torch'],
|
1404 | 1407 | is_inplace=True,
|
1405 |
| - para=dict( |
1406 |
| - exponent=[0, -1.2, 2, 0.6, 1.2, 0.], |
1407 |
| - ), |
| 1408 | + dtype=[np.float16, np.float32, np.float64, |
| 1409 | + np.int16, np.int32, np.int64, |
| 1410 | + np.int8], |
1408 | 1411 | tensor_para=dict(
|
1409 | 1412 | args=[
|
1410 | 1413 | {
|
1411 | 1414 | "ins": ['input'],
|
1412 |
| - "shape": ((), (20267, 80), |
| 1415 | + "shape": ((), (1, ), (20267, 80), |
1413 | 1416 | (2, 128, 3072),
|
1414 | 1417 | (2, 512, 38, 38),
|
1415 |
| - (0,), (0, 8)), |
1416 |
| - "dtype": [np.bool_], |
1417 |
| - "gen_fn": 'Genfunc.mask', |
1418 |
| - } |
| 1418 | + (0,), (0, 4), (9, 0, 3)), |
| 1419 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
| 1420 | + }, |
| 1421 | + { |
| 1422 | + "ins": ['exponent'], |
| 1423 | + "shape": ((), (1, ), (20267, 80), |
| 1424 | + (2, 128, 3072), |
| 1425 | + (2, 512, 38, 38), |
| 1426 | + (0,), (0, 4), (9, 0, 3)), |
| 1427 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
| 1428 | + }, |
1419 | 1429 | ],
|
1420 | 1430 | ),
|
1421 | 1431 | ),
|
1422 | 1432 |
|
1423 |
| - 'pow_tensor': dict( |
| 1433 | + # attention: Integers to negative integer powers are not allowed. |
| 1434 | + # int zero to negative int exp powers are not defined. |
| 1435 | + 'pow_tensor_base_negative_exp': dict( |
1424 | 1436 | name=['pow'],
|
1425 | 1437 | interface=['torch'],
|
1426 | 1438 | is_inplace=True,
|
1427 |
| - dtype=[np.float16, np.float32, np.float64, |
1428 |
| - np.int16, np.int32, np.int64, |
1429 |
| - np.int8, np.uint8], |
| 1439 | + dtype=[np.float16, np.float32, np.float64], |
1430 | 1440 | tensor_para=dict(
|
1431 |
| - gen_fn=dict(fn='Genfunc.randn_int', low=-4, high=4), |
1432 | 1441 | args=[
|
1433 | 1442 | {
|
1434 | 1443 | "ins": ['input'],
|
1435 | 1444 | "shape": ((), (1, ), (20267, 80),
|
1436 | 1445 | (2, 128, 3072),
|
1437 | 1446 | (2, 512, 38, 38),
|
1438 | 1447 | (0,), (0, 4), (9, 0, 3)),
|
| 1448 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
1439 | 1449 | },
|
1440 | 1450 | {
|
1441 | 1451 | "ins": ['exponent'],
|
1442 | 1452 | "shape": ((), (1, ), (20267, 80),
|
1443 | 1453 | (2, 128, 3072),
|
1444 | 1454 | (2, 512, 38, 38),
|
1445 | 1455 | (0,), (0, 4), (9, 0, 3)),
|
| 1456 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=-1), |
1446 | 1457 | },
|
1447 | 1458 | ],
|
1448 | 1459 | ),
|
1449 | 1460 | ),
|
1450 | 1461 |
|
| 1462 | + # int zero to negative int exp powers are not defined. |
1451 | 1463 | 'pow_tensor_only_0_1': dict(
|
1452 | 1464 | name=['pow'],
|
1453 | 1465 | interface=['torch'],
|
1454 | 1466 | is_inplace=True,
|
1455 | 1467 | dtype=[np.int16, np.int32, np.int64,
|
1456 | 1468 | np.int8, np.uint8],
|
1457 | 1469 | tensor_para=dict(
|
1458 |
| - gen_fn='Genfunc.randn', |
| 1470 | + gen_fn=dict(fn='Genfunc.uniform', low=0, high=2), |
1459 | 1471 | args=[
|
1460 | 1472 | {
|
1461 | 1473 | "ins": ['input'],
|
|
1520 | 1532 | ),
|
1521 | 1533 | ),
|
1522 | 1534 |
|
| 1535 | + # attention: Integers to negative integer powers are not allowed. |
| 1536 | + # may cause overflow if both base and exponet are uint8 |
| 1537 | + # int zero to negative int exp powers are not defined. |
1523 | 1538 | 'pow_diff_dtype_cast': dict(
|
1524 | 1539 | name=['pow'],
|
1525 | 1540 | interface=['torch'],
|
1526 | 1541 | tensor_para=dict(
|
1527 |
| - gen_fn=dict(fn='Genfunc.randn_int', low=-4, high=4), |
1528 | 1542 | args=[
|
1529 | 1543 | {
|
1530 | 1544 | "ins": ['input'],
|
1531 | 1545 | "shape": ((1024, ),),
|
1532 | 1546 | "dtype": [np.int64, np.int32, np.int16,
|
1533 | 1547 | np.bool_, np.bool_, np.bool_, np.bool_],
|
| 1548 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
1534 | 1549 | },
|
1535 | 1550 | {
|
1536 | 1551 | "ins": ['exponent'],
|
1537 | 1552 | "shape": ((1024, ),),
|
1538 | 1553 | "dtype": [np.float32, np.float64, np.float16,
|
1539 | 1554 | np.int32, np.float32, np.int8, np.uint8],
|
| 1555 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
1540 | 1556 | },
|
1541 | 1557 | ],
|
1542 | 1558 | ),
|
1543 | 1559 | ),
|
1544 | 1560 |
|
1545 |
| - # FIXME pow的input与exponent输入uint8和int8,结果不一致 |
| 1561 | + # attention: Integers to negative integer powers are not allowed. |
| 1562 | + # may cause overflow if both base and exponet are uint8 |
| 1563 | + # int zero to negative int exp powers are not defined. |
1546 | 1564 | 'pow_diff_dtype': dict(
|
1547 | 1565 | name=['pow'],
|
1548 | 1566 | interface=['torch'],
|
1549 | 1567 | is_inplace=True,
|
1550 | 1568 | tensor_para=dict(
|
1551 |
| - gen_fn=dict(fn='Genfunc.randn_int', low=-4, high=4), |
1552 | 1569 | args=[
|
1553 | 1570 | {
|
1554 | 1571 | "ins": ['input'],
|
1555 | 1572 | "shape": ((1024, ),),
|
1556 |
| - # "dtype":[np.float64, np.float32, np.float16, |
1557 |
| - # np.int32, np.float64, np.float64, |
1558 |
| - # np.int8, np.float32, np.uint8], |
1559 | 1573 | "dtype": [np.float64, np.float32, np.float16,
|
1560 | 1574 | np.int32, np.float64, np.float32,
|
1561 | 1575 | np.float32, np.int16, np.int64],
|
| 1576 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=4), |
1562 | 1577 | },
|
1563 | 1578 | {
|
1564 | 1579 | "ins": ['exponent'],
|
1565 | 1580 | "shape": ((1024, ),),
|
1566 |
| - # "dtype":[np.int32, np.uint8, np.bool_, |
1567 |
| - # np.int64, np.float16, np.float32, |
1568 |
| - # np.uint8, np.bool_, np.int8], |
1569 | 1581 | "dtype": [np.int32, np.uint8, np.bool_,
|
1570 | 1582 | np.int64, np.float16, np.float64,
|
1571 | 1583 | np.bool_, np.uint8, np.bool_],
|
| 1584 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
1572 | 1585 | },
|
1573 | 1586 | ],
|
1574 | 1587 | ),
|
1575 | 1588 | ),
|
1576 | 1589 |
|
1577 |
| - 'pow_input_scalar': dict( |
| 1590 | + # attention: Integers to negative integer powers are not allowed. |
| 1591 | + # may cause overflow if exponet are uint8 |
| 1592 | + # int zero to negative int exp powers are not defined. |
| 1593 | + 'pow_input_scalar_positive_exp': dict( |
1578 | 1594 | name=['pow'],
|
1579 | 1595 | interface=['torch'],
|
1580 | 1596 | para=dict(
|
|
1589 | 1605 | (0,), (0, 4), (9, 0, 6)),
|
1590 | 1606 | "dtype": [np.float16, np.float32, np.float64,
|
1591 | 1607 | np.int16, np.int32, np.int64,
|
1592 |
| - np.int8, np.uint8, np.bool_], |
1593 |
| - "gen_fn": dict(fn='Genfunc.randn_int', low=-4, high=4), |
| 1608 | + np.int8, np.bool_], |
| 1609 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
1594 | 1610 | }
|
1595 | 1611 | ],
|
1596 | 1612 | ),
|
1597 | 1613 | ),
|
1598 | 1614 |
|
| 1615 | + 'pow_input_scalar_negative_exp': dict( |
| 1616 | + name=['pow'], |
| 1617 | + interface=['torch'], |
| 1618 | + para=dict( |
| 1619 | + self=[-2, -0.5, 0, 0.6, 2, 3, 4., 1.], |
| 1620 | + ), |
| 1621 | + tensor_para=dict( |
| 1622 | + args=[ |
| 1623 | + { |
| 1624 | + "ins": ['exponent'], |
| 1625 | + "shape": ((), (8,), (125, 1), |
| 1626 | + (70, 1, 2), (4, 256, 16, 16), |
| 1627 | + (0,), (0, 4), (9, 0, 6)), |
| 1628 | + "dtype": [np.float16, np.float32, np.float64], |
| 1629 | + "gen_fn": dict(fn='Genfunc.uniform', low=-4, high=-1), |
| 1630 | + } |
| 1631 | + ], |
| 1632 | + ), |
| 1633 | + ), |
| 1634 | + |
| 1635 | + # attention: Integers to negative integer powers are not allowed. |
1599 | 1636 | 'pow_input_scalar_bool': dict(
|
1600 | 1637 | name=['pow'],
|
1601 | 1638 | interface=['torch'],
|
|
1610 | 1647 | "dtype": [np.float16, np.float32, np.float64,
|
1611 | 1648 | np.int16, np.int32, np.int64,
|
1612 | 1649 | np.int8, np.uint8],
|
1613 |
| - "gen_fn": dict(fn='Genfunc.randn_int', low=-4, high=4), |
| 1650 | + "gen_fn": dict(fn='Genfunc.uniform', low=1, high=4), |
1614 | 1651 | }
|
1615 | 1652 | ],
|
1616 | 1653 | ),
|
|
0 commit comments