feat: register_pytree_node — allow custom classes in mx.compile#3500
feat: register_pytree_node — allow custom classes in mx.compile#3500st-adam wants to merge 3 commits into
Conversation
zcbenz
left a comment
There was a problem hiding this comment.
I think the register_pytree_node API is a good thing to have, but utils.compile would be a bad addition.
The register_pytree_node API should be implemented in C++ layer so mx.compile can be made aware of it directly. And we should probably remove the python versions of tree utils and expose the C++ ones instead so we don't have to duplicate the register_pytree_node implementation in 2 languages.
Adds a JAX-style pytree registration mechanism so third-party Python
classes can flow through mx.compile, tree_visit, tree_map, and the
rest of MLX's tree utilities.
Motivation
----------
mx.compile rejects any function argument that is not a plain array,
list, dict, tuple, or scalar constant:
ValueError: [compile] Function arguments must be trees of arrays or
constants (floats, ints, strings, or None), but received type
mlx_lm.models.cache.ArraysCache.
Any model whose forward pass receives a custom cache object — every
hybrid SSM+attention model in mlx-lm (Qwen 3.5/3.6, Llama 4, Gemma 3n,
etc.) — therefore cannot be compiled, even though the computation is
fully expressible as MLX ops.
Implementation
--------------
The registry, the public API, and all tree-traversal hooks live in
C++ (per review feedback: a Python-side compile wrapper would
duplicate the implementation across two languages).
python/src/trees.h, python/src/trees.cpp:
* PytreeNodeDef — (flatten_fn, unflatten_fn) pair.
* registry() — heap-allocated map keyed by PyTypeObject*, never
freed. Avoids the use-after-finalize segfault
that a function-local static would hit when
Python tears down the interpreter while
stored nb::callables still hold refs. Same
lifetime pattern used by structure_sentinel().
* register_pytree_node(cls, flatten_fn, unflatten_fn) — exposed to
Python as mx.register_pytree_node.
* is_registered_pytree, flatten_registered, unflatten_registered,
registered_pytree_fingerprint — internal helpers.
* tree_visit / tree_map (multi-tree and single-tree overloads) and
tree_visit_update now recurse into registered types, so
tree_unflatten through the compile path reconstructs them.
python/src/transforms.cpp:
* PyCompiledFun::call_impl::recurse adds a pytree_identifier branch:
flattens the registered node into its children and embeds the
type-id + aux hash in the constants vector, so two structurally
different registered instances retrace correctly.
* Error message updated to mention mx.register_pytree_node.
python/src/mlx.cpp:
* Wires init_trees() into NB_MODULE.
python/mlx/utils.py:
* re-exports mlx.core.register_pytree_node so users can do either
`import mlx.core as mx; mx.register_pytree_node(...)` or
`from mlx.utils import register_pytree_node`.
Test
----
python/tests/test_compile.py::test_compile_registered_pytree_node:
* mx.compile rejects an unregistered custom class.
* After registration the compiled forward returns the correct value.
* aux_data tagged differently on two subclasses retraces cleanly.
* flatten_fn returning a malformed value surfaces a clear ValueError.
All existing tests still pass:
- python/tests/test_compile.py — 55 passed
- python/tests/test_tree.py — 4 passed
- python/tests/test_autograd.py + test_vmap.py — full suite green
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
59c91c6 to
727264c
Compare
|
@zcbenz Thanks for the review. I've moved the implementation into C++ as requested:
Verified locally on a CPU build: |
| // Combines id(type) and hash(aux) so that compile retraces if either changes. | ||
| uint64_t registered_pytree_fingerprint(nb::handle obj); | ||
|
|
||
| void init_trees(nb::module_& m); |
There was a problem hiding this comment.
This declaration is not needed, already declared in python/src/mlx.cpp.
| } // namespace | ||
|
|
||
| void register_pytree_node( | ||
| nb::object cls, |
There was a problem hiding this comment.
You can use nb::type_object and let nanobind do the check.
| throw std::invalid_argument( | ||
| "[register_pytree_node] cls must be a Python class object."); | ||
| } | ||
| PyTypeObject* type = reinterpret_cast<PyTypeObject*>(cls.ptr()); |
There was a problem hiding this comment.
We should probably just use PyObject* as registry key, force converting to PyTypeObject* adds code complexity without much benefits.
| nb::handle obj); | ||
|
|
||
| // Calls the registered unflatten_fn for the given type object. | ||
| nb::object unflatten_registered( |
There was a problem hiding this comment.
flatten_registered and unflatten_registered are helpers that do not need to be exposed in header.
| nb::object aux = seq[1]; | ||
| if (!aux.is_none()) { | ||
| try { | ||
| auto h = aux.attr("__hash__")(); |
| nb::object children_obj = seq[0]; | ||
| nb::object aux = seq[1]; | ||
|
|
||
| std::vector<nb::object> children; |
There was a problem hiding this comment.
Can you just do auto children = nb::cast<std::vector<nb::object>>(seq[0]);?
| "[flatten_registered] type is not registered as a pytree node"); | ||
| } | ||
| nb::object result = it->second.flatten_fn(obj); | ||
| if (!nb::isinstance<nb::tuple>(result) && !nb::isinstance<nb::list>(result)) { |
There was a problem hiding this comment.
The nb::cast<nb::sequence> should be able to do type check so I think this check is redundant.
| auto seq = nb::cast<nb::sequence>(result); | ||
| if (nb::len(seq) == 2) { | ||
| nb::object aux = seq[1]; | ||
| if (!aux.is_none()) { |
There was a problem hiding this comment.
There is no need to be defensive here since it is inside a try/catch, just do castings and let it throw when bad happens.
| for (auto& c : children) { | ||
| new_children.push_back(recurse(c)); | ||
| } | ||
| return unflatten_registered(type_handle, aux, new_children); |
There was a problem hiding this comment.
You can just pass subtree here?
- Header surface trimmed: only `register_pytree_node`, `is_registered_pytree`, `pytree_children`, `registered_pytree_fingerprint` are exposed. `flatten_registered`/`unflatten_registered` are internal helpers in trees.cpp and `init_trees` is no longer redeclared (already in mlx.cpp). - `register_pytree_node` now takes `nb::type_object` so nanobind enforces the type check; manual `PyType_Check` is gone. - Registry keyed by `PyObject*` directly — no `PyTypeObject*` reinterpret cast at the boundary. - Internal `flatten_registered` uses `nb::cast<std::vector<nb::object>>` for children and lets `nb::cast<nb::sequence>` enforce the list/tuple shape. - Fingerprint uses `nb::hash` and lets nanobind throw on unhashable aux (no extra defensive casting). - `tree_visit_update` / `tree_map` pass the subtree handle directly to `unflatten_registered` instead of fabricating a type handle.
|
@zcbenz Thanks for the detailed review — all nine inline comments addressed in 2f7c5e6:
Lint also fixed (clang-format + black) in the previous commit. Local CPU run: |
|
Sorry for being late to this but why do we think that is better than A more real world example of this is simply the training loop. We don't need the model to be a typed PyTree, we can very quickly make the whole call "pure functional" by passing in a dictionary of parameters Changing gears after discussing whether we should add this at all, I see two issues to be addressed in the code
|
|
Thanks @angeloskath. Over-complexity concern noted, and I'm happy to make this as simple as the use case actually requires. Before discussing implementation shape, let me lay out the problem this PR is trying to solve, since I think that context is what makes the API question concrete.
Every model in mlx-lm that uses a cache class other than the plain 17 mlx-lm model architectures use these wrapped caches and therefore cannot be compiled today: On workarounds. The pure-functional pattern ( The current implementation in this PR is a C++ registry plus a public The simpler alternative is to drop class ArraysCache:
def tree_flatten(self) -> tuple[list, Any]:
return self.cache, (self.size,)
@classmethod
def tree_unflatten(cls, aux, children):
c = cls(*aux); c.cache = children; return cMLX side: tree traversal + Either shape works for me. I can pivot to the simpler alternative and drop the registry, or stay on the current implementation and land the two fixes you flagged. Whichever you and @zcbenz prefer, let me know and I'll push the change. |
Fixes #3499.
Addresses review feedback from @zcbenz: implementation moved to C++ so
mx.compileis natively aware of registered pytree types. The previous Python-sideutils.compilewrapper has been removed.Summary
mx.register_pytree_node(cls, flatten_fn, unflatten_fn)API (mirrorsjax.tree_util.register_pytree_node)python/src/trees.cpp, exposed viainit_trees()nanobind modulePyCompiledFun::call_implrecurse handles registered types: their type-id + aux hash participate in the compile cache keytree_visit,tree_map,tree_visit_update) recurse into registered nodes — same code path used bytree_unflattenon the compile pathmlx.utilsre-exportsregister_pytree_nodefor the naturalfrom mlx.utils import …access patternAPI
Test plan
python/tests/test_compile.py::TestCompile::test_compile_registered_pytree_nodecovers:flatten_fnreturn surfaces a cleanValueErrorRegression sweep on the CPU build:
python/tests/test_compile.py— 55 passedpython/tests/test_tree.py— 4 passedpython/tests/test_autograd.py+test_vmap.py— full suites greenImplementation notes
std::unordered_map<PyTypeObject*, PytreeNodeDef>triggers a use-after-finalize segfault at interpreter shutdown because the storednb::callables outlive Python state. This matches the lifetime trick used bystructure_sentinel().id(type)withhash(aux_data)(golden-ratio mixing constant) so two structurally distinct registered instances retrace correctly. Unhashable aux falls back to type-only fingerprinting.flatten_fnmay return children as eitherlistortuple; the registry rejects non-sequence return values up front.Files
python/src/trees.h(+38) — public API + helperspython/src/trees.cpp(+263) — registry,init_trees, tree-traversal hookspython/src/transforms.cpp(+15) —PyCompiledFun::call_implrecurse + error-message hintpython/src/mlx.cpp(+2) — wireinit_trees()intoNB_MODULEpython/mlx/utils.py(+1 net) — re-exportregister_pytree_nodepython/tests/test_compile.py(+63) — coverage🤖 Generated with Claude Code