Skip to content

Commit 0a63cf4

Browse files
committed
Add support for constant multivectors in jitted functions.
This changes the approach for storing `Layout` objects to be using globals via `add_dynamic_addr`. This will disable numba caching, but we don't use that anyway.
1 parent 439e905 commit 0a63cf4

File tree

3 files changed

+37
-9
lines changed

3 files changed

+37
-9
lines changed

clifford/numba/_layout.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
try:
44
# module locations as of numba 0.49.0
55
from numba.core import cgutils, types
6+
from numba.core.imputils import lower_constant
67
except ImportError:
78
# module locations prior to numba 0.49.0
89
from numba import cgutils, types
10+
from numba.targets.imputils import lower_constant
911

1012
from .._layout import Layout
1113

1214

13-
# Taken from numba_passthru
14-
opaque_pyobject = types.Opaque('Opaque(PyObject)')
15+
opaque_layout = types.Opaque('Opaque(Layout)')
1516

1617

1718
class LayoutType(types.Type):
@@ -23,7 +24,7 @@ def __init__(self):
2324
class LayoutModel(numba.extending.models.StructModel):
2425
def __init__(self, dmm, fe_typ):
2526
members = [
26-
('meminfo', types.MemInfoPointer(opaque_pyobject)),
27+
('obj', opaque_layout),
2728
]
2829
super().__init__(dmm, fe_typ, members)
2930

@@ -33,18 +34,25 @@ def _typeof_Layout(val: Layout, c) -> LayoutType:
3334
return LayoutType()
3435

3536

36-
# Derived from numba_passthru
37+
# Derived from the `Dispatcher` boxing
38+
39+
@lower_constant(LayoutType)
40+
def lower_constant_dispatcher(context, builder, typ, pyval):
41+
layout = cgutils.create_struct_proxy(typ)(context, builder)
42+
layout.obj = context.add_dynamic_addr(builder, id(pyval), info=type(pyval).__name__)
43+
return layout._getvalue()
44+
3745

3846
@numba.extending.unbox(LayoutType)
3947
def unbox_Layout(typ, obj, context):
4048
layout = cgutils.create_struct_proxy(typ)(context.context, context.builder)
41-
layout.meminfo = context.pyapi.nrt_meminfo_new_from_pyobject(obj, obj)
49+
layout.obj = obj
4250
return numba.extending.NativeValue(layout._getvalue())
4351

4452

4553
@numba.extending.box(LayoutType)
4654
def box_Layout(typ, val, context):
4755
val = cgutils.create_struct_proxy(typ)(context.context, context.builder, value=val)
48-
obj = context.context.nrt.meminfo_data(context.builder, val.meminfo)
56+
obj = val.obj
4957
context.pyapi.incref(obj)
5058
return obj

clifford/numba/_multivector.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
try:
1111
# module locations as of numba 0.49.0
1212
import numba.np.numpy_support as _numpy_support
13-
from numba.core.imputils import impl_ret_borrowed
13+
from numba.core.imputils import impl_ret_borrowed, lower_constant
1414
from numba.core import cgutils, types
1515
except ImportError:
1616
# module locations prior to numba 0.49.0
1717
import numba.numpy_support as _numpy_support
18-
from numba.targets.imputils import impl_ret_borrowed
18+
from numba.targets.imputils import impl_ret_borrowed, lower_constant
1919
from numba import cgutils, types
2020

2121
from .._multivector import MultiVector
@@ -46,7 +46,8 @@ def layout_type(self):
4646

4747
@numba.extending.typeof_impl.register(MultiVector)
4848
def _typeof_MultiVector(val: MultiVector, c) -> MultiVectorType:
49-
return MultiVectorType(dtype=val.value.dtype)
49+
val._numba_type_ = MultiVectorType(dtype=val.value.dtype)
50+
return val._numba_type_
5051

5152

5253
@numba.extending.register_model(MultiVectorType)
@@ -77,6 +78,18 @@ def impl_MultiVector(context, builder, sig, args):
7778
return impl_ret_borrowed(context, builder, sig.return_type, mv._getvalue())
7879

7980

81+
@lower_constant(MultiVectorType)
82+
def lower_constant_MultiVector(context, builder, typ: MultiVectorType, pyval: MultiVector):
83+
value = context.get_constant_generic(builder, typ.value_type, pyval.value)
84+
layout = context.get_constant_generic(builder, typ.layout_type, pyval.layout)
85+
return impl_ret_borrowed(
86+
context,
87+
builder,
88+
typ,
89+
cgutils.pack_struct(builder, (layout, value)),
90+
)
91+
92+
8093
@numba.extending.unbox(MultiVectorType)
8194
def unbox_MultiVector(typ: MultiVectorType, obj: MultiVector, c) -> NativeValue:
8295
value = c.pyapi.object_getattr_string(obj, "value")

clifford/test/test_numba_extensions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ def add(a, b):
4141
ab = add(e1, e2)
4242
assert ab == e1 + e2
4343
assert ab.layout is e1.layout
44+
45+
def test_constant_multivector(self):
46+
@numba.njit
47+
def add_e1(a):
48+
return cf.MultiVector(a.layout, a.value + e1.value)
49+
50+
assert add_e1(e2) == e1 + e2

0 commit comments

Comments
 (0)