@@ -379,9 +379,10 @@ def __init__(
379
379
380
380
Args:
381
381
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
382
- When the dtype is not one of the numpy native dtypes, the value needs
383
- to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
384
- when the value is a numpy array; ``dtype`` must be specified in this case.
382
+ When the dtype is not one of the numpy native dtypes, the value can
383
+ be ``uint8`` (unpacked) or ml_dtypes types for 4-bit and 8-bit data types,
384
+ and ``uint16`` or ml_dtype types for bfloat16 when the value is a numpy array;
385
+ ``dtype`` must be specified in this case.
385
386
dtype: The data type of the tensor. It can be None only when value is a numpy array.
386
387
Users are responsible for making sure the dtype matches the value when value is not a numpy array.
387
388
shape: The shape of the tensor. If None, the shape is obtained from the value.
@@ -955,6 +956,156 @@ def tobytes(self) -> bytes:
955
956
return self ._evaluate ().tobytes ()
956
957
957
958
959
+ class PackedTensor (TensorBase , _protocols .TensorProtocol ): # pylint: disable=too-many-ancestors
960
+ """A tensor that stores 4bit datatypes in packed format."""
961
+
962
+ __slots__ = (
963
+ "_dtype" ,
964
+ "_metadata" ,
965
+ "_metadata_props" ,
966
+ "_raw" ,
967
+ "_shape" ,
968
+ "doc_string" ,
969
+ "name" ,
970
+ )
971
+
972
+ def __init__ (
973
+ self ,
974
+ value : _protocols .ArrayCompatible | _protocols .DLPackCompatible ,
975
+ dtype : _enums .DataType ,
976
+ * ,
977
+ shape : Shape ,
978
+ name : str | None = None ,
979
+ doc_string : str | None = None ,
980
+ metadata_props : dict [str , str ] | None = None ,
981
+ ) -> None :
982
+ """Initialize a tensor.
983
+
984
+ Args:
985
+ value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
986
+ The value MUST be in ``uint8`` packed format, or in one of the ml_dtypes dtypes, which
987
+ will be packed when constructing the tensor.
988
+ dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1.
989
+ shape: The shape of the tensor.
990
+ name: The name of the tensor.
991
+ doc_string: The documentation string.
992
+ metadata_props: The metadata properties.
993
+
994
+ Raises:
995
+ TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
996
+ TypeError: If the value is a numpy array and the dtype is not uint8 or one of the ml_dtypes dtypes.
997
+ """
998
+ # NOTE: We should not do any copying here for performance reasons
999
+ if not _compatible_with_numpy (value ) and not _compatible_with_dlpack (value ):
1000
+ raise TypeError (f"Expected an array compatible object, got { type (value )} " )
1001
+ self ._shape = shape
1002
+ self ._shape .freeze ()
1003
+ if dtype is None :
1004
+ if isinstance (value , np .ndarray ):
1005
+ self ._dtype = _enums .DataType .from_numpy (value .dtype )
1006
+ else :
1007
+ raise ValueError (
1008
+ "The dtype must be specified when the value is not a numpy array."
1009
+ )
1010
+ self ._dtype = dtype
1011
+
1012
+ # View the bfloat16, float8 and int4 types using ml_dtypes
1013
+ if isinstance (value , np .ndarray ):
1014
+ value = _maybe_view_np_array_with_ml_dtypes (value , self ._dtype ) # type: ignore[assignment]
1015
+
1016
+ self ._raw = value
1017
+ self .name = name
1018
+ self .doc_string = doc_string
1019
+ self ._metadata : _metadata .MetadataStore | None = None
1020
+ self ._metadata_props = metadata_props
1021
+
1022
+ def __array__ (self , dtype : Any = None ) -> np .ndarray :
1023
+ if isinstance (self ._raw , np .ndarray ) or _compatible_with_numpy (self ._raw ):
1024
+ return self ._raw .__array__ (dtype )
1025
+ assert _compatible_with_dlpack (self ._raw ), (
1026
+ f"Bug: Expected DLPack or Numpy compatible objects, got { type (self ._raw )} "
1027
+ )
1028
+ return np .from_dlpack (self ._raw )
1029
+
1030
+ def __dlpack__ (self , * , stream : Any = None ) -> Any :
1031
+ if _compatible_with_dlpack (self ._raw ):
1032
+ return self ._raw .__dlpack__ (stream = stream )
1033
+ return self .__array__ ().__dlpack__ (stream = stream )
1034
+
1035
+ def __dlpack_device__ (self ) -> tuple [int , int ]:
1036
+ if _compatible_with_dlpack (self ._raw ):
1037
+ return self ._raw .__dlpack_device__ ()
1038
+ return self .__array__ ().__dlpack_device__ ()
1039
+
1040
+ def __repr__ (self ) -> str :
1041
+ return f"{ self ._repr_base ()} ({ self ._raw !r} , name={ self .name !r} )"
1042
+
1043
+ @property
1044
+ def dtype (self ) -> _enums .DataType :
1045
+ """The data type of the tensor. Immutable."""
1046
+ return self ._dtype
1047
+
1048
+ @property
1049
+ def shape (self ) -> Shape :
1050
+ """The shape of the tensor. Immutable."""
1051
+ return self ._shape
1052
+
1053
+ @property
1054
+ def raw (self ) -> TArrayCompatible :
1055
+ """Backing data of the tensor. Immutable."""
1056
+ return self ._raw # type: ignore[return-value]
1057
+
1058
+ def numpy (self ) -> np .ndarray :
1059
+ """Return the tensor as a numpy array.
1060
+
1061
+ When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
1062
+ package are used. The values can be reinterpreted as bit representations
1063
+ using the ``.view()`` method.
1064
+ """
1065
+ if isinstance (self ._raw , np .ndarray ):
1066
+ return self ._raw
1067
+ # We do not cache the value to save memory
1068
+ return self .__array__ ()
1069
+
1070
+ def tobytes (self ) -> bytes :
1071
+ """Returns the value as bytes encoded in little endian.
1072
+
1073
+ Override this method for more efficient serialization when the raw
1074
+ value is not a numpy array.
1075
+ """
1076
+ # TODO(justinchuby): Support DLPack
1077
+ array = self .numpy ()
1078
+ if self .dtype in {
1079
+ _enums .DataType .INT4 ,
1080
+ _enums .DataType .UINT4 ,
1081
+ _enums .DataType .FLOAT4E2M1 ,
1082
+ }:
1083
+ # Pack the array into int4
1084
+ array = _type_casting .pack_int4 (array )
1085
+ else :
1086
+ assert self .dtype .itemsize == array .itemsize , "Bug: The itemsize should match"
1087
+ if not _IS_LITTLE_ENDIAN :
1088
+ array = array .view (array .dtype .newbyteorder ("<" ))
1089
+ return array .tobytes ()
1090
+
1091
+ @property
1092
+ def metadata_props (self ) -> dict [str , str ]:
1093
+ if self ._metadata_props is None :
1094
+ self ._metadata_props = {}
1095
+ return self ._metadata_props
1096
+
1097
+ @property
1098
+ def meta (self ) -> _metadata .MetadataStore :
1099
+ """The metadata store for intermediate analysis.
1100
+
1101
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
1102
+ to the ONNX proto.
1103
+ """
1104
+ if self ._metadata is None :
1105
+ self ._metadata = _metadata .MetadataStore ()
1106
+ return self ._metadata
1107
+
1108
+
958
1109
class SymbolicDim (_protocols .SymbolicDimProtocol , _display .PrettyPrintable ):
959
1110
"""Immutable symbolic dimension that can be shared across multiple shapes."""
960
1111
0 commit comments