Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GlobalDriftCompensationWithExactReference drift compensation class #674

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/aihwkit/inference/compensation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch import Tensor
from torch.autograd import no_grad

from aihwkit.simulator.tiles.inference import InferenceTileWithPeriphery


class BaseDriftCompensation:
"""Base class for drift compensations."""
Expand All @@ -19,17 +21,18 @@ def __init__(self) -> None:
pass

@no_grad()
def init_baseline(self, forward_output: Tensor) -> Tuple[Tensor, Tensor]:
def init_baseline(self, tile: InferenceTileWithPeriphery) -> Tuple[Tensor, Tensor]:
"""Initialize the base line for applying the compensation.

Uses a all one tensor for read_out.

Args:
forward_output: forward output of the read out vector to compensate
tile: forward output of the read out vector to compensate

Returns:
reference tensor readout
"""
forward_output = tile._forward_drift_readout_tensor(True, exact_reference=False)
ref_value = self.readout(forward_output)

return ref_value
Expand Down
27 changes: 27 additions & 0 deletions src/aihwkit/inference/compensation/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

"""Global drift compensation for inference."""

from typing import Tuple

from torch.autograd import no_grad
from torch import abs as torch_abs
from torch import clamp, Tensor, eye

from aihwkit.inference.compensation.base import BaseDriftCompensation
from aihwkit.simulator.tiles.inference import InferenceTileWithPeriphery


class GlobalDriftCompensation(BaseDriftCompensation):
Expand All @@ -36,6 +39,30 @@ def __str__(self) -> str:
return "{}()".format(self.__class__.__name__)


class GlobalDriftCompensationWithExactReference(GlobalDriftCompensation):
"""Global drift compensation using an exact (ideal) reference readout.

Uses a constant factor for compensating the drift.
"""

@no_grad()
def init_baseline(self, tile: InferenceTileWithPeriphery) -> Tuple[Tensor, Tensor]:
"""Initialize the base line for applying the compensation.

Uses a all one tensor for read_out.

Args:
tile: forward output of the read out vector to compensate

Returns:
reference tensor readout
"""
forward_output = tile._forward_drift_readout_tensor(True, exact_reference=True)
ref_value = self.readout(forward_output)

return ref_value


class PerColumnDriftCompensation(BaseDriftCompensation):
"""Per column drift compensation.
Uses a vector for compensating the drift.
Expand Down
85 changes: 67 additions & 18 deletions src/aihwkit/simulator/tiles/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from aihwkit.simulator.tiles.base import BaseTile
from aihwkit.simulator.rpu_base import tiles
from aihwkit.simulator.parameters.helpers import parameters_to_bindings
from aihwkit.simulator.parameters.enums import WeightModifierType, WeightClipType, WeightRemapType
from aihwkit.simulator.parameters.enums import (
WeightModifierType,
WeightClipType,
WeightRemapType,
)
from aihwkit.inference.noise.base import BaseNoiseModel

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,19 +88,28 @@ def init_mapping_scales(self) -> None:
This method is called from the constructor.
"""
super().init_mapping_scales()
if hasattr(self.rpu_config, "remap") and self.rpu_config.remap.type != WeightRemapType.NONE:
if (
hasattr(self.rpu_config, "remap")
and self.rpu_config.remap.type != WeightRemapType.NONE
):
# needs to be always out_size
mapping_scales = ones(
(self.out_size,), dtype=self.get_dtype(), device=self.device, requires_grad=False
(self.out_size,),
dtype=self.get_dtype(),
device=self.device,
requires_grad=False,
)
self.set_mapping_scales(mapping_scales)

@no_grad()
def _forward_drift_readout_tensor(self, reset_if: bool = False) -> Optional[Tensor]:
def _forward_drift_readout_tensor(
self, reset_if: bool = False, exact_reference: bool = False
) -> Optional[Tensor]:
"""Perform a forward pass using the drift read-out tensor.

Args:
reset_if: Will reset the readout tensor, otherwise use the stored one
exact_reference: Whether or not to compute the reference using an "ideal" forward pass.

Returns:
Readout tensor if drift compensation is on
Expand All @@ -109,20 +122,41 @@ def _forward_drift_readout_tensor(self, reset_if: bool = False) -> Optional[Tens

if self.drift_readout_tensor is None or reset_if:
self.drift_readout_tensor = (
self.rpu_config.drift_compensation.get_readout_tensor(self.tile.get_x_size())
self.rpu_config.drift_compensation.get_readout_tensor(
self.tile.get_x_size()
)
.detach()
.to(self.device)
)
if self.in_trans:
self.drift_readout_tensor = self.drift_readout_tensor.tranpose(0, 1).clone()
self.drift_readout_tensor = self.drift_readout_tensor.tranpose(
0, 1
).clone()
else:
self.drift_readout_tensor = self.drift_readout_tensor.to(self.device)

# We need to take the bias as a common column here, also we do
# not want to use indexed.
return self.tile.forward(
self.drift_readout_tensor, False, self.in_trans, self.out_trans, True, self.non_blocking
)
if exact_reference:
input_ = self.drift_readout_tensor
if self.in_trans:
input_ = input_.T

output = (input_ @ self.reference_combined_weights.T)
if self.out_trans:
output = output.T

else:
output = self.tile.forward(
self.drift_readout_tensor,
False,
self.in_trans,
self.out_trans,
True,
self.non_blocking,
)

return output

@no_grad()
def program_weights(
Expand Down Expand Up @@ -160,7 +194,9 @@ def program_weights(

if noise_model is not None:
if not isinstance(noise_model, BaseNoiseModel):
raise ConfigError("Given noise model has to be of type 'BaseNoiseModel'")
raise ConfigError(
"Given noise model has to be of type 'BaseNoiseModel'"
)

self.rpu_config.noise_model = noise_model

Expand All @@ -170,16 +206,17 @@ def program_weights(
(
self.programmed_weights,
self.drift_noise_parameters,
) = self.rpu_config.noise_model.apply_programming_noise(self.reference_combined_weights)
) = self.rpu_config.noise_model.apply_programming_noise(
self.reference_combined_weights
)

self.tile.set_weights(self.programmed_weights)

if (
hasattr(self.rpu_config, "drift_compensation")
and self.rpu_config.drift_compensation is not None
):
forward_output = self._forward_drift_readout_tensor(True)
self.drift_baseline = self.rpu_config.drift_compensation.init_baseline(forward_output)
self.drift_baseline = self.rpu_config.drift_compensation.init_baseline(self)
Copy link
Contributor

@maljoras-sony maljoras-sony Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here is still a conceptual error as far as I can see. If you simply set is_perfect=True then the programmed_weights are used to establish the reference without any additional noise. However, this cannot be done in reality as the programmed weights are those in the conductances. What you want to do is taking the self.reference_combined_weights as the reference for the drift, since those weights are the floating point weights that are then programmed onto the crossbar. This is a difference, because the programmed weights might in fact have some corrupt devices etc. and thus different from the floating point weights even if is_perfect=True is set.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - thanks for the explanation! I'll change this now. As an aside, I will raise a separate issue to add a warning/note to the documentation here: https://aihwkit.readthedocs.io/en/latest/api/aihwkit.simulator.parameters.io.html#aihwkit.simulator.parameters.io.IOParameters.is_perfect. It is somewhat ambiguous what a forward pass is and whether a user would expect the FP or programmed weights to be used to compute it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maljoras could you please check my recent code changes? I have used self.reference_combined_weights, as suggested.


@no_grad()
def drift_weights(self, t_inference: float = 0.0) -> None:
Expand Down Expand Up @@ -223,7 +260,9 @@ def drift_weights(self, t_inference: float = 0.0) -> None:
and self.rpu_config.drift_compensation is not None
):
forward_output = self._forward_drift_readout_tensor()
alpha = self.rpu_config.drift_compensation.apply(forward_output, self.drift_baseline)
alpha = self.rpu_config.drift_compensation.apply(
forward_output, self.drift_baseline
)
if isinstance(self, Module):
# somehow legacy is incompatible with torch buffers
self.__dict__.pop("alpha", None)
Expand Down Expand Up @@ -268,14 +307,20 @@ def post_update_step(self) -> None:
if not hasattr(self, "_tmp"):
# pylint: disable=attribute-defined-outside-init
self._tmp = {} # type: Dict[str, Any]
if hasattr(self.rpu_config, "clip") and self.rpu_config.clip.type != WeightClipType.NONE:
if (
hasattr(self.rpu_config, "clip")
and self.rpu_config.clip.type != WeightClipType.NONE
):
if on_the_fly_bindings or "weight_clip_params" not in self._tmp:
self._tmp["weight_clip_params"] = parameters_to_bindings(
self.rpu_config.clip, data_type
)
self.tile.clip_weights(self._tmp["weight_clip_params"])

if hasattr(self.rpu_config, "remap") and self.rpu_config.remap.type != WeightRemapType.NONE:
if (
hasattr(self.rpu_config, "remap")
and self.rpu_config.remap.type != WeightRemapType.NONE
):
if on_the_fly_bindings or "weight_remap_params" not in self._tmp:
self._tmp["weight_remap_params"] = parameters_to_bindings(
self.rpu_config.remap, data_type
Expand Down Expand Up @@ -306,7 +351,9 @@ def __getstate__(self) -> Dict:
state = super().__getstate__()
return state

def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "BaseTile":
def cuda(
self, device: Optional[Union[torch_device, str, int]] = None
) -> "BaseTile":
self.alpha = self.alpha.cuda(device)
ret = super().cuda(device)
return ret
Expand All @@ -317,7 +364,9 @@ def cpu(self) -> "BaseTile":
return ret


class InferenceTile(TileModule, InferenceTileWithPeriphery, RPUCudaSimulatorTileWrapper):
class InferenceTile(
TileModule, InferenceTileWithPeriphery, RPUCudaSimulatorTileWrapper
):
"""Tile used for analog inference and hardware-aware training for inference.

Note:
Expand Down