Skip to content

Commit 6bbd6cd

Browse files
authored
Merge pull request #331 from eric-wieser/layout-in-type
numba: Change layout to be stored in type itself
2 parents d45b964 + bdd0276 commit 6bbd6cd

File tree

5 files changed

+97
-62
lines changed

5 files changed

+97
-62
lines changed

clifford/_layout_helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
99
"""
1010

11-
from typing import TypeVar, Generic, Sequence, Tuple, List
11+
from typing import TypeVar, Generic, Sequence, Tuple, List, Optional
1212
import numpy as np
1313
import functools
1414
import operator
@@ -252,3 +252,10 @@ def __reduce__(self):
252252
return __class__, (self._n,)
253253
else:
254254
return __class__, (self._n, self._first_index)
255+
256+
257+
def layout_short_name(layout) -> Optional[str]:
258+
""" helper to get the short name of a layout """
259+
if hasattr(layout, '__name__') and '__module__' in layout.__dict__:
260+
return "{l.__module__}.{l.__name__}".format(l=layout)
261+
return None

clifford/_multivector.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import clifford as cf
88
from . import general_exp
99
from . import _settings
10+
from ._layout_helpers import layout_short_name
1011

1112

1213
class MultiVector(object):
@@ -507,11 +508,12 @@ def __repr__(self) -> str:
507508
else:
508509
dtype_str = None
509510

510-
if hasattr(self.layout, '__name__') and '__module__' in self.layout.__dict__:
511-
fmt = "{l.__module__}.{l.__name__}.MultiVector({v!r}{d})"
511+
l_name = layout_short_name(self.layout)
512+
args = dict(v=list(self.value), d=dtype_str)
513+
if l_name is not None:
514+
return "{l}.MultiVector({v!r}{d})".format(l=l_name, **args)
512515
else:
513-
fmt = "{l!r}.MultiVector({v!r}{d})"
514-
return fmt.format(l=self.layout, v=list(self.value), d=dtype_str)
516+
return "{l!r}.MultiVector({v!r}{d})".format(l=self.layout, **args)
515517

516518
def _repr_pretty_(self, p, cycle):
517519
if cycle:
@@ -521,8 +523,9 @@ def _repr_pretty_(self, p, cycle):
521523
p.text(str(self))
522524
return
523525

524-
if hasattr(self.layout, '__name__') and '__module__' in self.layout.__dict__:
525-
prefix = "{l.__module__}.{l.__name__}.MultiVector(".format(l=self.layout)
526+
l_name = layout_short_name(self.layout)
527+
if l_name is not None:
528+
prefix = "{}.MultiVector(".format(l_name)
526529
include_layout = False
527530
else:
528531
include_layout = True

clifford/numba/_layout.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,75 @@
11
import numba
22
import numba.extending
3+
from numba.extending import NativeValue
4+
import llvmlite.ir
35
try:
46
# module locations as of numba 0.49.0
5-
from numba.core import cgutils, types
7+
from numba.core import types
68
from numba.core.imputils import lower_constant
79
except ImportError:
810
# module locations prior to numba 0.49.0
9-
from numba import cgutils, types
11+
from numba import types
1012
from numba.targets.imputils import lower_constant
1113

12-
from .._layout import Layout
14+
from .._layout import Layout, _cached_property
15+
from .._layout_helpers import layout_short_name
16+
from .._multivector import MultiVector
1317

1418

15-
opaque_layout = types.Opaque('Opaque(Layout)')
19+
# In future we want to store some of the layout in the type (the `order` etc),
20+
# but store the `names` in the layout instances, so that we can reuse jitted
21+
# functions across different basis vector names.
1622

17-
18-
class LayoutType(types.Type):
19-
def __init__(self):
20-
super().__init__("LayoutType")
23+
class LayoutType(types.Dummy):
24+
def __init__(self, layout):
25+
self.obj = layout
26+
# cache of multivector types for this layout
27+
self._cache = {}
28+
layout_name = layout_short_name(layout)
29+
if layout_name is not None:
30+
name = "LayoutType({})".format(layout_name)
31+
else:
32+
name = "LayoutType({!r})".format(layout)
33+
super().__init__(name)
2134

2235

2336
@numba.extending.register_model(LayoutType)
24-
class LayoutModel(numba.extending.models.StructModel):
25-
def __init__(self, dmm, fe_typ):
26-
members = [
27-
('obj', opaque_layout),
28-
]
29-
super().__init__(dmm, fe_typ, members)
37+
class LayoutModel(numba.extending.models.OpaqueModel):
38+
pass
3039

40+
# The docs say we should use register a function to determine the numba type
41+
# with `@numba.extending.typeof_impl.register(LayoutType)`, but this is way
42+
# too slow (https://github.com/numba/numba/issues/5839). Instead, we use the
43+
# undocumented `_numba_type_` attribute, and use our own cache.
3144

32-
@numba.extending.typeof_impl.register(Layout)
33-
def _typeof_Layout(val: Layout, c) -> LayoutType:
34-
return LayoutType()
45+
@_cached_property
46+
def _numba_type_(self):
47+
return LayoutType(self)
3548

49+
Layout._numba_type_ = _numba_type_
3650

37-
# Derived from the `Dispatcher` boxing
3851

3952
@lower_constant(LayoutType)
40-
def lower_constant_dispatcher(context, builder, typ, pyval):
41-
layout = cgutils.create_struct_proxy(typ)(context, builder)
42-
layout.obj = context.add_dynamic_addr(builder, id(pyval), info=type(pyval).__name__)
43-
return layout._getvalue()
53+
def lower_constant_Layout(context, builder, typ: LayoutType, pyval: Layout) -> llvmlite.ir.Value:
54+
return context.get_dummy_value()
4455

4556

4657
@numba.extending.unbox(LayoutType)
47-
def unbox_Layout(typ, obj, context):
48-
layout = cgutils.create_struct_proxy(typ)(context.context, context.builder)
49-
layout.obj = obj
50-
return numba.extending.NativeValue(layout._getvalue())
58+
def unbox_Layout(typ: LayoutType, obj: Layout, c) -> NativeValue:
59+
return NativeValue(c.context.get_dummy_value())
5160

61+
# Derived from the `Dispatcher` boxing
5262

5363
@numba.extending.box(LayoutType)
54-
def box_Layout(typ, val, context):
55-
val = cgutils.create_struct_proxy(typ)(context.context, context.builder, value=val)
56-
obj = val.obj
57-
context.pyapi.incref(obj)
64+
def box_Layout(typ: LayoutType, val: llvmlite.ir.Value, c) -> Layout:
65+
obj = c.context.add_dynamic_addr(c.builder, id(typ.obj), info=typ.name)
66+
c.pyapi.incref(obj)
5867
return obj
68+
69+
# methods
70+
71+
@numba.extending.overload_method(LayoutType, 'MultiVector')
72+
def Layout_MultiVector(self, value):
73+
def impl(self, value):
74+
return MultiVector(self, value)
75+
return impl

clifford/numba/_multivector.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
44
For now, this just supports .value wrapping / unwrapping
55
"""
6-
import numpy as np
76
import numba
87
from numba.extending import NativeValue
8+
import llvmlite.ir
99

1010
try:
1111
# module locations as of numba 0.49.0
@@ -26,38 +26,43 @@
2626

2727

2828
class MultiVectorType(types.Type):
29-
def __init__(self, dtype: np.dtype):
30-
assert isinstance(dtype, np.dtype)
31-
self.dtype = dtype
32-
super().__init__(name='MultiVector[{!r}]'.format(numba.from_dtype(dtype)))
29+
def __init__(self, layout: LayoutType, dtype: types.DType):
30+
self.layout_type = layout
31+
self._scalar_type = dtype
32+
super().__init__(name='MultiVector({!r}, {!r})'.format(
33+
self.layout_type, self._scalar_type
34+
))
3335

3436
@property
3537
def key(self):
36-
return self.dtype
38+
return self.layout_type, self._scalar_type
3739

3840
@property
3941
def value_type(self):
40-
return numba.from_dtype(self.dtype)[:]
41-
42-
@property
43-
def layout_type(self):
44-
return LayoutType()
42+
return self._scalar_type[:]
4543

4644

4745
# The docs say we should use register a function to determine the numba type
4846
# with `@numba.extending.typeof_impl.register(MultiVector)`, but this is way
4947
# too slow (https://github.com/numba/numba/issues/5839). Instead, we use the
5048
# undocumented `_numba_type_` attribute, and use our own cache. In future
5149
# this may need to be a weak cache, but for now the objects are tiny anyway.
52-
_cache = {}
5350

5451
@property
5552
def _numba_type_(self):
53+
layout_type = self.layout._numba_type_
54+
55+
cache = layout_type._cache
5656
dt = self.value.dtype
57+
58+
# now use the dtype to key that cache.
5759
try:
58-
return _cache[dt]
60+
return cache[dt]
5961
except KeyError:
60-
ret = _cache[dt] = MultiVectorType(dtype=dt)
62+
# Computing and hashing `dtype_type` is slow, so we do not use it as a
63+
# hash key. The raw numpy dtype is much faster to use as a key.
64+
dtype_type = _numpy_support.from_dtype(dt)
65+
ret = cache[dt] = MultiVectorType(layout_type, dtype_type)
6166
return ret
6267

6368
MultiVector._numba_type_ = _numba_type_
@@ -77,7 +82,7 @@ def __init__(self, dmm, fe_type):
7782
def type_MultiVector(context):
7883
def typer(layout, value):
7984
if isinstance(layout, LayoutType) and isinstance(value, types.Array):
80-
return MultiVectorType(_numpy_support.as_dtype(value.dtype))
85+
return MultiVectorType(layout, value.dtype)
8186
return typer
8287

8388

@@ -92,15 +97,11 @@ def impl_MultiVector(context, builder, sig, args):
9297

9398

9499
@lower_constant(MultiVectorType)
95-
def lower_constant_MultiVector(context, builder, typ: MultiVectorType, pyval: MultiVector):
96-
value = context.get_constant_generic(builder, typ.value_type, pyval.value)
97-
layout = context.get_constant_generic(builder, typ.layout_type, pyval.layout)
98-
return impl_ret_borrowed(
99-
context,
100-
builder,
101-
typ,
102-
cgutils.pack_struct(builder, (layout, value)),
103-
)
100+
def lower_constant_MultiVector(context, builder, typ: MultiVectorType, pyval: MultiVector) -> llvmlite.ir.Value:
101+
mv = cgutils.create_struct_proxy(typ)(context, builder)
102+
mv.value = context.get_constant_generic(builder, typ.value_type, pyval.value)
103+
mv.layout = context.get_constant_generic(builder, typ.layout_type, pyval.layout)
104+
return mv._getvalue()
104105

105106

106107
@numba.extending.unbox(MultiVectorType)
@@ -117,7 +118,7 @@ def unbox_MultiVector(typ: MultiVectorType, obj: MultiVector, c) -> NativeValue:
117118

118119

119120
@numba.extending.box(MultiVectorType)
120-
def box_MultiVector(typ: MultiVectorType, val: NativeValue, c) -> MultiVector:
121+
def box_MultiVector(typ: MultiVectorType, val: llvmlite.ir.Value, c) -> MultiVector:
121122
mv = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
122123
mv_obj = c.box(typ.value_type, mv.value)
123124
layout_obj = c.box(typ.layout_type, mv.layout)

clifford/test/test_numba_extensions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,10 @@ def add_e1(a):
4848
return cf.MultiVector(a.layout, a.value + e1.value)
4949

5050
assert add_e1(e2) == e1 + e2
51+
52+
def test_multivector_shorthand(self):
53+
@numba.njit
54+
def double(a):
55+
return a.layout.MultiVector(a.value*2)
56+
57+
assert double(e2) == 2 * e2

0 commit comments

Comments
 (0)