Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests/pjit_test.py::PJitErrorTest::testAxisResourcesMismatch fails with "SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!" #26106

Open
booxter opened this issue Jan 26, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@booxter
Copy link

booxter commented Jan 26, 2025

Description

This is happening on Darwin aarch64. The build environment is nixpkgs.

(ins)bash-5.2$ pytest tests/pjit_test.py::PJitErrorTest::testAxisResourcesMismatch
=========================================================================================================================================== test session starts ============================================================================================================================================
platform darwin -- Python 3.12.8, pytest-8.3.3, pluggy-1.5.0
rootdir: /private/tmp/nix-shell.7TgxxA/tmp.6tnVwLWke8/source
configfile: pyproject.toml
plugins: hypothesis-6.112.2, xdist-3.6.1
collected 1 item

tests/pjit_test.py F                                                                                                                                                                                                                                                                                 [100%]

================================================================================================================================================= FAILURES =================================================================================================================================================
_________________________________________________________________________________________________________________________________ PJitErrorTest.testAxisResourcesMismatch __________________________________________________________________________________________________________________________________

self = <pjit_test.PJitErrorTest testMethod=testAxisResourcesMismatch>

    @jtu.with_mesh([('x', 2)])
    def testAxisResourcesMismatch(self):
      x = jnp.ones([])
      p = [None, None, None]

      pjit(lambda x: x, (p,), p)([x, x, x])  # OK

      error = re.escape(
          "pjit in_shardings specification must be a tree prefix of the "
          "positional arguments tuple passed to the `pjit`-decorated function. "
          "In particular, pjit in_shardings must either be a None, a "
          "PartitionSpec, or a tuple of length equal to the number of positional "
          "arguments. But pjit in_shardings is the wrong length: got a "
          "tuple or list of length 3 for an args tuple of length 2.")
      with self.assertRaisesRegex(ValueError, error):
>       pjit(lambda x, y: x, p, p)(x, x)

tests/pjit_test.py:6308:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jax/_src/traceback_util.py:180: in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
jax/_src/pjit.py:340: in cache_miss
    pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
jax/_src/pjit.py:180: in _python_pjit_helper
    p, args_flat = _infer_params(fun, jit_info, args, kwargs)
jax/_src/pjit.py:740: in _infer_params
    p, args_flat = _infer_params_impl(
jax/_src/pjit.py:618: in _infer_params_impl
    in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
jax/_src/util.py:302: in wrapper
    return cached(config.trace_context() if trace_context_in_key else _ignore(),
jax/_src/util.py:296: in cached
    return f(*args, **kwargs)
jax/_src/pjit.py:1137: in _process_in_axis_resources
    in_shardings_flat = flatten_axis_resources(
jax/_src/pjit.py:1067: in flatten_axis_resources
    return tuple(flatten_axes(what, tree, shardings, tupled_args=tupled_args))
jax/_src/api_util.py:404: in flatten_axes
    tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

    @export
    def tree_map(f: Callable[..., Any],
                 tree: Any,
                 *rest: Any,
                 is_leaf: Callable[[Any], bool] | None = None) -> Any:
      """Alias of :func:`jax.tree.map`."""
      leaves, treedef = tree_flatten(tree, is_leaf)
>     all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
E     SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!

jax/_src/tree_util.py:357: SystemError
p========================================================================================================================================= short test summary info ==========================================================================================================================================
FAILED tests/pjit_test.py::PJitErrorTest::testAxisResourcesMismatch - SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated!
============================================================================================================================================ 1 failed in 1.60s =============================================================================================================================================

As far as I understand, the error comes from src/nb_func.cpp from nanobind, though I don't think the package is injected in the build system (jaxlib is pulled as a wheel from PyPI).

Similar errors were noticed before in the same build environment for nixpkgs for other tests, see: https://github.com/NixOS/nixpkgs/blob/c48b8a8a88d6b398bc3ea15caee988f1a821dba5/pkgs/development/python-modules/jax/default.nix#L96

(Though the error seems quite generic, so I don't know if this is the same problem or separate.)


Let me know if I can collect more info to you, not sure how to debug this coming from the wheel.

System info (python version, jaxlib version, accelerator, etc.)

(ins)bash-5.2$ python --version
Python 3.12.8
(ins)bash-5.2$ pip3 freeze
absl-py==2.1.0
attrs==24.2.0
cloudpickle==3.0.0
contourpy==1.3.0
cycler==0.12.1
execnet==2.1.1
flatbuffers==24.12.23
fonttools==4.55.2
hypothesis==6.112.2
iniconfig==2.0.0
installer==0.7.0
jax==0.5.0
jaxlib==0.5.0
kiwisolver==1.4.7
matplotlib==3.9.2
ml_dtypes==0.5.1
numpy==2.2.0
opt_einsum==3.4.0
packaging==24.2
pillow==11.1.0
pluggy==1.5.0
pyparsing==3.1.4
pytest==8.3.3
pytest-xdist==3.6.1
python-dateutil==2.9.0.post0
scipy==1.14.1
setuptools==75.3.0.post0
six==1.17.0
sortedcontainers==2.4.0
wheel==0.45.1
(ins)bash-5.2$ python -c 'import jax; jax.print_environment_info()'
jax:    0.5.0.dev20250125
jaxlib: 0.5.0
numpy:  2.2.0
python: 3.12.8 (main, Dec  3 2024, 18:42:41) [Clang 19.1.6 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='ihrachys-macpro', release='23.6.0', version='Darwin Kernel Version 23.6.0: Fri Nov 15 15:12:37 PST 2024; root:xnu-10063.141.1.702.7~1/RELEASE_ARM64_T6030', machine='arm64')
@booxter booxter added the bug Something isn't working label Jan 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant