Skip to content

Commit 6ae4ed2

Browse files
committed
wip
1 parent a8f56c2 commit 6ae4ed2

File tree

1 file changed

+154
-3
lines changed

1 file changed

+154
-3
lines changed

onnxscript/ir/_core.py

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,10 @@ def __init__(
379379
380380
Args:
381381
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
382-
When the dtype is not one of the numpy native dtypes, the value needs
383-
to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
384-
when the value is a numpy array; ``dtype`` must be specified in this case.
382+
When the dtype is not one of the numpy native dtypes, the value can
383+
be ``uint8`` (unpacked) or ml_dtypes types for 4-bit and 8-bit data types,
384+
and ``uint16`` or ml_dtype types for bfloat16 when the value is a numpy array;
385+
``dtype`` must be specified in this case.
385386
dtype: The data type of the tensor. It can be None only when value is a numpy array.
386387
Users are responsible for making sure the dtype matches the value when value is not a numpy array.
387388
shape: The shape of the tensor. If None, the shape is obtained from the value.
@@ -955,6 +956,156 @@ def tobytes(self) -> bytes:
955956
return self._evaluate().tobytes()
956957

957958

959+
class PackedTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
960+
"""A tensor that stores 4bit datatypes in packed format."""
961+
962+
__slots__ = (
963+
"_dtype",
964+
"_metadata",
965+
"_metadata_props",
966+
"_raw",
967+
"_shape",
968+
"doc_string",
969+
"name",
970+
)
971+
972+
def __init__(
973+
self,
974+
value: _protocols.ArrayCompatible | _protocols.DLPackCompatible,
975+
dtype: _enums.DataType,
976+
*,
977+
shape: Shape,
978+
name: str | None = None,
979+
doc_string: str | None = None,
980+
metadata_props: dict[str, str] | None = None,
981+
) -> None:
982+
"""Initialize a tensor.
983+
984+
Args:
985+
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
986+
The value MUST be in ``uint8`` packed format, or in one of the ml_dtypes dtypes, which
987+
will be packed when constructing the tensor.
988+
dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1.
989+
shape: The shape of the tensor.
990+
name: The name of the tensor.
991+
doc_string: The documentation string.
992+
metadata_props: The metadata properties.
993+
994+
Raises:
995+
TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
996+
TypeError: If the value is a numpy array and the dtype is not uint8 or one of the ml_dtypes dtypes.
997+
"""
998+
# NOTE: We should not do any copying here for performance reasons
999+
if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
1000+
raise TypeError(f"Expected an array compatible object, got {type(value)}")
1001+
self._shape = shape
1002+
self._shape.freeze()
1003+
if dtype is None:
1004+
if isinstance(value, np.ndarray):
1005+
self._dtype = _enums.DataType.from_numpy(value.dtype)
1006+
else:
1007+
raise ValueError(
1008+
"The dtype must be specified when the value is not a numpy array."
1009+
)
1010+
self._dtype = dtype
1011+
1012+
# View the bfloat16, float8 and int4 types using ml_dtypes
1013+
if isinstance(value, np.ndarray):
1014+
value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]
1015+
1016+
self._raw = value
1017+
self.name = name
1018+
self.doc_string = doc_string
1019+
self._metadata: _metadata.MetadataStore | None = None
1020+
self._metadata_props = metadata_props
1021+
1022+
def __array__(self, dtype: Any = None) -> np.ndarray:
1023+
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
1024+
return self._raw.__array__(dtype)
1025+
assert _compatible_with_dlpack(self._raw), (
1026+
f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
1027+
)
1028+
return np.from_dlpack(self._raw)
1029+
1030+
def __dlpack__(self, *, stream: Any = None) -> Any:
1031+
if _compatible_with_dlpack(self._raw):
1032+
return self._raw.__dlpack__(stream=stream)
1033+
return self.__array__().__dlpack__(stream=stream)
1034+
1035+
def __dlpack_device__(self) -> tuple[int, int]:
1036+
if _compatible_with_dlpack(self._raw):
1037+
return self._raw.__dlpack_device__()
1038+
return self.__array__().__dlpack_device__()
1039+
1040+
def __repr__(self) -> str:
1041+
return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
1042+
1043+
@property
1044+
def dtype(self) -> _enums.DataType:
1045+
"""The data type of the tensor. Immutable."""
1046+
return self._dtype
1047+
1048+
@property
1049+
def shape(self) -> Shape:
1050+
"""The shape of the tensor. Immutable."""
1051+
return self._shape
1052+
1053+
@property
1054+
def raw(self) -> TArrayCompatible:
1055+
"""Backing data of the tensor. Immutable."""
1056+
return self._raw # type: ignore[return-value]
1057+
1058+
def numpy(self) -> np.ndarray:
1059+
"""Return the tensor as a numpy array.
1060+
1061+
When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
1062+
package are used. The values can be reinterpreted as bit representations
1063+
using the ``.view()`` method.
1064+
"""
1065+
if isinstance(self._raw, np.ndarray):
1066+
return self._raw
1067+
# We do not cache the value to save memory
1068+
return self.__array__()
1069+
1070+
def tobytes(self) -> bytes:
1071+
"""Returns the value as bytes encoded in little endian.
1072+
1073+
Override this method for more efficient serialization when the raw
1074+
value is not a numpy array.
1075+
"""
1076+
# TODO(justinchuby): Support DLPack
1077+
array = self.numpy()
1078+
if self.dtype in {
1079+
_enums.DataType.INT4,
1080+
_enums.DataType.UINT4,
1081+
_enums.DataType.FLOAT4E2M1,
1082+
}:
1083+
# Pack the array into int4
1084+
array = _type_casting.pack_int4(array)
1085+
else:
1086+
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
1087+
if not _IS_LITTLE_ENDIAN:
1088+
array = array.view(array.dtype.newbyteorder("<"))
1089+
return array.tobytes()
1090+
1091+
@property
1092+
def metadata_props(self) -> dict[str, str]:
1093+
if self._metadata_props is None:
1094+
self._metadata_props = {}
1095+
return self._metadata_props
1096+
1097+
@property
1098+
def meta(self) -> _metadata.MetadataStore:
1099+
"""The metadata store for intermediate analysis.
1100+
1101+
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
1102+
to the ONNX proto.
1103+
"""
1104+
if self._metadata is None:
1105+
self._metadata = _metadata.MetadataStore()
1106+
return self._metadata
1107+
1108+
9581109
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
9591110
"""Immutable symbolic dimension that can be shared across multiple shapes."""
9601111

0 commit comments

Comments
 (0)