Skip to content

Commit 0cd9a9f

Browse files
authored
added FP preset (#705)
Signed-off-by: Julian Buechel <[email protected]>
1 parent 86620b3 commit 0cd9a9f

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

src/aihwkit/simulator/presets/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@
6868
MixedPrecisionGokmenVlasovPreset,
6969
MixedPrecisionPCMPreset,
7070
)
71-
from .inference import StandardHWATrainingPreset
72-
71+
from .inference import StandardHWATrainingPreset, FloatingPointPreset
7372
from .devices import (
7473
ReRamESPresetDevice,
7574
ReRamSBPresetDevice,

src/aihwkit/simulator/presets/inference.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Optional
1111
from dataclasses import dataclass, field
1212

13-
from aihwkit.simulator.configs.configs import InferenceRPUConfig
13+
from aihwkit.simulator.configs.configs import InferenceRPUConfig, TorchInferenceRPUConfig
1414
from aihwkit.simulator.parameters import (
1515
MappingParameter,
1616
IOParameters,
@@ -34,6 +34,28 @@
3434

3535

3636
# Inference
37+
@dataclass
38+
class FloatingPointPreset(TorchInferenceRPUConfig):
39+
"""Preset configuration for FP-like AIMC (Analog In-Mememory Compute)
40+
accuracy evaluation/training.
41+
42+
This preset configuration does not inject any noise in any form (weight noise
43+
quantization etc.) and is equivalent to the FP model.
44+
"""
45+
46+
mapping: MappingParameter = field(
47+
default_factory=lambda: MappingParameter(max_input_size=0, max_output_size=0)
48+
)
49+
50+
forward: IOParameters = field(default_factory=lambda: IOParameters(is_perfect=True))
51+
52+
pre_post: PrePostProcessingParameter = field(
53+
default_factory=lambda: PrePostProcessingParameter(
54+
input_range=InputRangeParameter(enable=False)
55+
)
56+
)
57+
58+
3759
@dataclass
3860
class StandardHWATrainingPreset(InferenceRPUConfig):
3961
"""Preset configuration for AIMC (Analog In-Mememory Compute)

tests/test_presets.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
"""Tests for analog presets."""
88

9-
from torch import Tensor
9+
from torch import Tensor, randn
1010

1111
from aihwkit.simulator.tiles.analog import AnalogTile
1212
from aihwkit.simulator.presets import (
@@ -50,6 +50,7 @@
5050
TTv2EcRamPreset,
5151
TTv2EcRamMOPreset,
5252
TTv2IdealizedPreset,
53+
FloatingPointPreset,
5354
)
5455
from .helpers.decorators import parametrize_over_presets
5556
from .helpers.testcases import AihwkitTestCase
@@ -131,3 +132,21 @@ def test_tile_preset(self):
131132
self.assertEqual(tile_biases, None)
132133
# TODO: disabled as the comparison needs to take into account noise
133134
# self.assertTensorAlmostEqual(tile_weights, weights)
135+
136+
137+
class PresetTestFP(AihwkitTestCase):
138+
"""Test for FP preset."""
139+
140+
def test_tile_preset(self):
141+
"""Test fwd behavior of FP preset."""
142+
out_size = 2
143+
in_size = 3
144+
weights = randn(out_size, in_size)
145+
inp = randn(in_size)
146+
fp_out = inp @ weights.T
147+
148+
rpu_config = FloatingPointPreset()
149+
analog_tile = AnalogTile(out_size, in_size, rpu_config, bias=False)
150+
analog_tile.set_weights(weights)
151+
tile_out = analog_tile(inp)
152+
self.assertTensorAlmostEqual(fp_out, tile_out)

0 commit comments

Comments
 (0)