Skip to content

Commit 6179551

Browse files
committed
Fix Bfloat16 support in iree-turbine
1 parent d866fd0 commit 6179551

File tree

2 files changed

+2
-5
lines changed

2 files changed

+2
-5
lines changed

.github/workflows/ci.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
- name: Run unit tests
6363
if: ${{ !cancelled() }}
6464
run: |
65-
pytest -n 4 --capture=tee-sys -vv .
65+
pytest --capture=tee-sys -vv .
6666
6767
- name: Run LIT tests
6868
if: ${{ !cancelled() }}

iree/turbine/aot/support/ir_utils.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,7 @@ def create_tensor_global(
324324
)
325325
else:
326326
# Emit inline initialized.
327-
detached_tensor = t.detach().contiguous().cpu()
328-
array = np.array(detached_tensor)
329-
# We know that a Numpy array is a ReadableBuffer so ignore type error.
330-
contents = memoryview(array) # type: ignore
327+
contents = torch.to_dlpack(t)
331328
blob_name = symbol_name
332329
elements_attr = DenseResourceElementsAttr.get_from_buffer(
333330
contents, blob_name, tensor_type

0 commit comments

Comments
 (0)