|
| 1 | +""" |
| 2 | +Numba support for MultiVector objects. |
| 3 | +
|
| 4 | +For now, this just supports .value wrapping / unwrapping |
| 5 | +""" |
| 6 | +import numpy as np |
| 7 | +import numba |
| 8 | +from numba.extending import NativeValue |
| 9 | + |
| 10 | +try: |
| 11 | + # module locations as of numba 0.49.0 |
| 12 | + import numba.np.numpy_support as _numpy_support |
| 13 | + from numba.core.imputils import impl_ret_borrowed, lower_constant |
| 14 | + from numba.core import cgutils, types |
| 15 | +except ImportError: |
| 16 | + # module locations prior to numba 0.49.0 |
| 17 | + import numba.numpy_support as _numpy_support |
| 18 | + from numba.targets.imputils import impl_ret_borrowed, lower_constant |
| 19 | + from numba import cgutils, types |
| 20 | + |
| 21 | +from .._multivector import MultiVector |
| 22 | + |
| 23 | +from ._layout import LayoutType |
| 24 | + |
| 25 | +__all__ = ['MultiVectorType'] |
| 26 | + |
| 27 | + |
| 28 | +class MultiVectorType(types.Type): |
| 29 | + def __init__(self, dtype: np.dtype): |
| 30 | + assert isinstance(dtype, np.dtype) |
| 31 | + self.dtype = dtype |
| 32 | + super().__init__(name='MultiVector[{!r}]'.format(numba.from_dtype(dtype))) |
| 33 | + |
| 34 | + @property |
| 35 | + def key(self): |
| 36 | + return self.dtype |
| 37 | + |
| 38 | + @property |
| 39 | + def value_type(self): |
| 40 | + return numba.from_dtype(self.dtype)[:] |
| 41 | + |
| 42 | + @property |
| 43 | + def layout_type(self): |
| 44 | + return LayoutType() |
| 45 | + |
| 46 | + |
| 47 | +# The docs say we should use register a function to determine the numba type |
| 48 | +# with `@numba.extending.typeof_impl.register(MultiVector)`, but this is way |
| 49 | +# too slow (https://github.com/numba/numba/issues/5839). Instead, we use the |
| 50 | +# undocumented `_numba_type_` attribute, and use our own cache. In future |
| 51 | +# this may need to be a weak cache, but for now the objects are tiny anyway. |
| 52 | +_cache = {} |
| 53 | + |
| 54 | +@property |
| 55 | +def _numba_type_(self): |
| 56 | + dt = self.value.dtype |
| 57 | + try: |
| 58 | + return _cache[dt] |
| 59 | + except KeyError: |
| 60 | + ret = _cache[dt] = MultiVectorType(dtype=dt) |
| 61 | + return ret |
| 62 | + |
| 63 | +MultiVector._numba_type_ = _numba_type_ |
| 64 | + |
| 65 | + |
| 66 | +@numba.extending.register_model(MultiVectorType) |
| 67 | +class MultiVectorModel(numba.extending.models.StructModel): |
| 68 | + def __init__(self, dmm, fe_type): |
| 69 | + members = [ |
| 70 | + ('layout', fe_type.layout_type), |
| 71 | + ('value', fe_type.value_type), |
| 72 | + ] |
| 73 | + super().__init__(dmm, fe_type, members) |
| 74 | + |
| 75 | + |
| 76 | +@numba.extending.type_callable(MultiVector) |
| 77 | +def type_MultiVector(context): |
| 78 | + def typer(layout, value): |
| 79 | + if isinstance(layout, LayoutType) and isinstance(value, types.Array): |
| 80 | + return MultiVectorType(_numpy_support.as_dtype(value.dtype)) |
| 81 | + return typer |
| 82 | + |
| 83 | + |
| 84 | +@numba.extending.lower_builtin(MultiVector, LayoutType, types.Any) |
| 85 | +def impl_MultiVector(context, builder, sig, args): |
| 86 | + typ = sig.return_type |
| 87 | + layout, value = args |
| 88 | + mv = cgutils.create_struct_proxy(typ)(context, builder) |
| 89 | + mv.layout = layout |
| 90 | + mv.value = value |
| 91 | + return impl_ret_borrowed(context, builder, sig.return_type, mv._getvalue()) |
| 92 | + |
| 93 | + |
| 94 | +@lower_constant(MultiVectorType) |
| 95 | +def lower_constant_MultiVector(context, builder, typ: MultiVectorType, pyval: MultiVector): |
| 96 | + value = context.get_constant_generic(builder, typ.value_type, pyval.value) |
| 97 | + layout = context.get_constant_generic(builder, typ.layout_type, pyval.layout) |
| 98 | + return impl_ret_borrowed( |
| 99 | + context, |
| 100 | + builder, |
| 101 | + typ, |
| 102 | + cgutils.pack_struct(builder, (layout, value)), |
| 103 | + ) |
| 104 | + |
| 105 | + |
| 106 | +@numba.extending.unbox(MultiVectorType) |
| 107 | +def unbox_MultiVector(typ: MultiVectorType, obj: MultiVector, c) -> NativeValue: |
| 108 | + value = c.pyapi.object_getattr_string(obj, "value") |
| 109 | + layout = c.pyapi.object_getattr_string(obj, "layout") |
| 110 | + mv = cgutils.create_struct_proxy(typ)(c.context, c.builder) |
| 111 | + mv.layout = c.unbox(typ.layout_type, layout).value |
| 112 | + mv.value = c.unbox(typ.value_type, value).value |
| 113 | + c.pyapi.decref(value) |
| 114 | + c.pyapi.decref(layout) |
| 115 | + is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) |
| 116 | + return NativeValue(mv._getvalue(), is_error=is_error) |
| 117 | + |
| 118 | + |
| 119 | +@numba.extending.box(MultiVectorType) |
| 120 | +def box_MultiVector(typ: MultiVectorType, val: NativeValue, c) -> MultiVector: |
| 121 | + mv = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) |
| 122 | + mv_obj = c.box(typ.value_type, mv.value) |
| 123 | + layout_obj = c.box(typ.layout_type, mv.layout) |
| 124 | + |
| 125 | + # All the examples use `c.pyapi.unserialize(c.pyapi.serialize_object(MultiVector))` here. |
| 126 | + # Doing so is much slower, as it incurs pickle. This is probably safe. |
| 127 | + class_obj_ptr = c.context.add_dynamic_addr(c.builder, id(MultiVector), info=MultiVector.__name__) |
| 128 | + class_obj = c.builder.bitcast(class_obj_ptr, c.pyapi.pyobj) |
| 129 | + res = c.pyapi.call_function_objargs(class_obj, (layout_obj, mv_obj)) |
| 130 | + c.pyapi.decref(mv_obj) |
| 131 | + c.pyapi.decref(layout_obj) |
| 132 | + return res |
| 133 | + |
| 134 | + |
| 135 | +numba.extending.make_attribute_wrapper(MultiVectorType, 'value', 'value') |
| 136 | +numba.extending.make_attribute_wrapper(MultiVectorType, 'layout', 'layout') |
0 commit comments