Skip to content

Enhance gradient api #2965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

VirtualNonsense
Copy link

@VirtualNonsense VirtualNonsense commented Mar 27, 2025

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#2924

Changes

I adjusted the TensorContainer according the discussion in #2924 and adjusted the usage accordingly.
Question: Should the return type be propagated?

Testing

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Should the return type be propagated?

We have two options:

  1. Propagate the result to the GradientsParams methods (get(...) and remove(...))
  2. .unwrap() the result, which will panic with the error type (or .expect(...) with a descriptive message)

but I don't think we should not simply ignore the error with .ok(). This will be even more difficult to understand for user error 😅

Copy link

codecov bot commented Mar 27, 2025

Codecov Report

Attention: Patch coverage is 64.42953% with 53 lines in your changes missing coverage. Please review.

Project coverage is 81.06%. Comparing base (70e2bc7) to head (3ea5e58).
Report is 187 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-core/src/optim/grad_accum.rs 12.12% 29 Missing ⚠️
crates/burn-core/src/optim/visitor.rs 0.00% 12 Missing ⚠️
crates/burn-core/src/optim/simple/adaptor.rs 74.19% 8 Missing ⚠️
crates/burn-autodiff/src/backend.rs 76.92% 3 Missing ⚠️
crates/burn-tensor/src/tensor/container.rs 91.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2965      +/-   ##
==========================================
- Coverage   82.17%   81.06%   -1.11%     
==========================================
  Files         871      872       +1     
  Lines      120526   121221     +695     
==========================================
- Hits        99041    98268     -773     
- Misses      21485    22953    +1468     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@VirtualNonsense
Copy link
Author

Hey!
Please excuse my rather cautious approach. I'm a little bit burned by my corporate environment where even minor changes to lead to people scream at you 😅

If you don't mind I'll choose the first option after getting the check to run at my machine.

@laggui
Copy link
Member

laggui commented Mar 31, 2025

Ha, no need to apologize! 😄

Your feedback is important, especially since you were one of the first users to report this discrepancy. So if you feel that one option is more natural / helpful for an end-user, that's great!

@VirtualNonsense VirtualNonsense marked this pull request as ready for review April 7, 2025 14:59
@VirtualNonsense
Copy link
Author

Please excuse the delay.

Here is my proposed solution for #2924.
I finally did not change the return type of the backend trait methods and chose to panic instead since it might break a lot of code and I don't feel like this change is important enough to justify such a change.
Let me know if I can enhance this PR in any way!

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, the addition of the result makes sense.

Just have some comments on the usage. See below 🙂

@VirtualNonsense
Copy link
Author

VirtualNonsense commented Apr 9, 2025

Thank you for taking the time for reviewing my PR and sorry for that sloppy oversight!
I'll try to account for those issues as soon as possible.

@laggui
Copy link
Member

laggui commented Apr 9, 2025

Ha no worries! That's what reviews are for 🙂

Andreas Nachtmann and others added 7 commits April 10, 2025 13:09
@laggui
Copy link
Member

laggui commented Apr 14, 2025

@VirtualNonsense also let me know when this is ready for another round of review 🙂

@VirtualNonsense
Copy link
Author

VirtualNonsense commented Apr 14, 2025

@laggui
I tried to address all issues/comments and the pipeline seems to be happy as well so it should be fine for a review.
Thanks in advance!

/// Gradients are not stored on the autodiff backend.
DowncastError,
}

/// Contains tensor of arbitrary dimension.
#[derive(Debug)]
pub struct TensorContainer<ID> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason to have a Box<dyn Any> here was because the TensorPrimitive used to be generic over the rank. Now that's not the case anymore, we could actually remove the Box<dyn Any> and see if we can use the right backend as generic.

Copy link
Author

@VirtualNonsense VirtualNonsense Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking the time to review my PR!

If I unterstand you correctly you are suggesting to change the TensorContainer to something like this:

/// Contains tensor of arbitrary dimension.
#[derive(Debug)]
pub struct TensorContainer<ID, B: Backend> {
    tensors: HashMap<ID, TensorPrimitive<B>>,
}

and thereby sidestepping the issue entirely?
I have to admit I'm not that good with rust yet and do not know burn well enough to properly foresee the consequences of that change but I can try if you like :D
Should I create another PR for that version?

Copy link
Contributor

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added the stale The issue or pr has been open for too long label May 15, 2025
@github-actions github-actions bot removed the stale The issue or pr has been open for too long label May 16, 2025
Copy link
Contributor

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added stale The issue or pr has been open for too long and removed stale The issue or pr has been open for too long labels Jun 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants