Skip to content

Commit cbbd7ca

Browse files
committed
Add tree_bits function.
1 parent a02de84 commit cbbd7ca

File tree

5 files changed

+58
-0
lines changed

5 files changed

+58
-0
lines changed

docs/api/utilities.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ Tree
120120
tree_scale
121121
tree_set
122122
tree_size
123+
tree_bits
123124
tree_sub
124125
tree_sum
125126
tree_vdot
@@ -226,6 +227,10 @@ Tree size
226227
~~~~~~~~~
227228
.. autofunction:: tree_size
228229

230+
Tree bits
231+
~~~~~~~~~
232+
.. autofunction:: tree_bits
233+
229234
Tree subtract
230235
~~~~~~~~~~~~~
231236
.. autofunction:: tree_sub

optax/tree/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
real = _tree_math.tree_real
4949
scale = _tree_math.tree_scale
5050
size = _tree_math.tree_size
51+
bits = _tree_math.tree_bits
5152
sub = _tree_math.tree_sub
5253
sum = _tree_math.tree_sum # pylint: disable=redefined-builtin
5354
update_infinity_moment = _tree_math.tree_update_infinity_moment

optax/tree_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from optax.tree_utils._tree_math import tree_real
4848
from optax.tree_utils._tree_math import tree_scale
4949
from optax.tree_utils._tree_math import tree_size
50+
from optax.tree_utils._tree_math import tree_bits
5051
from optax.tree_utils._tree_math import tree_sub
5152
from optax.tree_utils._tree_math import tree_sum
5253
from optax.tree_utils._tree_math import tree_update_infinity_moment

optax/tree_utils/_tree_math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,37 @@ def tree_size(tree: Any) -> int:
207207
return sum(jnp.size(leaf) for leaf in jax.tree.leaves(tree))
208208

209209

210+
def _get_bits(dtype):
211+
if jnp.issubdtype(dtype, jnp.integer):
212+
return jnp.iinfo(dtype).bits
213+
elif jnp.issubdtype(dtype, jnp.floating):
214+
return jnp.finfo(dtype).bits
215+
elif dtype is bool:
216+
return 1
217+
else:
218+
raise NotImplementedError(f"_get_bits not implemented for {dtype=}")
219+
220+
221+
def tree_bits(tree: Any) -> int:
222+
r"""Total number of bits in a pytree.
223+
224+
Args:
225+
tree: pytree
226+
227+
Returns:
228+
the total size of the pytree in bits.
229+
230+
.. warning::
231+
It is assumed that every leaf's dtype has an integer byte size.
232+
Fractional byte sizes may yield an incorrect result.
233+
For example, ``int4`` might be only half a byte on device.
234+
"""
235+
return sum(
236+
_get_bits(jnp.asarray(leaf).dtype) * jnp.size(leaf)
237+
for leaf in jax.tree.leaves(tree)
238+
)
239+
240+
210241
def tree_conj(tree: Any) -> Any:
211242
"""Compute the conjugate of a pytree.
212243

optax/tree_utils/_tree_math_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,26 @@ def test_tree_allclose(self):
177177
assert tu.tree_allclose(1, 1 + 1e-7)
178178
assert not tu.tree_allclose(1, 2)
179179

180+
@parameterized.product(
181+
size=[1, 10, 100, 1000],
182+
dtype=[jnp.int16, jnp.int32, jnp.float16, jnp.float32, jnp.bfloat16],
183+
)
184+
def test_tree_bits(self, size, dtype):
185+
tree = jnp.ones(size, dtype)
186+
if dtype == jnp.int16:
187+
bits = 16
188+
elif dtype == jnp.int32:
189+
bits = 32
190+
elif dtype == jnp.float16:
191+
bits = 16
192+
elif dtype == jnp.float32:
193+
bits = 32
194+
elif dtype == jnp.bfloat16:
195+
bits = 16
196+
else:
197+
raise NotImplementedError(f"{dtype=}")
198+
assert tu.tree_bits(tree) == bits * size
199+
180200
def test_tree_conj(self):
181201
expected = jnp.conj(self.array_a)
182202
got = tu.tree_conj(self.array_a)

0 commit comments

Comments
 (0)