-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added GlobalDriftCompensationWithExactReference drift compensation class #674
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: Corey Liam Lammie <[email protected]>
@coreylammie, the drift compensation is on purpose using the rpu config of the tile as this is the realistic situation. In reality, how can you achieve a perfect reference readout, as there always will be noise when reading the analogue devics? I don't think you should change this default behaviour as it would then make the simulation unrealistic and wrong. |
If you want perfect readout during drift compensation, then you should explicitly write a new drift compensation. This should be a special case as this behaviour is wrong in most natural cases. |
This PR will also invalidate the papers we have written using this standard drift where the correct way was used to estimate the drift compensation in a noisy manner. So please take this PR back. If you need a perfect compensation for some reasons, write a special drift compensation class by deriving from the default one. |
Also the implementation seems to permanently set the rpu config to perfect, which is then always using the perfect case, which is even wrong, if you only want to use a noise free reference. It will then use a completely noise free drift. That's completely wrong. |
Please don't change the default drift class in this manner. Use a special drift class if you need this special behaviour for some reasons. |
I am about to close this PR. Please submit a new one where you modify your PR accordingly by deriving a new drift compensation class from the default class and not changing the behaviour of the default class. Also, as stated above, it seems that the rpu configuration is permanently changed which is wrong in any case. |
Btw, if you want the exact weights of a tile for the reference computation of your special drift compensation class, simply use |
@coreylammie can you please let us know what is the latest on this PR. Should we go with the changes you have? We need to close this PR soon. |
|
||
self.tile.set_weights(self.programmed_weights) | ||
|
||
if ( | ||
hasattr(self.rpu_config, "drift_compensation") | ||
and self.rpu_config.drift_compensation is not None | ||
): | ||
forward_output = self._forward_drift_readout_tensor(True) | ||
self.drift_baseline = self.rpu_config.drift_compensation.init_baseline(forward_output) | ||
self.drift_baseline = self.rpu_config.drift_compensation.init_baseline(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here is still a conceptual error as far as I can see. If you simply set is_perfect=True
then the programmed_weights
are used to establish the reference without any additional noise. However, this cannot be done in reality as the programmed weights are those in the conductances. What you want to do is taking the self.reference_combined_weights as the reference for the drift, since those weights are the floating point weights that are then programmed onto the crossbar. This is a difference, because the programmed weights might in fact have some corrupt devices etc. and thus different from the floating point weights even if is_perfect=True
is set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense - thanks for the explanation! I'll change this now. As an aside, I will raise a separate issue to add a warning/note to the documentation here: https://aihwkit.readthedocs.io/en/latest/api/aihwkit.simulator.parameters.io.html#aihwkit.simulator.parameters.io.IOParameters.is_perfect. It is somewhat ambiguous what a forward pass is and whether a user would expect the FP or programmed weights to be used to compute it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@maljoras could you please check my recent code changes? I have used self.reference_combined_weights
, as suggested.
02ec964
to
1203880
Compare
is_perfect=True
when determining the reference readout
@maljoras-sony @maljoras can you please review the changes that Corey did to address your review. |
Description
Added a new drift compensation mechanism which uses an ideal reference readout. In the default global drift compensation mechanism, all non-idealities (as set by the corresponding
rpu_config
) are modeled, potentially resulting in sub-optimal drift compensation scales being computed in some scenarios, e.g., where the output noise is sufficiently large.