Skip to content

Commit f49e81e

Browse files
committed
use nanobind ndarray to make use of bfloat16 weights possible (without
coping the data)
1 parent edef63c commit f49e81e

File tree

2 files changed

+171
-1
lines changed

2 files changed

+171
-1
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

+54-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <cstddef>
910
#include <cstdint>
1011
#include <optional>
1112
#include <string>
@@ -18,6 +19,9 @@
1819
#include "mlir-c/BuiltinTypes.h"
1920
#include "mlir/Bindings/Python/NanobindAdaptors.h"
2021
#include "mlir/Bindings/Python/Nanobind.h"
22+
#include "nanobind/nanobind.h"
23+
#include "nanobind/ndarray.h"
24+
#include "pytypedefs.h"
2125
#include "llvm/ADT/ScopeExit.h"
2226
#include "llvm/Support/raw_ostream.h"
2327

@@ -1494,14 +1498,63 @@ class PyDenseResourceElementsAttribute
14941498
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
14951499
}
14961500

1501+
static PyDenseResourceElementsAttribute
1502+
getFromNdarray(nb::ndarray<nb::any_contig>& buffer, const std::string &name, const PyType &type,
1503+
std::optional<size_t> alignment, bool isMutable,
1504+
DefaultingPyMlirContext contextWrapper) {
1505+
if (!mlirTypeIsAShaped(type)) {
1506+
throw std::invalid_argument(
1507+
"Constructing a DenseResourceElementsAttr requires a ShapedType.");
1508+
}
1509+
1510+
// Do not request any conversions as we must ensure to use caller
1511+
// managed memory.
1512+
if (!buffer.is_valid()) {
1513+
throw std::invalid_argument("The buffer should not be a nullptr.");
1514+
}
1515+
1516+
// Infer alignment to be the stride of one element if not explicit.
1517+
size_t inferredAlignment;
1518+
if (alignment)
1519+
inferredAlignment = *alignment;
1520+
else
1521+
inferredAlignment = buffer.stride_ptr()[buffer.ndim() - 1];
1522+
1523+
// The userData is a Py_buffer* that the deleter owns.
1524+
auto deleter = [](void *userData, const void *data, size_t size,
1525+
size_t align) {
1526+
nb::ndarray<nb::any_contig> *ownedBuffer= static_cast<nb::ndarray<nb::any_contig> *>(userData);
1527+
//delete ownedBuffer;
1528+
};
1529+
1530+
size_t rawBufferSize = buffer.size();
1531+
MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1532+
type, toMlirStringRef(name), buffer.data(), rawBufferSize,
1533+
inferredAlignment, isMutable, deleter, static_cast<void *>(&buffer)); //deleter, static_cast<void *>(&buffer));
1534+
if (mlirAttributeIsNull(attr)) {
1535+
throw std::invalid_argument(
1536+
"DenseResourceElementsAttr could not be constructed from the given "
1537+
"buffer. "
1538+
"This may mean that the Python buffer layout does not match that "
1539+
"MLIR expected layout and is a bug.");
1540+
}
1541+
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1542+
}
1543+
14971544
static void bindDerived(ClassTy &c) {
14981545
c.def_static(
14991546
"get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
15001547
nb::arg("array"), nb::arg("name"), nb::arg("type"),
15011548
nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
15021549
nb::arg("context").none() = nb::none(),
15031550
kDenseResourceElementsAttrGetFromBufferDocstring);
1504-
}
1551+
c.def_static(
1552+
"get_from_ndarray", PyDenseResourceElementsAttribute::getFromNdarray,
1553+
nb::arg("array"), nb::arg("name"), nb::arg("type"),
1554+
nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
1555+
nb::arg("context").none() = nb::none(),
1556+
kDenseResourceElementsAttrGetFromBufferDocstring);
1557+
}
15051558
};
15061559

15071560
class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {

mlir/test/python/ir/array_attributes.py

+117
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from mlir.ir import *
77
import numpy as np
88
import weakref
9+
import ctypes
910

1011

1112
def run(f):
@@ -617,3 +618,119 @@ def test_attribute(context, mview):
617618
# CHECK: BACKING MEMORY DELETED
618619
# CHECK: EXIT FUNCTION
619620
print("EXIT FUNCTION")
621+
622+
623+
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayI32
624+
@run
625+
def testGetDenseResourceElementsAttrNdarrayI32():
626+
class DLPackWrapper:
627+
def __init__(self, array: np.ndarray):
628+
self.dlpack_capsule = array.__dlpack__()
629+
630+
def __del__(self):
631+
print("BACKING MEMORY DELETED")
632+
633+
def get_capsule(self):
634+
return self.dlpack_capsule
635+
636+
context = Context()
637+
mview_int32 = DLPackWrapper(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
638+
639+
def test_attribute_int32(context, mview_int32):
640+
with context, Location.unknown():
641+
element_type = IntegerType.get_signless(32)
642+
tensor_type = RankedTensorType.get((2, 3), element_type)
643+
resource = DenseResourceElementsAttr.get_from_ndarray(
644+
mview_int32.get_capsule(), "from_py", tensor_type
645+
)
646+
module = Module.parse("module {}")
647+
module.operation.attributes["test.resource"] = resource
648+
# CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32>
649+
# CHECK: from_py: "0x01000000010000000200"
650+
print(module)
651+
652+
# Verifies type casting.
653+
# CHECK: dense_resource<from_py> : tensor<2x3xi32>
654+
print(
655+
DenseResourceElementsAttr(module.operation.attributes["test.resource"])
656+
)
657+
658+
test_attribute_int32(context, mview_int32)
659+
del mview_int32
660+
gc.collect()
661+
# CHECK: FREEING CONTEXT
662+
print("FREEING CONTEXT")
663+
context = None
664+
gc.collect()
665+
# CHECK: EXIT FUNCTION
666+
print("EXIT FUNCTION")
667+
668+
669+
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayF32
670+
@run
671+
def testGetDenseResourceElementsAttrNdarrayF32():
672+
class DLPackWrapper:
673+
def __init__(self, array: np.ndarray):
674+
self.dlpack_capsule = array.__dlpack__()
675+
676+
def __del__(self):
677+
print("BACKING MEMORY DELETED")
678+
679+
def get_capsule(self):
680+
return self.dlpack_capsule
681+
682+
context = Context()
683+
mview_float32 = DLPackWrapper(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))
684+
685+
def test_attribute_float32(context, mview_float32):
686+
with context, Location.unknown():
687+
element_type = FloatAttr.get_f32(32.0)
688+
tensor_type = RankedTensorType.get((2, 3), element_type.type)
689+
resource = DenseResourceElementsAttr.get_from_ndarray(
690+
mview_float32.get_capsule(), "from_py", tensor_type
691+
)
692+
module = Module.parse("module {}")
693+
module.operation.attributes["test.resource"] = resource
694+
# CHECK: test.resource = dense_resource<from_py> : tensor<2x3xf32>
695+
# CHECK: from_py: "0x010000000000803F0000"
696+
print(module)
697+
698+
# Verifies type casting.
699+
# CHECK: dense_resource<from_py> : tensor<2x3xf32>
700+
print(
701+
DenseResourceElementsAttr(module.operation.attributes["test.resource"])
702+
)
703+
704+
test_attribute_float32(context, mview_float32)
705+
del mview_float32
706+
gc.collect()
707+
# CHECK: FREEING CONTEXT
708+
print("FREEING CONTEXT")
709+
context = None
710+
gc.collect()
711+
# CHECK: EXIT FUNCTION
712+
print("EXIT FUNCTION")
713+
714+
715+
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNonShapedType
716+
@run
717+
def testGetDenseResourceElementsAttrNonShapedType():
718+
class DLPackWrapper:
719+
def __init__(self, array: np.ndarray):
720+
self.dlpack_capsule = array.__dlpack__()
721+
722+
def __del__(self):
723+
print("BACKING MEMORY DELETED")
724+
725+
def get_capsule(self):
726+
return self.dlpack_capsule
727+
728+
with Context(), Location.unknown():
729+
mview = DLPackWrapper(np.array([1], dtype=np.int32))
730+
t = F32Type.get()
731+
732+
try:
733+
attr = DenseResourceElementsAttr.get_from_ndarray(mview.get_capsule(), "from_py", t)
734+
except ValueError as e:
735+
# CHECK: Constructing a DenseResourceElementsAttr requires a ShapedType.
736+
print(e)

0 commit comments

Comments
 (0)