-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use grad context for hashing the generated stablehlo program #1604
Use grad context for hashing the generated stablehlo program #1604
Conversation
Hello. You may have forgotten to update the changelog!
|
I think the proposed solution in the current PR is a bit insufficient. It depends on what influences the contextual decomposition, but it is a good starting point. We can discuss more about this tomorrow :) Like |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## v0.11.0-rc #1604 +/- ##
=============================================
Coverage ? 96.61%
=============================================
Files ? 80
Lines ? 8615
Branches ? 837
=============================================
Hits ? 8323
Misses ? 238
Partials ? 54 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Hi @erick-xanadu, do you think this PR can be merged for the release or it need more work? |
@mehrdad2m I think it can be merged. (I improved on the hashing after my comment). One issue is testing. I tested it by running it against the benchmark and noticing the regression no longer being there. But I don't think it is possible to write a unit test for it. |
@mehrdad2m in the release |
Makes sense to me! |
8a79821
to
9028dac
Compare
**Context:** After PR #1562, a single function could have multiple JAXPR representations based on whether it was under a grad context or not. This made the previous hash based on the function id create possible conflicts. To address this, we hashed on the jaxpr string representation. (We cannot hash on the jax object itself since they are unique). The JAXPR string representation can be very long and hashing over long strings can take a long time. **Description of the Change:** Instead of hashing the string representation, add a simple key to denote whether it is inside a grad context or not. **Benefits:** Reduced compilation time. **Possible Drawbacks:** The cache key is getting more complicated. Maybe the drawbacks outweight the benefits now? **Related GitHub Issues:** [sc-88454]
Context: After PR #1562, a single function could have multiple JAXPR representations based on whether it was under a grad context or not. This made the previous hash based on the function id create possible conflicts. To address this, we hashed on the jaxpr string representation. (We cannot hash on the jax object itself since they are unique).
The JAXPR string representation can be very long and hashing over long strings can take a long time.
Description of the Change: Instead of hashing the string representation, add a simple key to denote whether it is inside a grad context or not.
Benefits: Reduced compilation time.
Possible Drawbacks: The cache key is getting more complicated. Maybe the drawbacks outweight the benefits now?
Related GitHub Issues:
[sc-88454]