Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit f10cbf3

Browse files
Test for literally in overload_method
1 parent 433fa97 commit f10cbf3

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

sdc/tests/test_hpat_jit.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
import pandas as pd
3434
from sdc import *
3535
from numba.typed import Dict
36+
from numba.extending import (overload_method, overload, models, register_model, intrinsic)
37+
from numba.special import literally
38+
from numba.typing import signature
39+
from numba import cgutils
3640
from collections import defaultdict
3741
from sdc.tests.test_base import TestCase
3842
from sdc.tests.test_utils import skip_numba_jit
@@ -436,6 +440,56 @@ def test_impl():
436440
hpat_func = self.jit(test_impl)
437441
pd.testing.assert_frame_equal(hpat_func(), test_impl())
438442

443+
@unittest.expectedFailure
444+
def test_literally_with_overload_method(self):
445+
class Dummy:
446+
def lit(self, a):
447+
return a
448+
449+
class DummyType(numba.types.Type):
450+
def __init__(self):
451+
super().__init__(name="dummy")
452+
453+
@register_model(DummyType)
454+
class DummyTypeModel(models.StructModel):
455+
def __init__(self, dmm, fe_type):
456+
members = []
457+
super().__init__(dmm, fe_type, members)
458+
459+
@intrinsic
460+
def init_dummy(typingctx):
461+
def codegen(context, builder, signature, args):
462+
dummy = cgutils.create_struct_proxy(
463+
signature.return_type)(context, builder)
464+
465+
return dummy._getvalue()
466+
467+
sig = signature(DummyType())
468+
return sig, codegen
469+
470+
@overload(Dummy)
471+
def dummy_overload():
472+
def ctor():
473+
return init_dummy()
474+
475+
return ctor
476+
477+
@overload_method(DummyType, 'lit')
478+
def lit_overload(self, a):
479+
def impl(self, a):
480+
return literally(a)
481+
# return a
482+
483+
return impl
484+
485+
def test_impl(a):
486+
d = Dummy()
487+
488+
return d.lit(a)
489+
490+
jtest = numba.njit(test_impl)
491+
test_impl(5)
492+
439493

440494
if __name__ == "__main__":
441495
unittest.main()

0 commit comments

Comments
 (0)