Skip to content

Commit 42c3cad

Browse files
committed
use nanobind ndarray to make use of bfloat16 weights possible (without
coping the data)
1 parent edef63c commit 42c3cad

File tree

2 files changed

+168
-1
lines changed

2 files changed

+168
-1
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

+59-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <cstddef>
910
#include <cstdint>
11+
#include <memory>
1012
#include <optional>
1113
#include <string>
1214
#include <string_view>
@@ -18,6 +20,9 @@
1820
#include "mlir-c/BuiltinTypes.h"
1921
#include "mlir/Bindings/Python/NanobindAdaptors.h"
2022
#include "mlir/Bindings/Python/Nanobind.h"
23+
#include "nanobind/nanobind.h"
24+
#include "nanobind/ndarray.h"
25+
#include "pytypedefs.h"
2126
#include "llvm/ADT/ScopeExit.h"
2227
#include "llvm/Support/raw_ostream.h"
2328

@@ -1494,14 +1499,67 @@ class PyDenseResourceElementsAttribute
14941499
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
14951500
}
14961501

1502+
static PyDenseResourceElementsAttribute
1503+
getFromNdarray(nb::ndarray<nb::any_contig>& buffer, const std::string &name, const PyType &type,
1504+
std::optional<size_t> alignment, bool isMutable,
1505+
DefaultingPyMlirContext contextWrapper) {
1506+
if (!mlirTypeIsAShaped(type)) {
1507+
throw std::invalid_argument(
1508+
"Constructing a DenseResourceElementsAttr requires a ShapedType.");
1509+
}
1510+
1511+
// Do not request any conversions as we must ensure to use caller
1512+
// managed memory.
1513+
//nb::ndarray<nb::any_contig>* view = new nb::ndarray<nb::any_contig>(buffer);
1514+
//std::unique_ptr<nb::ndarray<nb::any_contig>> view (new nb::ndarray<nb::any_contig>(buffer));
1515+
//std::unique_ptr<nb::ndarray<nb::any_contig>> view = std::make_unique<nb::ndarray<nb::any_contig>>(buffer);
1516+
nb::ndarray<nb::any_contig> *view = new nb::ndarray<nb::any_contig>(buffer);
1517+
if (!view->is_valid()) {
1518+
throw std::invalid_argument("The buffer should not be a nullptr.");
1519+
}
1520+
1521+
// Infer alignment to be the stride of one element if not explicit.
1522+
size_t inferredAlignment;
1523+
if (alignment)
1524+
inferredAlignment = *alignment;
1525+
else
1526+
inferredAlignment = view->stride_ptr()[view->ndim() - 1];
1527+
1528+
// The userData is a nb::ndarray<nb::any_contig>* that the deleter owns.
1529+
auto deleter = [](void *userData, const void *data, size_t size,
1530+
size_t align) {
1531+
nb::ndarray<nb::any_contig> *ownedView= static_cast<nb::ndarray<nb::any_contig> *>(userData);
1532+
delete ownedView;
1533+
};
1534+
1535+
size_t rawBufferSize = view->size() * view->itemsize();
1536+
MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1537+
type, toMlirStringRef(name), view->data(), rawBufferSize,
1538+
inferredAlignment, isMutable, deleter, static_cast<void *>(view));
1539+
if (mlirAttributeIsNull(attr)) {
1540+
throw std::invalid_argument(
1541+
"DenseResourceElementsAttr could not be constructed from the given "
1542+
"buffer. "
1543+
"This may mean that the Python buffer layout does not match that "
1544+
"MLIR expected layout and is a bug.");
1545+
}
1546+
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1547+
}
1548+
14971549
static void bindDerived(ClassTy &c) {
14981550
c.def_static(
14991551
"get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
15001552
nb::arg("array"), nb::arg("name"), nb::arg("type"),
15011553
nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
15021554
nb::arg("context").none() = nb::none(),
15031555
kDenseResourceElementsAttrGetFromBufferDocstring);
1504-
}
1556+
c.def_static(
1557+
"get_from_ndarray", PyDenseResourceElementsAttribute::getFromNdarray,
1558+
nb::arg("array"), nb::arg("name"), nb::arg("type"),
1559+
nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
1560+
nb::arg("context").none() = nb::none(),
1561+
kDenseResourceElementsAttrGetFromBufferDocstring);
1562+
}
15051563
};
15061564

15071565
class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {

mlir/test/python/ir/array_attributes.py

+109
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from mlir.ir import *
77
import numpy as np
88
import weakref
9+
import ctypes
910

1011

1112
def run(f):
@@ -617,3 +618,111 @@ def test_attribute(context, mview):
617618
# CHECK: BACKING MEMORY DELETED
618619
# CHECK: EXIT FUNCTION
619620
print("EXIT FUNCTION")
621+
622+
623+
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayI32
624+
@run
625+
def testGetDenseResourceElementsAttrNdarrayI32():
626+
class DLPackWrapper:
627+
def __init__(self, array: np.ndarray):
628+
self.dlpack_capsule = array.__dlpack__()
629+
630+
def __del__(self):
631+
print("BACKING MEMORY DELETED")
632+
633+
def get_capsule(self):
634+
return self.dlpack_capsule
635+
636+
context = Context()
637+
mview_int32 = DLPackWrapper(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
638+
639+
def test_attribute_int32(context, mview_int32):
640+
with context, Location.unknown():
641+
element_type = IntegerType.get_signless(32)
642+
tensor_type = RankedTensorType.get((2, 3), element_type)
643+
resource = DenseResourceElementsAttr.get_from_ndarray(
644+
mview_int32.get_capsule(), "from_py", tensor_type
645+
)
646+
module = Module.parse("module {}")
647+
module.operation.attributes["test.resource"] = resource
648+
# CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32>
649+
# CHECK: from_py: "0x01000000010000000200000003000000040000000500000006000000"
650+
print(module)
651+
652+
# Verifies type casting.
653+
# CHECK: dense_resource<from_py> : tensor<2x3xi32>
654+
print(
655+
DenseResourceElementsAttr(module.operation.attributes["test.resource"])
656+
)
657+
658+
test_attribute_int32(context, mview_int32)
659+
del mview_int32
660+
gc.collect()
661+
# CHECK: BACKING MEMORY DELETED
662+
# CHECK: FREEING CONTEXT
663+
print("FREEING CONTEXT")
664+
context = None
665+
gc.collect()
666+
# CHECK: EXIT FUNCTION
667+
print("EXIT FUNCTION")
668+
669+
670+
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayF32
671+
@run
672+
def testGetDenseResourceElementsAttrNdarrayF32():
673+
class DLPackWrapper:
674+
def __init__(self, array: np.ndarray):
675+
self.dlpack_capsule = array.__dlpack__()
676+
677+
def __del__(self):
678+
print("BACKING MEMORY DELETED")
679+
680+
def get_capsule(self):
681+
return self.dlpack_capsule
682+
683+
context = Context()
684+
mview_float32 = DLPackWrapper(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))
685+
686+
def test_attribute_float32(context, mview_float32):
687+
with context, Location.unknown():
688+
element_type = FloatAttr.get_f32(32.0)
689+
tensor_type = RankedTensorType.get((2, 3), element_type.type)
690+
resource = DenseResourceElementsAttr.get_from_ndarray(
691+
mview_float32.get_capsule(), "from_py", tensor_type
692+
)
693+
module = Module.parse("module {}")
694+
module.operation.attributes["test.resource"] = resource
695+
# CHECK: test.resource = dense_resource<from_py> : tensor<2x3xf32>
696+
# CHECK: from_py: "0x010000000000803F0000004000004040000080400000A0400000C040"
697+
print(module)
698+
699+
# Verifies type casting.
700+
# CHECK: dense_resource<from_py> : tensor<2x3xf32>
701+
print(
702+
DenseResourceElementsAttr(module.operation.attributes["test.resource"])
703+
)
704+
705+
test_attribute_float32(context, mview_float32)
706+
del mview_float32
707+
gc.collect()
708+
# CHECK: BACKING MEMORY DELETED
709+
# CHECK: FREEING CONTEXT
710+
print("FREEING CONTEXT")
711+
context = None
712+
gc.collect()
713+
# CHECK: EXIT FUNCTION
714+
print("EXIT FUNCTION")
715+
716+
717+
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNonShapedType
718+
@run
719+
def testGetDenseResourceElementsAttrNonShapedType():
720+
with Context(), Location.unknown():
721+
mview = np.array([1], dtype=np.int32).__dlpack__()
722+
t = F32Type.get()
723+
724+
try:
725+
attr = DenseResourceElementsAttr.get_from_ndarray(mview, "from_py", t)
726+
except ValueError as e:
727+
# CHECK: Constructing a DenseResourceElementsAttr requires a ShapedType.
728+
print(e)

0 commit comments

Comments
 (0)