Skip to content

Commit c3a8c30

Browse files
committed
Fix Bfloat16 support in iree-turbine
1 parent d866fd0 commit c3a8c30

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

iree/turbine/aot/support/ir_utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,11 @@ def create_tensor_global(
324324
)
325325
else:
326326
# Emit inline initialized.
327-
detached_tensor = t.detach().contiguous().cpu()
328-
array = np.array(detached_tensor)
329-
# We know that a Numpy array is a ReadableBuffer so ignore type error.
330-
contents = memoryview(array) # type: ignore
327+
contents = torch.utils.dlpack.to_dlpack(t)
328+
#detached_tensor = t.detach().contiguous().cpu()
329+
#array = np.array(detached_tensor)
330+
#array = bytes(detached_tensor.untyped_storage())
331+
#contents = memoryview(array)
331332
blob_name = symbol_name
332333
elements_attr = DenseResourceElementsAttr.get_from_buffer(
333334
contents, blob_name, tensor_type

0 commit comments

Comments
 (0)