Skip to content

Commit 9e00247

Browse files
committed
Fix Bfloat16 support in iree-turbine
1 parent d866fd0 commit 9e00247

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

iree/turbine/aot/support/ir_utils.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,9 @@ def create_tensor_global(
324324
)
325325
else:
326326
# Emit inline initialized.
327-
detached_tensor = t.detach().contiguous().cpu()
328-
array = np.array(detached_tensor)
329-
# We know that a Numpy array is a ReadableBuffer so ignore type error.
330-
contents = memoryview(array) # type: ignore
327+
contents = torch.utils.dlpack.to_dlpack(t)
331328
blob_name = symbol_name
332-
elements_attr = DenseResourceElementsAttr.get_from_buffer(
329+
elements_attr = DenseResourceElementsAttr.get_from_buffer_ndarray(
333330
contents, blob_name, tensor_type
334331
)
335332
ir_attrs["initial_value"] = elements_attr

0 commit comments

Comments
 (0)