Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 893acad

Browse files
committedMar 17, 2025·
use nanobind ndarray to make use of bfloat16 weights possible (without
coping the data)
1 parent edef63c commit 893acad

File tree

3 files changed

+206
-1
lines changed

3 files changed

+206
-1
lines changed
 

‎mlir/lib/Bindings/Python/IRAttributes.cpp

+59-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <cstddef>
910
#include <cstdint>
11+
#include <memory>
1012
#include <optional>
1113
#include <string>
1214
#include <string_view>
@@ -18,6 +20,9 @@
1820
#include "mlir-c/BuiltinTypes.h"
1921
#include "mlir/Bindings/Python/NanobindAdaptors.h"
2022
#include "mlir/Bindings/Python/Nanobind.h"
23+
#include "nanobind/nanobind.h"
24+
#include "nanobind/ndarray.h"
25+
#include "pytypedefs.h"
2126
#include "llvm/ADT/ScopeExit.h"
2227
#include "llvm/Support/raw_ostream.h"
2328

@@ -1494,14 +1499,67 @@ class PyDenseResourceElementsAttribute
14941499
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
14951500
}
14961501

1502+
static PyDenseResourceElementsAttribute
1503+
getFromNdarray(const nb::ndarray<nb::ro, nb::any_contig> &buffer, const std::string &name, const PyType &type,
1504+
std::optional<size_t> alignment, bool isMutable,
1505+
DefaultingPyMlirContext contextWrapper) {
1506+
if (!mlirTypeIsAShaped(type)) {
1507+
throw std::invalid_argument(
1508+
"Constructing a DenseResourceElementsAttr requires a ShapedType.");
1509+
}
1510+
1511+
// Do not request any conversions as we must ensure to use caller
1512+
// managed memory.
1513+
//nb::ndarray<nb::any_contig>* view = new nb::ndarray<nb::any_contig>(buffer);
1514+
//std::unique_ptr<nb::ndarray<nb::any_contig>> view (new nb::ndarray<nb::any_contig>(buffer));
1515+
//std::unique_ptr<nb::ndarray<nb::any_contig>> view = std::make_unique<nb::ndarray<nb::any_contig>>(buffer);
1516+
nb::ndarray<nb::any_contig> *view = new nb::ndarray<nb::any_contig>(buffer);
1517+
if (!view->is_valid()) {
1518+
throw std::invalid_argument("The buffer should not be a nullptr.");
1519+
}
1520+
1521+
// Infer alignment to be the stride of one element if not explicit.
1522+
size_t inferredAlignment;
1523+
if (alignment)
1524+
inferredAlignment = *alignment;
1525+
else
1526+
inferredAlignment = view->stride_ptr()[view->ndim() - 1];
1527+
1528+
// The userData is a nb::ndarray<nb::any_contig>* that the deleter owns.
1529+
auto deleter = [](void *userData, const void *data, size_t size,
1530+
size_t align) {
1531+
nb::ndarray<nb::any_contig> *ownedView= static_cast<nb::ndarray<nb::any_contig> *>(userData);
1532+
delete ownedView;
1533+
};
1534+
1535+
size_t rawBufferSize = view->size() * view->itemsize();
1536+
MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1537+
type, toMlirStringRef(name), view->data(), rawBufferSize,
1538+
inferredAlignment, isMutable, deleter, static_cast<void *>(view));
1539+
if (mlirAttributeIsNull(attr)) {
1540+
throw std::invalid_argument(
1541+
"DenseResourceElementsAttr could not be constructed from the given "
1542+
"buffer. "
1543+
"This may mean that the Python buffer layout does not match that "
1544+
"MLIR expected layout and is a bug.");
1545+
}
1546+
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1547+
}
1548+
14971549
static void bindDerived(ClassTy &c) {
14981550
c.def_static(
14991551
"get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
15001552
nb::arg("array"), nb::arg("name"), nb::arg("type"),
15011553
nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
15021554
nb::arg("context").none() = nb::none(),
15031555
kDenseResourceElementsAttrGetFromBufferDocstring);
1504-
}
1556+
c.def_static(
1557+
"get_from_ndarray", PyDenseResourceElementsAttribute::getFromNdarray,
1558+
nb::arg("array"), nb::arg("name"), nb::arg("type"),
1559+
nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
1560+
nb::arg("context").none() = nb::none(),
1561+
kDenseResourceElementsAttrGetFromBufferDocstring);
1562+
}
15051563
};
15061564

15071565
class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {

‎mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

+38
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ else:
180180
class Buffer(abc.ABC):
181181
pass
182182

183+
class DLPack(abc.ABC):
184+
pass
185+
183186
class _OperationBase:
184187
@overload
185188
def __eq__(self, arg0: _OperationBase) -> bool: ...
@@ -1350,6 +1353,41 @@ class DenseResourceElementsAttr(Attribute):
13501353
type or if the buffer does not meet expectations.
13511354
"""
13521355
@staticmethod
1356+
def get_from_ndarray(
1357+
array: DLPack,
1358+
name: str,
1359+
type: Type,
1360+
alignment: int | None = None,
1361+
is_mutable: bool = False,
1362+
context: Context | None = None,
1363+
) -> DenseResourceElementsAttr:
1364+
"""
1365+
Gets a DenseResourceElementsAttr from a Python buffer or DLPack C structure
1366+
wrapped in a PyCapsule.
1367+
1368+
This function does minimal validation or massaging of the data, and it is
1369+
up to the caller to ensure that the buffer meets the characteristics
1370+
implied by the shape.
1371+
1372+
The backing buffer and any user objects will be retained for the lifetime
1373+
of the resource blob. This is typically bounded to the context but the
1374+
resource can have a shorter lifespan depending on how it is used in
1375+
subsequent processing.
1376+
1377+
Args:
1378+
array: The buffer or DLPack to convert.
1379+
name: Name to provide to the resource (may be changed upon collision).
1380+
type: The explicit ShapedType to construct the attribute with.
1381+
context: Explicit context, if not from context manager.
1382+
1383+
Returns:
1384+
DenseResourceElementsAttr on success.
1385+
1386+
Raises:
1387+
ValueError: If the type of the buffer or array cannot be matched to an MLIR
1388+
type or if the buffer does not meet expectations.
1389+
"""
1390+
@staticmethod
13531391
def isinstance(other: Attribute) -> bool: ...
13541392
def __init__(self, cast_from_attr: Attribute) -> None: ...
13551393
@property

‎mlir/test/python/ir/array_attributes.py

+109
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from mlir.ir import *
77
import numpy as np
88
import weakref
9+
import ctypes
910

1011

1112
def run(f):
@@ -617,3 +618,111 @@ def test_attribute(context, mview):
617618
# CHECK: BACKING MEMORY DELETED
618619
# CHECK: EXIT FUNCTION
619620
print("EXIT FUNCTION")
621+
622+
623+
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayI32
624+
@run
625+
def testGetDenseResourceElementsAttrNdarrayI32():
626+
class DLPackWrapper:
627+
def __init__(self, array: np.ndarray):
628+
self.dlpack_capsule = array.__dlpack__()
629+
630+
def __del__(self):
631+
print("BACKING MEMORY DELETED")
632+
633+
def get_capsule(self):
634+
return self.dlpack_capsule
635+
636+
context = Context()
637+
mview_int32 = DLPackWrapper(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))
638+
639+
def test_attribute_int32(context, mview_int32):
640+
with context, Location.unknown():
641+
element_type = IntegerType.get_signless(32)
642+
tensor_type = RankedTensorType.get((2, 3), element_type)
643+
resource = DenseResourceElementsAttr.get_from_ndarray(
644+
mview_int32.get_capsule(), "from_py", tensor_type
645+
)
646+
module = Module.parse("module {}")
647+
module.operation.attributes["test.resource"] = resource
648+
# CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32>
649+
# CHECK: from_py: "0x01000000010000000200000003000000040000000500000006000000"
650+
print(module)
651+
652+
# Verifies type casting.
653+
# CHECK: dense_resource<from_py> : tensor<2x3xi32>
654+
print(
655+
DenseResourceElementsAttr(module.operation.attributes["test.resource"])
656+
)
657+
658+
test_attribute_int32(context, mview_int32)
659+
del mview_int32
660+
gc.collect()
661+
# CHECK: BACKING MEMORY DELETED
662+
# CHECK: FREEING CONTEXT
663+
print("FREEING CONTEXT")
664+
context = None
665+
gc.collect()
666+
# CHECK: EXIT FUNCTION
667+
print("EXIT FUNCTION")
668+
669+
670+
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayF32
671+
@run
672+
def testGetDenseResourceElementsAttrNdarrayF32():
673+
class DLPackWrapper:
674+
def __init__(self, array: np.ndarray):
675+
self.dlpack_capsule = array.__dlpack__()
676+
677+
def __del__(self):
678+
print("BACKING MEMORY DELETED")
679+
680+
def get_capsule(self):
681+
return self.dlpack_capsule
682+
683+
context = Context()
684+
mview_float32 = DLPackWrapper(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))
685+
686+
def test_attribute_float32(context, mview_float32):
687+
with context, Location.unknown():
688+
element_type = FloatAttr.get_f32(32.0)
689+
tensor_type = RankedTensorType.get((2, 3), element_type.type)
690+
resource = DenseResourceElementsAttr.get_from_ndarray(
691+
mview_float32.get_capsule(), "from_py", tensor_type
692+
)
693+
module = Module.parse("module {}")
694+
module.operation.attributes["test.resource"] = resource
695+
# CHECK: test.resource = dense_resource<from_py> : tensor<2x3xf32>
696+
# CHECK: from_py: "0x010000000000803F0000004000004040000080400000A0400000C040"
697+
print(module)
698+
699+
# Verifies type casting.
700+
# CHECK: dense_resource<from_py> : tensor<2x3xf32>
701+
print(
702+
DenseResourceElementsAttr(module.operation.attributes["test.resource"])
703+
)
704+
705+
test_attribute_float32(context, mview_float32)
706+
del mview_float32
707+
gc.collect()
708+
# CHECK: BACKING MEMORY DELETED
709+
# CHECK: FREEING CONTEXT
710+
print("FREEING CONTEXT")
711+
context = None
712+
gc.collect()
713+
# CHECK: EXIT FUNCTION
714+
print("EXIT FUNCTION")
715+
716+
717+
# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNonShapedType
718+
@run
719+
def testGetDenseResourceElementsAttrNonShapedType():
720+
with Context(), Location.unknown():
721+
mview = np.array([1], dtype=np.int32).__dlpack__()
722+
t = F32Type.get()
723+
724+
try:
725+
attr = DenseResourceElementsAttr.get_from_ndarray(mview, "from_py", t)
726+
except ValueError as e:
727+
# CHECK: Constructing a DenseResourceElementsAttr requires a ShapedType.
728+
print(e)

0 commit comments

Comments
 (0)
Please sign in to comment.