|
33 | 33 | import pandas as pd
|
34 | 34 | from sdc import *
|
35 | 35 | from numba.typed import Dict
|
| 36 | +from numba.extending import (overload_method, overload, models, register_model, intrinsic) |
| 37 | +from numba.special import literally |
| 38 | +from numba.typing import signature |
| 39 | +from numba import cgutils |
36 | 40 | from collections import defaultdict
|
37 | 41 | from sdc.tests.test_base import TestCase
|
38 | 42 | from sdc.tests.test_utils import skip_numba_jit
|
@@ -436,6 +440,56 @@ def test_impl():
|
436 | 440 | hpat_func = self.jit(test_impl)
|
437 | 441 | pd.testing.assert_frame_equal(hpat_func(), test_impl())
|
438 | 442 |
|
| 443 | + @unittest.expectedFailure |
| 444 | + def test_literally_with_overload_method(self): |
| 445 | + class Dummy: |
| 446 | + def lit(self, a): |
| 447 | + return a |
| 448 | + |
| 449 | + class DummyType(numba.types.Type): |
| 450 | + def __init__(self): |
| 451 | + super().__init__(name="dummy") |
| 452 | + |
| 453 | + @register_model(DummyType) |
| 454 | + class DummyTypeModel(models.StructModel): |
| 455 | + def __init__(self, dmm, fe_type): |
| 456 | + members = [] |
| 457 | + super().__init__(dmm, fe_type, members) |
| 458 | + |
| 459 | + @intrinsic |
| 460 | + def init_dummy(typingctx): |
| 461 | + def codegen(context, builder, signature, args): |
| 462 | + dummy = cgutils.create_struct_proxy( |
| 463 | + signature.return_type)(context, builder) |
| 464 | + |
| 465 | + return dummy._getvalue() |
| 466 | + |
| 467 | + sig = signature(DummyType()) |
| 468 | + return sig, codegen |
| 469 | + |
| 470 | + @overload(Dummy) |
| 471 | + def dummy_overload(): |
| 472 | + def ctor(): |
| 473 | + return init_dummy() |
| 474 | + |
| 475 | + return ctor |
| 476 | + |
| 477 | + @overload_method(DummyType, 'lit') |
| 478 | + def lit_overload(self, a): |
| 479 | + def impl(self, a): |
| 480 | + return literally(a) |
| 481 | + # return a |
| 482 | + |
| 483 | + return impl |
| 484 | + |
| 485 | + def test_impl(a): |
| 486 | + d = Dummy() |
| 487 | + |
| 488 | + return d.lit(a) |
| 489 | + |
| 490 | + jtest = numba.njit(test_impl) |
| 491 | + test_impl(5) |
| 492 | + |
439 | 493 |
|
440 | 494 | if __name__ == "__main__":
|
441 | 495 | unittest.main()
|
0 commit comments