Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,49 @@ Hessian vector product
.. autofunction:: hvp


State utilities
-------------------

.. currentmodule:: optax.tree_utils

.. autosummary::
ParamsShapedState
reshape_params_shaped_state
shape_state_like_params
tree_map_params
tree_get
tree_get_all_with_path
tree_set

ParamsShapedState
~~~~~~~~~~~~~~~~~
.. autoclass:: ParamsShapedState

Reshape params-shaped state
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: reshape_params_shaped_state

Shape state like params
~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: shape_state_like_params

Fetch single value that match a given key
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_get

Fetch all values that match a given key
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_get_all_with_path

Tree map parameters
~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_map_params

Set values in a tree
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_set


Tree
----

Expand All @@ -106,10 +149,7 @@ Tree
tree_div
tree_dtype
tree_full_like
tree_get
tree_get_all_with_path
tree_norm
tree_map_params
tree_max
tree_min
tree_mul
Expand Down Expand Up @@ -174,22 +214,11 @@ Tree divide
~~~~~~~~~~~
.. autofunction:: tree_div

Fetch single value that match a given key
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_get

Fetch all values that match a given key
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_get_all_with_path

Tree norm
~~~~~~~~~
.. autofunction:: tree_norm

Tree map parameters
~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_map_params

Tree max
~~~~~~~~
.. autofunction:: tree_max
Expand Down Expand Up @@ -222,10 +251,6 @@ Tree scalar multiply
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_scale

Set values in a tree
~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_set

Tree size
~~~~~~~~~
.. autofunction:: tree_size
Expand Down
3 changes: 3 additions & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from optax.tree_utils._random import tree_split_key_like
from optax.tree_utils._random import tree_unwrap_random_key_data
from optax.tree_utils._state_utils import NamedTupleKey
from optax.tree_utils._state_utils import ParamsShapedState
from optax.tree_utils._state_utils import reshape_params_shaped_state
from optax.tree_utils._state_utils import shape_state_like_params
from optax.tree_utils._state_utils import tree_get
from optax.tree_utils._state_utils import tree_get_all_with_path
from optax.tree_utils._state_utils import tree_map_params
Expand Down
249 changes: 205 additions & 44 deletions optax/tree_utils/_state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import dataclasses
import functools
import typing
from typing import Any, Optional, Protocol, Tuple, Union, cast
from typing import Any, NamedTuple, Optional, Protocol, Tuple, Union, cast

import jax
from optax._src import base
Expand All @@ -36,49 +36,6 @@ def tree_unflatten(cls, aux, children):
return cls()


@dataclasses.dataclass(frozen=True)
class NamedTupleKey:
"""KeyType for a NamedTuple in a tree.

When using a function ``filtering(path: KeyPath, value: Any) -> bool: ...``
in a tree in :func:`optax.tree_utils.tree_get_all_with_path`,
:func:`optax.tree_utils.tree_get`, or :func:`optax.tree_utils.tree_set`, can
filter the path to check if of the KeyEntry is a NamedTupleKey and then check
if the name of named tuple is the one intended to be searched.

Attributes:
tuple_name (str): name of the tuple containing the key.
name (str): name of the key.

.. seealso:: :class:`jax.tree_util.DictKey`,
:class:`jax.tree_util.FlattenedIndexKey`,
:class:`jax.tree_util.GetAttrKey`,
:class:`jax.tree_util.SequenceKey`,
:func:`optax.tree_utils.tree_get_all_with_path`,
:func:`optax.tree_utils.tree_get`,
:func:`optax.tree_utils.tree_set`,

.. versionadded:: 0.2.2
"""

tuple_name: str
name: str

def __str__(self):
return f"{self.tuple_name}.{self.name}"


_KeyEntry = Union[
jax.tree_util.DictKey,
jax.tree_util.FlattenedIndexKey,
jax.tree_util.GetAttrKey,
jax.tree_util.SequenceKey,
NamedTupleKey,
]

_KeyPath = Tuple[_KeyEntry, ...]


@typing.runtime_checkable
class Initable(Protocol):
"""An object with an init function."""
Expand Down Expand Up @@ -164,6 +121,210 @@ def map_params(maybe_placeholder_value, value):
)


class ParamsShapedState(NamedTuple):
"""State separating params-like fields from other fields."""

non_params_like: Any
params_like: Any


# pylint: disable=line-too-long
def shape_state_like_params(
initable: Union[
Callable[[base.Params], base.OptState],
Initable,
],
state: base.OptState,
) -> ParamsShapedState:
"""Reshape a state to separate out params-like fields from other fields.

Args:
initable: A callable taking parameters and returning an optimizer state, or
an object with an `init` attribute having the same function.
state: The optimizer state to reshape.

Returns:
A `ParamsShapedState` named tuple where the `params_like` attribute contains
a new state tree with the same structure as the parameter tree, and the
`non_params_like` attribute contains a state tree with the same structure as
the original state tree but with all leaf nodes replaced by None if they
were previously a parameter leaf.

Example:

>>> opt = optax.adam(1e-3)
>>> params = {'a': jnp.ones(2), 'b': jnp.ones(1)}
>>> state = opt.init(params)
>>> params_shaped_state = shape_state_like_params(opt, state)
>>> # Non-params-like part of the state (shaped as the original state):
>>> print(params_shaped_state.non_params_like)
(ScaleByAdamState(count=Array(0, dtype=int32), mu=None, nu=None), EmptyState())
>>> # Params-like part of the state (shaped as a pytree of states having tree structure of params):
>>> print(params_shaped_state.params_like)
{'a': (ScaleByAdamState(count=None, mu=Array([0., 0.], dtype=float32), nu=Array([0., 0.], dtype=float32)), EmptyState()), 'b': (ScaleByAdamState(count=None, mu=Array([0.], dtype=float32), nu=Array([0.], dtype=float32)), EmptyState())}

.. seealso:: :func:`optax.tree_utils.reshape_params_shaped_state`
"""
# pylint: enable=line-too-long

# Cast for pytype checks (no-op for other usages).
placeholder = cast(base.chex.ArrayTree, _ParamsPlaceholder())

if isinstance(initable, Initable):
initable = cast(Initable, initable) # for pytype checks
state_with_placeholders = initable.init(placeholder)
else:
state_with_placeholders = initable(placeholder)

non_params_like = jax.tree.map(
lambda x, y: None if isinstance(x, _ParamsPlaceholder) else y,
state_with_placeholders,
state,
is_leaf=lambda v: isinstance(v, _ParamsPlaceholder),
)

state_tree_def = jax.tree.structure(
state_with_placeholders,
is_leaf=lambda x: isinstance(x, _ParamsPlaceholder),
)
flat_state = state_tree_def.flatten_up_to(state)
flat_state_with_placeholders = state_tree_def.flatten_up_to(
state_with_placeholders
)

params_tree_def = None
for field, maybe_placeholder in zip(flat_state, flat_state_with_placeholders):
if isinstance(maybe_placeholder, _ParamsPlaceholder):
params_tree_def = jax.tree.structure(field)
break

if params_tree_def is None:
return ParamsShapedState(non_params_like, None)

filled_flat_state = []
for field, maybe_placeholder in zip(flat_state, flat_state_with_placeholders):
if isinstance(maybe_placeholder, _ParamsPlaceholder):
flat_subtree, _ = jax.tree.flatten(field)
else:
flat_subtree = [None] * params_tree_def.num_leaves
filled_flat_state.append(flat_subtree)
transpose_flat_state = zip(*filled_flat_state)
subtrees = map(
functools.partial(jax.tree.unflatten, state_tree_def),
transpose_flat_state,
)
params_like = jax.tree.unflatten(params_tree_def, subtrees)

return ParamsShapedState(non_params_like, params_like)


# pylint: disable=line-too-long
def reshape_params_shaped_state(
params_shaped_state: ParamsShapedState,
) -> base.OptState:
"""Reshape a params-shaped state to a regular state.

Args:
params_shaped_state: A `ParamsShapedState` named tuple returned by
`shape_state_like_params`.

Returns:
A regular optimizer state with the same structure as the original state.

Example:
>>> import jax.numpy as jnp
>>> import optax
>>> opt = optax.adam(1e-3)
>>> params = {'a': jnp.ones(2)), 'b': jnp.ones(1))}
>>> state = opt.init(params)
>>> params_shaped_state = shape_state_like_params(opt, state)
>>> reshaped_state = reshape_params_shaped_state(params_shaped_state)
>>> print(reshaped_state)
(ScaleByAdamState(count=Array(0, dtype=int32), mu={'a': Array([0., 0.], dtype=float32), 'b': Array([0.], dtype=float32)}, nu={'a': Array([0., 0.], dtype=float32), 'b': Array([0.], dtype=float32)}), EmptyState())

.. seealso:: :func:`optax.tree_utils.shape_state_like_params`
"""
# pylint: enable=line-too-long
if params_shaped_state.params_like is None:
return params_shaped_state.non_params_like
placeholder = object()

def tree_fill_nones(tree):
return jax.tree.map(
lambda x: placeholder if x is None else x,
tree,
is_leaf=lambda x: x is None,
)

filled_non_params_like = tree_fill_nones(params_shaped_state.non_params_like)
state_tree_def = jax.tree.structure(filled_non_params_like)
filled_params_like = tree_fill_nones(params_shaped_state.params_like)
params_tree_def = jax.tree.structure(
filled_params_like,
is_leaf=lambda x: jax.tree.structure(x) == state_tree_def,
)
state_like_params = jax.tree.transpose(
params_tree_def, state_tree_def, filled_params_like
)

state = jax.tree.map(
lambda x, y: y if x is placeholder else x,
filled_non_params_like,
state_like_params,
is_leaf=lambda x: (jax.tree.structure(x) == params_tree_def)
or (x is placeholder),
)
return state


################################################################################
# Tree get/set utiltities
################################################################################


@dataclasses.dataclass(frozen=True)
class NamedTupleKey:
"""KeyType for a NamedTuple in a tree.

When using a function ``filtering(path: KeyPath, value: Any) -> bool: ...``
in a tree in :func:`optax.tree_utils.tree_get_all_with_path`,
:func:`optax.tree_utils.tree_get`, or :func:`optax.tree_utils.tree_set`, can
filter the path to check if of the KeyEntry is a NamedTupleKey and then check
if the name of named tuple is the one intended to be searched.

Attributes:
tuple_name (str): name of the tuple containing the key.
name (str): name of the key.

.. seealso:: :class:`jax.tree_util.DictKey`,
:class:`jax.tree_util.FlattenedIndexKey`,
:class:`jax.tree_util.GetAttrKey`,
:class:`jax.tree_util.SequenceKey`,
:func:`optax.tree_utils.tree_get_all_with_path`,
:func:`optax.tree_utils.tree_get`,
:func:`optax.tree_utils.tree_set`,

.. versionadded:: 0.2.2
"""

tuple_name: str
name: str

def __str__(self):
return f"{self.tuple_name}.{self.name}"


_KeyEntry = Union[
jax.tree_util.DictKey,
jax.tree_util.FlattenedIndexKey,
jax.tree_util.GetAttrKey,
jax.tree_util.SequenceKey,
NamedTupleKey,
]

_KeyPath = Tuple[_KeyEntry, ...]


def tree_get_all_with_path(
tree: base.PyTree,
key: Any,
Expand Down
Loading
Loading