Skip to content

Commit 8cb5063

Browse files
PabloCarmonacharlesmackincharles-mackinatvasilopoulos
authored
add release content for 0.9.2 (#685)
* add release content for 0.9.2 Signed-off-by: Pablo Carmona <[email protected]> * remove added patch text Signed-off-by: Pablo Carmona <[email protected]> * modify changelog.md and version Signed-off-by: Pablo Carmona <[email protected]> * ufix(examples/31_custom_drift_models.py): change way rpu config is setup to a cleaner way Signed-off-by: Pablo Carmona <[email protected]> * fix(examples/31_custom_drift_models.py): fix errors with pycodestyle Signed-off-by: Pablo Carmona <[email protected]> * feat(notebooks): clean up and enhacenments on notebook contents for clarification and typos fixings * feat(converter/conductance): new conductance converters development * updates to custom drift model suggested by Malte (#690) Co-authored-by: Charles Mackin <[email protected]> * Update hermes.py (#691) * fix(hermes.py): fix linting errors related to whitespacing * fix(requirements.txt): solve problem with numpy versions in some cases * feat(CHANGELOG.md): add PR number for new related developments and modify the date of release to a proper one * feat(CHANGELOG.md): add PR number for new related developments and modify the date of release to a proper one * feat(travis): update travis build with latest pytorch version --------- Signed-off-by: Pablo Carmona <[email protected]> Co-authored-by: charlesmackin <[email protected]> Co-authored-by: Charles Mackin <[email protected]> Co-authored-by: Athanasios Vasilopoulos <[email protected]>
1 parent 6bcc9d6 commit 8cb5063

19 files changed

+2027
-114
lines changed

.travis.yml

+8-8
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ jobs:
9494
if: branch =~ /^release\/.*$/
9595
env:
9696
# Use a specific torch version.
97-
- CIBW_ENVIRONMENT="TORCH_VERSION_SPECIFIER='==2.0.1'"
98-
- CIBW_BEFORE_BUILD="pip install torch==2.0.1 torchvision && pip install -r requirements.txt"
97+
- CIBW_ENVIRONMENT="TORCH_VERSION_SPECIFIER='==2.4.1'"
98+
- CIBW_BEFORE_BUILD="pip install torch==2.4.1 torchvision && pip install -r requirements.txt"
9999
- CIBW_MANYLINUX_X86_64_IMAGE="aihwkit/manylinux2014_x86_64_aihwkit"
100100
- CIBW_REPAIR_WHEEL_COMMAND="auditwheel repair -w {dest_dir} {wheel} --exclude libtorch_python.so"
101-
- CIBW_BUILD="cp38-manylinux_x86_64 cp39-manylinux_x86_64 cp310-manylinux_x86_64"
101+
- CIBW_BUILD="cp38-manylinux_x86_64 cp39-manylinux_x86_64 cp310-manylinux_x86_64 cp311-manylinux_x86_64"
102102
before_install:
103103
- docker pull aihwkit/manylinux2014_x86_64_aihwkit
104104
install:
@@ -120,8 +120,8 @@ jobs:
120120
update: true
121121
env:
122122
# Use a specific torch version.
123-
- CIBW_ENVIRONMENT="TORCH_VERSION_SPECIFIER='==2.0.1'"
124-
- CIBW_BEFORE_BUILD="pip install torch==2.0.1 torchvision && pip install ./delocate && pip install -r requirements.txt"
123+
- CIBW_ENVIRONMENT="TORCH_VERSION_SPECIFIER='==2.4.1'"
124+
- CIBW_BEFORE_BUILD="pip install torch==2.4.1 torchvision && pip install ./delocate && pip install -r requirements.txt"
125125
- CIBW_BUILD="cp38-macosx_x86_64 cp39-macosx_x86_64"
126126
before_install:
127127
- git clone -b aihwkit https://github.com/aihwkit-bot/delocate.git
@@ -139,9 +139,9 @@ jobs:
139139
if: branch =~ /^release\/win.*$/
140140
env:
141141
# Use a specific torch version.
142-
- CIBW_ENVIRONMENT="TORCH_VERSION_SPECIFIER='==2.0.1'"
143-
- CIBW_BEFORE_BUILD="pip install torch==2.0.1 && pip install -r requirements.txt"
144-
- CIBW_BUILD="cp37-win_amd64 cp38-win_amd64 cp39-win_amd64 cp310-win_amd64"
142+
- CIBW_ENVIRONMENT="TORCH_VERSION_SPECIFIER='==2.4.1'"
143+
- CIBW_BEFORE_BUILD="pip install torch==2.4.1 && pip install -r requirements.txt"
144+
- CIBW_BUILD="cp38-win_amd64 cp39-win_amd64 cp310-win_amd64"
145145
# Use unzipped OpenBLAS.
146146
- OPENBLAS_ROOT=C:\\BLAS
147147
- OPENBLAS_ROOT_DIR=C:\\BLAS

CHANGELOG.md

+14
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@ The format is based on [Keep a Changelog], and this project adheres to
1212
* `Fixed` for any bug fixes.
1313
* `Security` in case of vulnerabilities.
1414

15+
## [0.9.2] - 2024/09/18
16+
17+
### Added
18+
19+
* Added new Hermes noise model and related notebooks (\#685)
20+
* Added new conductance converters (\#685)
21+
* Make Conv layers also compatible with non-batched inputs (\#685)
22+
* Added per column drift compensation (\#685)
23+
* Added custom drifts (\#685)
24+
25+
### Changed
26+
27+
* Update requirements-examples.txt (\#685)
28+
1529
## [0.9.1] - 2024/05/16
1630

1731
### Added

docs/source/pcm_inference.rst

+9-1
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ The fits between these equations and the hardware measurements are shown below:
121121
.. image:: ../img/pcm_drift_plot.png
122122
:alt:
123123

124+
Users can also provide custom drift characteristics to override the default drift model,
125+
which can be used to evaluate the performance and trade-offs of different devices
126+
:ref:`[6] <references_pcm>`. See
127+
`example 31 <https://github.com/IBM/aihwkit/blob/master/examples/31_custom_drift_models.py>`_
128+
for an example on how to customize drift models.
129+
124130
Read noise
125131
----------
126132

@@ -225,4 +231,6 @@ References
225231

226232
* [5] Le Gallo, M., Krebs, D., Zipoli, F., Salinga, M., & Sebastian, A. `Collective Structural Relaxation in Phase‐Change Memory Devices <https://onlinelibrary.wiley.com/doi/full/10.1002/aelm.201700627>`_. Advanced Electronic Materials, 4(9), 1700627. 2018
227233

228-
* [6] Le Gallo, M., Sebastian, A., Cherubini, G., Giefers, H., & Eleftheriou, E. `Compressed sensing with approximate message passing using in-memory computing <https://ieeexplore.ieee.org/abstract/document/8450603>`_. IEEE Transactions on Electron Devices, 65(10), 4304-4312. 2018
234+
* [6] N. Li, C. Mackin, A. Chen, K. Brew, T. Philip, A. Simon, I. Saraf, J.-P. Han, S. G. Sarwat, G. W. Burr, M. Rasch, A. Sebastian, V. Narayanan, N. Saulnier. `Optimization of Projected Phase Change Memory for Analog In-Memory Computing Inference <https://doi.org/10.1002/aelm.202201190>`_. Advanced Electronic Materials, 9, 2201190. 2023
235+
236+
* [7] Le Gallo, M., Sebastian, A., Cherubini, G., Giefers, H., & Eleftheriou, E. `Compressed sensing with approximate message passing using in-memory computing <https://ieeexplore.ieee.org/abstract/document/8450603>`_. IEEE Transactions on Electron Devices, 65(10), 4304-4312. 2018

examples/31_custom_drift_models.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
4+
#
5+
# This code is licensed under the Apache License, Version 2.0. You may
6+
# obtain a copy of this license in the LICENSE.txt file in the root directory
7+
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8+
#
9+
# Any modifications or derivative works of this code must retain this
10+
# copyright notice, and modified files need to carry a notice indicating
11+
# that they have been altered from the originals.
12+
13+
"""aihwkit example 31: customized conductance drift models
14+
15+
Simple 1-layer network that demonstrates user-defined drift model
16+
capability and impact on output over time due to drift.
17+
18+
Reference paper evaluating a number of different conductance-dependent
19+
drift models exhibiting complex characteristics:
20+
https://onlinelibrary.wiley.com/doi/full/10.1002/aelm.202201190
21+
"""
22+
# pylint: disable=invalid-name
23+
from numpy import asarray
24+
25+
# Imports from PyTorch.
26+
from torch import (
27+
zeros,
28+
ones,
29+
mean,
30+
std,
31+
linspace,
32+
)
33+
import matplotlib.pyplot as plt
34+
35+
# Imports from aihwkit.
36+
from aihwkit.nn import AnalogLinear
37+
from aihwkit.simulator.rpu_base import cuda
38+
from aihwkit.inference.converter.conductance import SinglePairConductanceConverter
39+
from aihwkit.inference.noise.pcm import CustomDriftPCMLikeNoiseModel
40+
from aihwkit.simulator.parameters.enums import BoundManagementType
41+
from aihwkit.simulator.parameters.io import IOParameters
42+
from aihwkit.simulator.configs import TorchInferenceRPUConfig
43+
44+
g_min, g_max = 0.0, 25.
45+
# define custom drift model
46+
custom_drift_model = dict(g_lst=[g_min, 10., g_max],
47+
nu_mean_lst=[0.08, 0.05, 0.03],
48+
nu_std_lst=[0.03, 0.02, 0.01])
49+
50+
t_inference_times = [1, # 1 sec
51+
60, # 1 min
52+
60 * 60, # 1 hour
53+
24 * 60 * 60, # 1 day
54+
30 * 24 * 60 * 60, # 1 month
55+
12 * 30 * 24 * 60 * 60, # 1 year
56+
]
57+
58+
IN_FEATURES = 512
59+
OUT_FEATURES = 512
60+
BATCH_SIZE = 1
61+
62+
# define rpu_config
63+
io_params = IOParameters(
64+
bound_management=BoundManagementType.NONE,
65+
nm_thres=1.0,
66+
inp_res=2 ** 8 - 2,
67+
out_bound=-1,
68+
out_res=-1,
69+
out_noise=0.0)
70+
71+
noise_model = CustomDriftPCMLikeNoiseModel(custom_drift_model,
72+
prog_noise_scale=0.0, # turn off to show drift only
73+
read_noise_scale=0.0, # turn off to show drift only
74+
drift_scale=1.0,
75+
g_converter=SinglePairConductanceConverter(g_min=g_min,
76+
g_max=g_max),
77+
)
78+
79+
rpu_config = TorchInferenceRPUConfig(noise_model=noise_model, forward=io_params)
80+
81+
# define simple model, weights, and activations
82+
model = AnalogLinear(IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config)
83+
weights = linspace(custom_drift_model['g_lst'][0],
84+
custom_drift_model['g_lst'][-1],
85+
OUT_FEATURES).repeat(IN_FEATURES, 1)
86+
x = ones(BATCH_SIZE, IN_FEATURES)
87+
88+
# set weights
89+
for name, layer in model.named_analog_layers():
90+
layer.set_weights(weights.T, zeros(OUT_FEATURES))
91+
92+
# Move the model and tensors to cuda if it is available
93+
if cuda.is_compiled():
94+
x = x.cuda()
95+
model = model.cuda()
96+
97+
model.eval()
98+
model.drift_analog_weights(t_inference_times[0]) # generate drift (nu) coefficients
99+
100+
# Extract drift coefficients nu as a function of conductance
101+
g_lst, _ = rpu_config.noise_model.g_converter.convert_to_conductances(weights)
102+
nu_lst = model.analog_module.drift_noise_parameters
103+
104+
# Get mean and std drift coefficient (nu) as function of conductance
105+
gs = mean(g_lst[0], dim=0).numpy()
106+
nu_mean = mean(nu_lst[0].T, dim=0).numpy()
107+
nu_std = std(nu_lst[0].T, dim=0).numpy()
108+
109+
# Plot device drift model
110+
plt.figure()
111+
plt.plot(gs, nu_mean)
112+
plt.fill_between(gs, nu_mean - nu_std, nu_mean + nu_std, alpha=0.2)
113+
plt.xlabel(r"$Conductance \ [\mu S]$")
114+
plt.ylabel(r"$\nu \ [1]$")
115+
plt.tight_layout()
116+
plt.savefig('custom_drift_model.png')
117+
plt.close()
118+
119+
# create simple linear layer model
120+
model = AnalogLinear(IN_FEATURES, OUT_FEATURES, bias=False, rpu_config=rpu_config)
121+
122+
# define weights, activations
123+
weights = (1. / 512.) * ones(IN_FEATURES, OUT_FEATURES)
124+
x = ones(BATCH_SIZE, IN_FEATURES)
125+
126+
# set weights
127+
for _, layer in model.named_analog_layers():
128+
layer.set_weights(weights.T, zeros(OUT_FEATURES))
129+
130+
# Move the model and tensors to cuda if it is available
131+
if cuda.is_compiled():
132+
x = x.cuda()
133+
model = model.cuda()
134+
135+
# Eval model at different time steps
136+
model.eval()
137+
out_lst = []
138+
for t_inference in t_inference_times:
139+
model.drift_analog_weights(t_inference) # generate new nu coefficients at each time step
140+
out = model(x)
141+
out_lst.append(out)
142+
143+
# Plot drift compensated outputs as a function of time
144+
t = asarray(t_inference_times)
145+
out_mean = asarray([mean(out).detach().cpu().numpy() for out in out_lst])
146+
out_std = asarray([std(out).detach().cpu().numpy() for out in out_lst])
147+
plt.figure()
148+
plt.plot(t, out_mean)
149+
plt.fill_between(t, out_mean - out_std, out_mean + out_std, alpha=0.2)
150+
plt.xscale("log")
151+
plt.xticks(t, ['1 sec', '1 min', '1 hour', '1 day', '1 month', '1 year'], rotation='vertical')
152+
plt.xlabel(r"$Time$")
153+
plt.ylabel(r"$Drift \ Compensated \ Outputs \ [1]$")
154+
plt.tight_layout()
155+
plt.savefig('custom_drift_model_output.png')
156+
plt.close()

0 commit comments

Comments
 (0)