-
Notifications
You must be signed in to change notification settings - Fork 661
Enhance gradient api #2965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance gradient api #2965
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Should the return type be propagated?
We have two options:
- Propagate the result to the
GradientsParams
methods (get(...)
andremove(...)
) .unwrap()
the result, which will panic with the error type (or.expect(...)
with a descriptive message)
but I don't think we should not simply ignore the error with .ok()
. This will be even more difficult to understand for user error 😅
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2965 +/- ##
==========================================
- Coverage 82.17% 81.06% -1.11%
==========================================
Files 871 872 +1
Lines 120526 121221 +695
==========================================
- Hits 99041 98268 -773
- Misses 21485 22953 +1468 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Hey! If you don't mind I'll choose the first option after getting the check to run at my machine. |
Ha, no need to apologize! 😄 Your feedback is important, especially since you were one of the first users to report this discrepancy. So if you feel that one option is more natural / helpful for an end-user, that's great! |
…ptions and avoiding unwraps
…non disruptive way
d9f4ee2
to
7741896
Compare
Please excuse the delay. Here is my proposed solution for #2924. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, the addition of the result makes sense.
Just have some comments on the usage. See below 🙂
Thank you for taking the time for reviewing my PR and sorry for that sloppy oversight! |
Ha no worries! That's what reviews are for 🙂 |
…ing to handle the result value in methods downstream in a way that is consequent and helpful.
…rContainerError gets another variant for whatever reason and to increase the ease of debugging.
@VirtualNonsense also let me know when this is ready for another round of review 🙂 |
@laggui |
/// Gradients are not stored on the autodiff backend. | ||
DowncastError, | ||
} | ||
|
||
/// Contains tensor of arbitrary dimension. | ||
#[derive(Debug)] | ||
pub struct TensorContainer<ID> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason to have a Box<dyn Any>
here was because the TensorPrimitive
used to be generic over the rank. Now that's not the case anymore, we could actually remove the Box<dyn Any>
and see if we can use the right backend as generic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for taking the time to review my PR!
If I unterstand you correctly you are suggesting to change the TensorContainer
to something like this:
/// Contains tensor of arbitrary dimension.
#[derive(Debug)]
pub struct TensorContainer<ID, B: Backend> {
tensors: HashMap<ID, TensorPrimitive<B>>,
}
and thereby sidestepping the issue entirely?
I have to admit I'm not that good with rust yet and do not know burn well enough to properly foresee the consequences of that change but I can try if you like :D
Should I create another PR for that version?
This PR has been marked as stale because it has not been updated for over a month |
This PR has been marked as stale because it has not been updated for over a month |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#2924
Changes
I adjusted the
TensorContainer
according the discussion in #2924 and adjusted the usage accordingly.Question: Should the return type be propagated?
Testing