Skip to content

Commit 89d298f

Browse files
committed
Add tree_bits function.
1 parent c6d8eba commit 89d298f

File tree

5 files changed

+58
-0
lines changed

5 files changed

+58
-0
lines changed

docs/api/utilities.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ Tree
119119
tree_scale
120120
tree_set
121121
tree_size
122+
tree_bits
122123
tree_sub
123124
tree_sum
124125
tree_vdot
@@ -221,6 +222,10 @@ Tree size
221222
~~~~~~~~~
222223
.. autofunction:: tree_size
223224

225+
Tree bits
226+
~~~~~~~~~
227+
.. autofunction:: tree_bits
228+
224229
Tree subtract
225230
~~~~~~~~~~~~~
226231
.. autofunction:: tree_sub

optax/tree/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
real = _tree_math.tree_real
4848
scale = _tree_math.tree_scale
4949
size = _tree_math.tree_size
50+
bits = _tree_math.tree_bits
5051
sub = _tree_math.tree_sub
5152
sum = _tree_math.tree_sum # pylint: disable=redefined-builtin
5253
update_infinity_moment = _tree_math.tree_update_infinity_moment

optax/tree_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from optax.tree_utils._tree_math import tree_real
4747
from optax.tree_utils._tree_math import tree_scale
4848
from optax.tree_utils._tree_math import tree_size
49+
from optax.tree_utils._tree_math import tree_bits
4950
from optax.tree_utils._tree_math import tree_sub
5051
from optax.tree_utils._tree_math import tree_sum
5152
from optax.tree_utils._tree_math import tree_update_infinity_moment

optax/tree_utils/_tree_math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,37 @@ def tree_size(tree: Any) -> int:
209209
return sum(jnp.size(leaf) for leaf in jax.tree.leaves(tree))
210210

211211

212+
def _get_bits(dtype):
213+
if jnp.issubdtype(dtype, jnp.integer):
214+
return jnp.iinfo(dtype).bits
215+
elif jnp.issubdtype(dtype, jnp.floating):
216+
return jnp.finfo(dtype).bits
217+
elif dtype is bool:
218+
return 1
219+
else:
220+
raise NotImplementedError(f"_get_bits not implemented for {dtype=}")
221+
222+
223+
def tree_bits(tree: Any) -> int:
224+
r"""Total number of bytes in a pytree.
225+
226+
Args:
227+
tree: pytree
228+
229+
Returns:
230+
the total size of the pytree in bytes.
231+
232+
.. warning::
233+
It is assumed that every leaf's dtype has an integer byte size.
234+
Fractional byte sizes may yield an incorrect result.
235+
For example, ``int4`` might be only half a byte on device.
236+
"""
237+
return sum(
238+
_get_bits(jnp.asarray(leaf).dtype) * jnp.size(leaf)
239+
for leaf in jax.tree.leaves(tree)
240+
)
241+
242+
212243
def tree_conj(tree: Any) -> Any:
213244
"""Compute the conjugate of a pytree.
214245

optax/tree_utils/_tree_math_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,26 @@ def test_tree_size(self, key):
172172
got = tu.tree_size(tree)
173173
np.testing.assert_allclose(expected, got)
174174

175+
@parameterized.product(
176+
size=[1, 10, 100, 1000],
177+
dtype=[jnp.int16, jnp.int32, jnp.float16, jnp.float32, jnp.bfloat16],
178+
)
179+
def test_tree_bits(self, size, dtype):
180+
tree = jnp.ones(size, dtype)
181+
if dtype == jnp.int16:
182+
bits = 16
183+
elif dtype == jnp.int32:
184+
bits = 32
185+
elif dtype == jnp.float16:
186+
bits = 16
187+
elif dtype == jnp.float32:
188+
bits = 32
189+
elif dtype == jnp.bfloat16:
190+
bits = 16
191+
else:
192+
raise NotImplementedError(f"{dtype=}")
193+
assert tu.tree_bits(tree) == bits * size
194+
175195
def test_tree_conj(self):
176196
expected = jnp.conj(self.array_a)
177197
got = tu.tree_conj(self.array_a)

0 commit comments

Comments
 (0)