File tree Expand file tree Collapse file tree 5 files changed +71
-0
lines changed Expand file tree Collapse file tree 5 files changed +71
-0
lines changed Original file line number Diff line number Diff line change 120120 tree_scale
121121 tree_set
122122 tree_size
123+ tree_bits
123124 tree_sub
124125 tree_sum
125126 tree_vdot
@@ -226,6 +227,10 @@ Tree size
226227~~~~~~~~~
227228.. autofunction :: tree_size
228229
230+ Tree bits
231+ ~~~~~~~~~
232+ .. autofunction :: tree_bits
233+
229234Tree subtract
230235~~~~~~~~~~~~~
231236.. autofunction :: tree_sub
Original file line number Diff line number Diff line change 4848real = _tree_math .tree_real
4949scale = _tree_math .tree_scale
5050size = _tree_math .tree_size
51+ bits = _tree_math .tree_bits
5152sub = _tree_math .tree_sub
5253sum = _tree_math .tree_sum # pylint: disable=redefined-builtin
5354update_infinity_moment = _tree_math .tree_update_infinity_moment
Original file line number Diff line number Diff line change 4747from optax .tree_utils ._tree_math import tree_real
4848from optax .tree_utils ._tree_math import tree_scale
4949from optax .tree_utils ._tree_math import tree_size
50+ from optax .tree_utils ._tree_math import tree_bits
5051from optax .tree_utils ._tree_math import tree_sub
5152from optax .tree_utils ._tree_math import tree_sum
5253from optax .tree_utils ._tree_math import tree_update_infinity_moment
Original file line number Diff line number Diff line change @@ -207,6 +207,37 @@ def tree_size(tree: Any) -> int:
207207 return sum (jnp .size (leaf ) for leaf in jax .tree .leaves (tree ))
208208
209209
210+ def _get_bits (dtype ):
211+ if jnp .issubdtype (dtype , jnp .integer ):
212+ return jnp .iinfo (dtype ).bits
213+ elif jnp .issubdtype (dtype , jnp .floating ):
214+ return jnp .finfo (dtype ).bits
215+ elif dtype is bool :
216+ return 1
217+ else :
218+ raise NotImplementedError (f"_get_bits not implemented for { dtype = } " )
219+
220+
221+ def tree_bits (tree : Any ) -> int :
222+ r"""Total number of bits in a pytree.
223+
224+ Args:
225+ tree: pytree
226+
227+ Returns:
228+ the total size of the pytree in bits.
229+
230+ .. warning::
231+ It is assumed that every leaf's dtype has an integer byte size.
232+ Fractional byte sizes may yield an incorrect result.
233+ For example, ``int4`` might be only half a byte on device.
234+ """
235+ return sum (
236+ _get_bits (jnp .asarray (leaf ).dtype ) * jnp .size (leaf )
237+ for leaf in jax .tree .leaves (tree )
238+ )
239+
240+
210241def tree_conj (tree : Any ) -> Any :
211242 """Compute the conjugate of a pytree.
212243
Original file line number Diff line number Diff line change @@ -177,6 +177,39 @@ def test_tree_allclose(self):
177177 assert tu .tree_allclose (1 , 1 + 1e-7 )
178178 assert not tu .tree_allclose (1 , 2 )
179179
180+ @parameterized .product (
181+ size = [1 , 10 , 100 , 1000 ],
182+ dtype = [
183+ jnp .int4 ,
184+ jnp .int8 ,
185+ jnp .int16 ,
186+ jnp .int32 ,
187+ jnp .uint4 ,
188+ jnp .uint8 ,
189+ jnp .uint16 ,
190+ jnp .uint32 ,
191+ jnp .float16 ,
192+ jnp .float32 ,
193+ jnp .bfloat16 ,
194+ ],
195+ )
196+ def test_tree_bits (self , size , dtype ):
197+ tree = jnp .zeros (size , dtype )
198+ bits = {
199+ jnp .int4 : 4 ,
200+ jnp .int8 : 8 ,
201+ jnp .int16 : 16 ,
202+ jnp .int32 : 32 ,
203+ jnp .uint4 : 4 ,
204+ jnp .uint8 : 8 ,
205+ jnp .uint16 : 16 ,
206+ jnp .uint32 : 32 ,
207+ jnp .float16 : 16 ,
208+ jnp .float32 : 32 ,
209+ jnp .bfloat16 : 16 ,
210+ }[dtype ]
211+ assert tu .tree_bits (tree ) == bits * size
212+
180213 def test_tree_conj (self ):
181214 expected = jnp .conj (self .array_a )
182215 got = tu .tree_conj (self .array_a )
You can’t perform that action at this time.
0 commit comments