Skip to content

fix(jax): include in_out_argnames and stage argnames in FFI registry …#1249

Open
Adityakk9031 wants to merge 1 commit intoNVIDIA:mainfrom
Adityakk9031:#1215
Open

fix(jax): include in_out_argnames and stage argnames in FFI registry …#1249
Adityakk9031 wants to merge 1 commit intoNVIDIA:mainfrom
Adityakk9031:#1215

Conversation

@Adityakk9031
Copy link
Contributor

@Adityakk9031 Adityakk9031 commented Feb 23, 2026

Issue:
Both
jax_kernel()
and
jax_callable()
use a registry dict to cache and deduplicate FFI wrappers. The key tuple used to look up the registry was missing parameters that affect the wrapper's behaviour — in_out_argnames in
jax_kernel()
, and in_out_argnames, stage_in_argnames, stage_out_argnames in
jax_callable()
. This meant calling either function with the same kernel but different argname configurations silently returned the first cached object with the wrong configuration.

Fix:
Added the missing parameters to the key tuples in both functions in
warp/_src/jax_experimental/ffi.py
. Since list is not hashable, they are converted to tuple before being added to the key. Two wrappers with different configurations now correctly produce different keys and are stored as separate objects in the registry.

Summary by CodeRabbit

  • Bug Fixes
    • Improved FFI operation caching to correctly distinguish entries based on argument configurations, enhancing cache accuracy and reliability.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

…cache keys

Signed-off-by: Aditya kumar singh <143548997+Adityakk9031@users.noreply.github.com>
@coderabbitai
Copy link

coderabbitai bot commented Feb 23, 2026

📝 Walkthrough

Walkthrough

Adds hashable components (hashable_in_out, hashable_stage_in, hashable_stage_out) to cache keys in FFI kernel and callable functions, enabling the cache to distinguish entries based on argument name configuration.

Changes

Cohort / File(s) Summary
Cache Key Enhancement
warp/_src/jax_experimental/ffi.py
Extends cache keys for jax_kernel and jax_callable with hashable tuples representing in_out_argnames and stage-related argnames to improve cache discrimination by argument configuration.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly identifies the fix: adding in_out_argnames and stage argnames to the FFI registry cache key to prevent incorrect wrapper reuse.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@greptile-apps
Copy link

greptile-apps bot commented Feb 23, 2026

Greptile Summary

Fixes a caching bug where FFI wrappers with different argument configurations were incorrectly deduplicated. The registry cache keys now include in_out_argnames, stage_in_argnames, and stage_out_argnames parameters to ensure each unique configuration gets its own wrapper.

  • Added hashable_in_out, hashable_stage_in, and hashable_stage_out to cache keys in jax_kernel() and jax_callable()
  • Converts list parameters to tuples before adding to keys (lists are not hashable)
  • Follows existing pattern for handling optional parameters with None checks

Confidence Score: 5/5

  • Safe to merge with no concerns
  • The fix correctly addresses the caching bug by adding missing parameters to cache keys, follows existing patterns in the codebase, and properly handles list-to-tuple conversion for hashability
  • No files require special attention

Important Files Changed

Filename Overview
warp/_src/jax_experimental/ffi.py Fixes FFI registry cache keys by including argname parameters that affect wrapper behavior

Last reviewed commit: dad527b

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
warp/_src/jax_experimental/ffi.py (2)

1206-1214: Use frozenset (or tuple(sorted(...))) for order-independent cache key semantics.

in_out_argnames is immediately converted to a set in FfiKernel.__init__ (line 170), so argument order is semantically irrelevant. Using tuple(in_out_argnames) makes the cache key order-sensitive: ["a", "b"] and ["b", "a"] produce different keys but create behaviorally identical FfiKernel objects, causing redundant FFI target registrations.

♻️ Proposed fix
-        hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None
+        hashable_in_out = frozenset(in_out_argnames) if in_out_argnames is not None else None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@warp/_src/jax_experimental/ffi.py` around lines 1206 - 1214, The cache key
currently uses tuple(in_out_argnames) which is order-sensitive even though
FfiKernel.__init__ turns in_out_argnames into a set, causing equivalent kernels
with different arg order to miss cache hits; change the construction of
hashable_in_out to an order-independent representation (e.g.,
frozenset(in_out_argnames) or tuple(sorted(in_out_argnames))) and update the key
creation that references hashable_in_out so identical FfiKernel objects produce
the same cache key.

1553-1564: Same order-sensitivity concern for all three hashable argname variables in jax_callable.

All three parameters — in_out_argnames, stage_in_argnames, stage_out_argnames — are stored as sets inside FfiCallable.__init__ (lines 529, 525, 526 respectively), so their ordering is semantically irrelevant. Using tuple(...) makes the cache key order-dependent, causing unnecessary duplicate FfiCallable registrations and FFI target registrations.

♻️ Proposed fix
-    hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None
-    hashable_stage_in = tuple(stage_in_argnames) if stage_in_argnames is not None else None
-    hashable_stage_out = tuple(stage_out_argnames) if stage_out_argnames is not None else None
+    hashable_in_out = frozenset(in_out_argnames) if in_out_argnames is not None else None
+    hashable_stage_in = frozenset(stage_in_argnames) if stage_in_argnames is not None else None
+    hashable_stage_out = frozenset(stage_out_argnames) if stage_out_argnames is not None else None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@warp/_src/jax_experimental/ffi.py` around lines 1553 - 1564, The cache key in
jax_callable is order-sensitive because in_out_argnames, stage_in_argnames, and
stage_out_argnames (which are stored as sets in FfiCallable.__init__) are
converted with tuple(...); change those conversions to an order-insensitive
representation (e.g., use frozenset(...) or tuple(sorted(...))) when building
key so the key does not depend on arbitrary set iteration order and avoids
duplicate FfiCallable/FFI registrations; update the creation of hashable_in_out,
hashable_stage_in, and hashable_stage_out accordingly where key is constructed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1206-1214: The cache key currently uses tuple(in_out_argnames)
which is order-sensitive even though FfiKernel.__init__ turns in_out_argnames
into a set, causing equivalent kernels with different arg order to miss cache
hits; change the construction of hashable_in_out to an order-independent
representation (e.g., frozenset(in_out_argnames) or
tuple(sorted(in_out_argnames))) and update the key creation that references
hashable_in_out so identical FfiKernel objects produce the same cache key.
- Around line 1553-1564: The cache key in jax_callable is order-sensitive
because in_out_argnames, stage_in_argnames, and stage_out_argnames (which are
stored as sets in FfiCallable.__init__) are converted with tuple(...); change
those conversions to an order-insensitive representation (e.g., use
frozenset(...) or tuple(sorted(...))) when building key so the key does not
depend on arbitrary set iteration order and avoids duplicate FfiCallable/FFI
registrations; update the creation of hashable_in_out, hashable_stage_in, and
hashable_stage_out accordingly where key is constructed.

ℹ️ Review info

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1de9c91 and dad527b.

📒 Files selected for processing (1)
  • warp/_src/jax_experimental/ffi.py

@Adityakk9031
Copy link
Contributor Author

@shi-eric have a look

@Adityakk9031
Copy link
Contributor Author

@c0d1f1ed have a look

@christophercrouzet
Copy link
Member

@Adityakk9031 Do you have a need for these changes? What is your use case?

@Adityakk9031
Copy link
Contributor Author

@christophercrouzet The fix addresses a silent correctness bug where reusing the same kernel with different in_out_argnames or staging configurations would return a stale cached wrapper. This ensures JAX FFI registrations correctly and uniquely identify the intended argument behavior, as reported in #1215.

@christophercrouzet
Copy link
Member

@Adityakk9031 Thanks for the AI summary, but my question was why you need this.

If you're pinging each team member daily on this PR (and on #1248), surely it must be because you have an urgent need for it in a project of yours?

Making such changes with an AI agent is fast, but reviewing and iterating on these takes time, energy, and a holistic understanding of the codebase.

Please help us understand why we should prioritize your pull requests over what we're currently working on.

@Adityakk9031
Copy link
Contributor Author

Adityakk9031 commented Feb 25, 2026

@christophercrouzet Sir Thanks for the clarification. The root cause and suggested fix were already described in the issue, and I simply implemented that fix. Also, sorry for the earlier AI-generated summary the code change itself was written by me, not by agent

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Adityakk9031
Copy link
Contributor Author

Also Sorry for tagging multiple people I can be a little dumb sometimes.

@christophercrouzet
Copy link
Member

@Adityakk9031 No worries, and thanks for your willingness to contribute to the project.

We're a small team with a lot on our plate, and when issues are already assigned to team members, it generally means we have them on our radar and plan to address them as part of our roadmap. Uncoordinated PRs for these issues can end up adding to our workload rather than reducing it.

If you do open a PR and need it reviewed promptly, please help us understand why. For example, if it's blocking a project of yours or addresses a critical bug you're hitting. Without that context, it's difficult for us to justify prioritizing an external PR over our current work.

We'll be updating our contribution guidelines to make this clearer for everyone, and we'll take a look at your PR when we get a chance.

Thanks for your understanding!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants