Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use nanobind ndarray to make use of bfloat16 weights possible (without coping the data) #5

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
//
//===----------------------------------------------------------------------===//

#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
Expand All @@ -18,6 +20,9 @@
#include "mlir-c/BuiltinTypes.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "nanobind/nanobind.h"
#include "nanobind/ndarray.h"
#include "pytypedefs.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"

Expand Down Expand Up @@ -1494,14 +1499,72 @@ class PyDenseResourceElementsAttribute
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
}

static PyDenseResourceElementsAttribute
getFromNdarray(const nb::ndarray<nb::any_contig> &buffer, const std::string &name, const PyType &type,
std::optional<size_t> alignment, bool isMutable,
DefaultingPyMlirContext contextWrapper) {
if (!mlirTypeIsAShaped(type)) {
throw std::invalid_argument(
"Constructing a DenseResourceElementsAttr requires a ShapedType.");
}

// Do not request any conversions as we must ensure to use caller
// managed memory.
std::unique_ptr<nb::ndarray<nb::any_contig>> view = std::make_unique<nb::ndarray<nb::any_contig>>(buffer);
if (!view->is_valid()) {
throw std::invalid_argument("The buffer should not be a nullptr.");
}

// This scope releaser will only release if we haven't yet transferred
// ownership.
auto freeBuffer = llvm::make_scope_exit([&]() {
if (view)
view.release();
});

// Infer alignment to be the stride of one element if not explicit.
size_t inferredAlignment;
if (alignment)
inferredAlignment = *alignment;
else
inferredAlignment = view->stride_ptr()[view->ndim() - 1];

// The userData is a nb::ndarray<nb::any_contig>* that the deleter owns.
auto deleter = [](void *userData, const void *data, size_t size,
size_t align) {
nb::ndarray<nb::any_contig> *ownedView= static_cast<nb::ndarray<nb::any_contig> *>(userData);
delete ownedView;
};

size_t rawBufferSize = view->size() * view->itemsize();
MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
type, toMlirStringRef(name), view->data(), rawBufferSize,
inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseResourceElementsAttr could not be constructed from the given "
"buffer. "
"This may mean that the Python buffer layout does not match that "
"MLIR expected layout and is a bug.");
}
view.release();
return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
}

static void bindDerived(ClassTy &c) {
c.def_static(
"get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
nb::arg("array"), nb::arg("name"), nb::arg("type"),
nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
nb::arg("context").none() = nb::none(),
kDenseResourceElementsAttrGetFromBufferDocstring);
}
c.def_static(
"get_from_ndarray", PyDenseResourceElementsAttribute::getFromNdarray,
nb::arg("array"), nb::arg("name"), nb::arg("type"),
nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
nb::arg("context").none() = nb::none(),
kDenseResourceElementsAttrGetFromBufferDocstring);
}
};

class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
Expand Down
38 changes: 38 additions & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ else:
class Buffer(abc.ABC):
pass

class DLPack(abc.ABC):
pass

class _OperationBase:
@overload
def __eq__(self, arg0: _OperationBase) -> bool: ...
Expand Down Expand Up @@ -1350,6 +1353,41 @@ class DenseResourceElementsAttr(Attribute):
type or if the buffer does not meet expectations.
"""
@staticmethod
def get_from_ndarray(
array: DLPack,
name: str,
type: Type,
alignment: int | None = None,
is_mutable: bool = False,
context: Context | None = None,
) -> DenseResourceElementsAttr:
"""
Gets a DenseResourceElementsAttr from a Python buffer or DLPack C structure
wrapped in a PyCapsule.

This function does minimal validation or massaging of the data, and it is
up to the caller to ensure that the buffer meets the characteristics
implied by the shape.

The DLPack data structure and any user objects will be retained over the lifetime
of the resource blob. The used nanobind ndarray is designed as a view of the memory
and any copy of this wrapper will point to the same underlying buffer
and will only increase the reference count until it goes out of scope.

Args:
array: The buffer or DLPack to convert.
name: Name to provide to the resource (may be changed upon collision).
type: The explicit ShapedType to construct the attribute with.
context: Explicit context, if not from context manager.

Returns:
DenseResourceElementsAttr on success.

Raises:
ValueError: If the type of the buffer or array cannot be matched to an MLIR
type or if the buffer does not meet expectations.
"""
@staticmethod
def isinstance(other: Attribute) -> bool: ...
def __init__(self, cast_from_attr: Attribute) -> None: ...
@property
Expand Down
109 changes: 109 additions & 0 deletions mlir/test/python/ir/array_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mlir.ir import *
import numpy as np
import weakref
import ctypes


def run(f):
Expand Down Expand Up @@ -617,3 +618,111 @@ def test_attribute(context, mview):
# CHECK: BACKING MEMORY DELETED
# CHECK: EXIT FUNCTION
print("EXIT FUNCTION")


# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayI32
@run
def testGetDenseResourceElementsAttrNdarrayI32():
class DLPackWrapper:
def __init__(self, array: np.ndarray):
self.dlpack_capsule = array.__dlpack__()

def __del__(self):
print("DLPACK MEMORY DELETED")

def get_capsule(self):
return self.dlpack_capsule

context = Context()
mview_int32 = DLPackWrapper(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32))

def test_attribute_int32(context, mview_int32):
with context, Location.unknown():
element_type = IntegerType.get_signless(32)
tensor_type = RankedTensorType.get((2, 3), element_type)
resource = DenseResourceElementsAttr.get_from_ndarray(
mview_int32.get_capsule(), "from_py", tensor_type
)
module = Module.parse("module {}")
module.operation.attributes["test.resource"] = resource
# CHECK: test.resource = dense_resource<from_py> : tensor<2x3xi32>
# CHECK: from_py: "0x01000000010000000200000003000000040000000500000006000000"
print(module)

# Verifies type casting.
# CHECK: dense_resource<from_py> : tensor<2x3xi32>
print(
DenseResourceElementsAttr(module.operation.attributes["test.resource"])
)

test_attribute_int32(context, mview_int32)
del mview_int32
gc.collect()
# CHECK: DLPACK MEMORY DELETED
# CHECK: FREEING CONTEXT
print("FREEING CONTEXT")
context = None
gc.collect()
# CHECK: EXIT FUNCTION
print("EXIT FUNCTION")


# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNdarrayF32
@run
def testGetDenseResourceElementsAttrNdarrayF32():
class DLPackWrapper:
def __init__(self, array: np.ndarray):
self.dlpack_capsule = array.__dlpack__()

def __del__(self):
print("DLPACK MEMORY DELETED")

def get_capsule(self):
return self.dlpack_capsule

context = Context()
mview_float32 = DLPackWrapper(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32))

def test_attribute_float32(context, mview_float32):
with context, Location.unknown():
element_type = FloatAttr.get_f32(32.0)
tensor_type = RankedTensorType.get((2, 3), element_type.type)
resource = DenseResourceElementsAttr.get_from_ndarray(
mview_float32.get_capsule(), "from_py", tensor_type
)
module = Module.parse("module {}")
module.operation.attributes["test.resource"] = resource
# CHECK: test.resource = dense_resource<from_py> : tensor<2x3xf32>
# CHECK: from_py: "0x010000000000803F0000004000004040000080400000A0400000C040"
print(module)

# Verifies type casting.
# CHECK: dense_resource<from_py> : tensor<2x3xf32>
print(
DenseResourceElementsAttr(module.operation.attributes["test.resource"])
)

test_attribute_float32(context, mview_float32)
del mview_float32
gc.collect()
# CHECK: DLPACK MEMORY DELETED
# CHECK: FREEING CONTEXT
print("FREEING CONTEXT")
context = None
gc.collect()
# CHECK: EXIT FUNCTION
print("EXIT FUNCTION")


# CHECK-LABEL: TEST: testGetDenseResourceElementsAttrNonShapedType
@run
def testGetDenseResourceElementsAttrNonShapedType():
with Context(), Location.unknown():
mview = np.array([1], dtype=np.int32).__dlpack__()
t = F32Type.get()

try:
attr = DenseResourceElementsAttr.get_from_ndarray(mview, "from_py", t)
except ValueError as e:
# CHECK: Constructing a DenseResourceElementsAttr requires a ShapedType.
print(e)