Skip to content

Commit d45b964

Browse files
authored
Merge pull request #189 from eric-wieser/numba-extension
Add primitive numba support for Layout and MultiVector
2 parents bf74bad + 8f4f247 commit d45b964

File tree

5 files changed

+251
-2
lines changed

5 files changed

+251
-2
lines changed

clifford/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181

8282
# Major library imports.
8383
import numpy as np
84-
import numba
84+
import numba as _numba # to avoid clashing with clifford.numba
8585
import sparse
8686
try:
8787
from numba.np import numpy_support as _numpy_support
@@ -147,7 +147,7 @@ def get_mult_function(mt: sparse.COO, gradeList,
147147
return _get_mult_function_runtime_sparse(mt)
148148

149149

150-
def _get_mult_function_result_type(a: numba.types.Type, b: numba.types.Type, mt: np.dtype):
150+
def _get_mult_function_result_type(a: _numba.types.Type, b: _numba.types.Type, mt: np.dtype):
151151
a_dt = _numpy_support.as_dtype(getattr(a, 'dtype', a))
152152
b_dt = _numpy_support.as_dtype(getattr(b, 'dtype', b))
153153
return np.result_type(a_dt, mt, b_dt)
@@ -325,6 +325,9 @@ def val_get_right_gmt_matrix(mt: sparse.COO, x):
325325
from ._layout_helpers import BasisVectorIds, BasisBladeOrder # noqa: F401
326326
from ._mvarray import MVArray, array # noqa: F401
327327
from ._frame import Frame # noqa: F401
328+
329+
# this registers the extension type
330+
from . import numba # noqa: F401
328331
from ._blademap import BladeMap # noqa: F401
329332

330333

clifford/numba/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from ._multivector import MultiVectorType
2+
from ._layout import LayoutType

clifford/numba/_layout.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import numba
2+
import numba.extending
3+
try:
4+
# module locations as of numba 0.49.0
5+
from numba.core import cgutils, types
6+
from numba.core.imputils import lower_constant
7+
except ImportError:
8+
# module locations prior to numba 0.49.0
9+
from numba import cgutils, types
10+
from numba.targets.imputils import lower_constant
11+
12+
from .._layout import Layout
13+
14+
15+
opaque_layout = types.Opaque('Opaque(Layout)')
16+
17+
18+
class LayoutType(types.Type):
19+
def __init__(self):
20+
super().__init__("LayoutType")
21+
22+
23+
@numba.extending.register_model(LayoutType)
24+
class LayoutModel(numba.extending.models.StructModel):
25+
def __init__(self, dmm, fe_typ):
26+
members = [
27+
('obj', opaque_layout),
28+
]
29+
super().__init__(dmm, fe_typ, members)
30+
31+
32+
@numba.extending.typeof_impl.register(Layout)
33+
def _typeof_Layout(val: Layout, c) -> LayoutType:
34+
return LayoutType()
35+
36+
37+
# Derived from the `Dispatcher` boxing
38+
39+
@lower_constant(LayoutType)
40+
def lower_constant_dispatcher(context, builder, typ, pyval):
41+
layout = cgutils.create_struct_proxy(typ)(context, builder)
42+
layout.obj = context.add_dynamic_addr(builder, id(pyval), info=type(pyval).__name__)
43+
return layout._getvalue()
44+
45+
46+
@numba.extending.unbox(LayoutType)
47+
def unbox_Layout(typ, obj, context):
48+
layout = cgutils.create_struct_proxy(typ)(context.context, context.builder)
49+
layout.obj = obj
50+
return numba.extending.NativeValue(layout._getvalue())
51+
52+
53+
@numba.extending.box(LayoutType)
54+
def box_Layout(typ, val, context):
55+
val = cgutils.create_struct_proxy(typ)(context.context, context.builder, value=val)
56+
obj = val.obj
57+
context.pyapi.incref(obj)
58+
return obj

clifford/numba/_multivector.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
Numba support for MultiVector objects.
3+
4+
For now, this just supports .value wrapping / unwrapping
5+
"""
6+
import numpy as np
7+
import numba
8+
from numba.extending import NativeValue
9+
10+
try:
11+
# module locations as of numba 0.49.0
12+
import numba.np.numpy_support as _numpy_support
13+
from numba.core.imputils import impl_ret_borrowed, lower_constant
14+
from numba.core import cgutils, types
15+
except ImportError:
16+
# module locations prior to numba 0.49.0
17+
import numba.numpy_support as _numpy_support
18+
from numba.targets.imputils import impl_ret_borrowed, lower_constant
19+
from numba import cgutils, types
20+
21+
from .._multivector import MultiVector
22+
23+
from ._layout import LayoutType
24+
25+
__all__ = ['MultiVectorType']
26+
27+
28+
class MultiVectorType(types.Type):
29+
def __init__(self, dtype: np.dtype):
30+
assert isinstance(dtype, np.dtype)
31+
self.dtype = dtype
32+
super().__init__(name='MultiVector[{!r}]'.format(numba.from_dtype(dtype)))
33+
34+
@property
35+
def key(self):
36+
return self.dtype
37+
38+
@property
39+
def value_type(self):
40+
return numba.from_dtype(self.dtype)[:]
41+
42+
@property
43+
def layout_type(self):
44+
return LayoutType()
45+
46+
47+
# The docs say we should use register a function to determine the numba type
48+
# with `@numba.extending.typeof_impl.register(MultiVector)`, but this is way
49+
# too slow (https://github.com/numba/numba/issues/5839). Instead, we use the
50+
# undocumented `_numba_type_` attribute, and use our own cache. In future
51+
# this may need to be a weak cache, but for now the objects are tiny anyway.
52+
_cache = {}
53+
54+
@property
55+
def _numba_type_(self):
56+
dt = self.value.dtype
57+
try:
58+
return _cache[dt]
59+
except KeyError:
60+
ret = _cache[dt] = MultiVectorType(dtype=dt)
61+
return ret
62+
63+
MultiVector._numba_type_ = _numba_type_
64+
65+
66+
@numba.extending.register_model(MultiVectorType)
67+
class MultiVectorModel(numba.extending.models.StructModel):
68+
def __init__(self, dmm, fe_type):
69+
members = [
70+
('layout', fe_type.layout_type),
71+
('value', fe_type.value_type),
72+
]
73+
super().__init__(dmm, fe_type, members)
74+
75+
76+
@numba.extending.type_callable(MultiVector)
77+
def type_MultiVector(context):
78+
def typer(layout, value):
79+
if isinstance(layout, LayoutType) and isinstance(value, types.Array):
80+
return MultiVectorType(_numpy_support.as_dtype(value.dtype))
81+
return typer
82+
83+
84+
@numba.extending.lower_builtin(MultiVector, LayoutType, types.Any)
85+
def impl_MultiVector(context, builder, sig, args):
86+
typ = sig.return_type
87+
layout, value = args
88+
mv = cgutils.create_struct_proxy(typ)(context, builder)
89+
mv.layout = layout
90+
mv.value = value
91+
return impl_ret_borrowed(context, builder, sig.return_type, mv._getvalue())
92+
93+
94+
@lower_constant(MultiVectorType)
95+
def lower_constant_MultiVector(context, builder, typ: MultiVectorType, pyval: MultiVector):
96+
value = context.get_constant_generic(builder, typ.value_type, pyval.value)
97+
layout = context.get_constant_generic(builder, typ.layout_type, pyval.layout)
98+
return impl_ret_borrowed(
99+
context,
100+
builder,
101+
typ,
102+
cgutils.pack_struct(builder, (layout, value)),
103+
)
104+
105+
106+
@numba.extending.unbox(MultiVectorType)
107+
def unbox_MultiVector(typ: MultiVectorType, obj: MultiVector, c) -> NativeValue:
108+
value = c.pyapi.object_getattr_string(obj, "value")
109+
layout = c.pyapi.object_getattr_string(obj, "layout")
110+
mv = cgutils.create_struct_proxy(typ)(c.context, c.builder)
111+
mv.layout = c.unbox(typ.layout_type, layout).value
112+
mv.value = c.unbox(typ.value_type, value).value
113+
c.pyapi.decref(value)
114+
c.pyapi.decref(layout)
115+
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
116+
return NativeValue(mv._getvalue(), is_error=is_error)
117+
118+
119+
@numba.extending.box(MultiVectorType)
120+
def box_MultiVector(typ: MultiVectorType, val: NativeValue, c) -> MultiVector:
121+
mv = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
122+
mv_obj = c.box(typ.value_type, mv.value)
123+
layout_obj = c.box(typ.layout_type, mv.layout)
124+
125+
# All the examples use `c.pyapi.unserialize(c.pyapi.serialize_object(MultiVector))` here.
126+
# Doing so is much slower, as it incurs pickle. This is probably safe.
127+
class_obj_ptr = c.context.add_dynamic_addr(c.builder, id(MultiVector), info=MultiVector.__name__)
128+
class_obj = c.builder.bitcast(class_obj_ptr, c.pyapi.pyobj)
129+
res = c.pyapi.call_function_objargs(class_obj, (layout_obj, mv_obj))
130+
c.pyapi.decref(mv_obj)
131+
c.pyapi.decref(layout_obj)
132+
return res
133+
134+
135+
numba.extending.make_attribute_wrapper(MultiVectorType, 'value', 'value')
136+
numba.extending.make_attribute_wrapper(MultiVectorType, 'layout', 'layout')
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numba
2+
3+
from clifford.g3c import layout, e1, e2
4+
import clifford as cf
5+
6+
7+
@numba.njit
8+
def identity(x):
9+
return x
10+
11+
12+
class TestBasic:
13+
""" Test very simple construction and field access """
14+
15+
def test_roundtrip_layout(self):
16+
layout_r = identity(layout)
17+
assert type(layout_r) is type(layout)
18+
assert layout_r is layout
19+
20+
def test_roundtrip_mv(self):
21+
e1_r = identity(e1)
22+
assert type(e1_r) is type(e1_r)
23+
24+
# mvs are values, and not preserved by identity
25+
assert e1_r.layout is e1.layout
26+
assert e1_r == e1
27+
28+
def test_piecewise_construction(self):
29+
@numba.njit
30+
def negate(a):
31+
return cf.MultiVector(a.layout, -a.value)
32+
33+
n_e1 = negate(e1)
34+
assert n_e1.layout is e1.layout
35+
assert n_e1 == -e1
36+
37+
@numba.njit
38+
def add(a, b):
39+
return cf.MultiVector(a.layout, a.value + b.value)
40+
41+
ab = add(e1, e2)
42+
assert ab == e1 + e2
43+
assert ab.layout is e1.layout
44+
45+
def test_constant_multivector(self):
46+
@numba.njit
47+
def add_e1(a):
48+
return cf.MultiVector(a.layout, a.value + e1.value)
49+
50+
assert add_e1(e2) == e1 + e2

0 commit comments

Comments
 (0)