From 29987b4c8ebbc17880aa0292b1b7ed436bf9402a Mon Sep 17 00:00:00 2001 From: "Mario Sieg (Neo)" Date: Thu, 13 Feb 2025 08:54:14 +0100 Subject: [PATCH] Core backprop module --- magnetron/magnetron.c | 35 +++-- magnetron/magnetron_cpu_blas.inl | 4 +- python/benchmarks/bench_tool.py | 2 +- python/examples/xor.py | 76 +++++------ .../magnetron_framework/magnetron/__init__.py | 2 + python/magnetron_framework/magnetron/core.py | 78 ----------- python/magnetron_framework/magnetron/layer.py | 124 ------------------ python/magnetron_framework/magnetron/model.py | 81 ------------ .../magnetron_framework/magnetron/module.py | 98 ++++++++++++++ python/magnetron_framework/magnetron/optim.py | 37 +++--- 10 files changed, 178 insertions(+), 359 deletions(-) delete mode 100644 python/magnetron_framework/magnetron/layer.py delete mode 100644 python/magnetron_framework/magnetron/model.py create mode 100644 python/magnetron_framework/magnetron/module.py diff --git a/magnetron/magnetron.c b/magnetron/magnetron.c index fa5f4ec..13a3d3c 100644 --- a/magnetron/magnetron.c +++ b/magnetron/magnetron.c @@ -1202,7 +1202,9 @@ static void mag_op_backward_view(mag_tensor_t* node, mag_tensor_t** grads) { } static void mag_op_backward_transpose(mag_tensor_t* node, mag_tensor_t** grads) { - *grads = mag_transpose(node->grad); + mag_tensor_t* t = mag_transpose(node->grad); + *grads = mag_clone(t); + mag_tensor_decref(t); } static void mag_op_backward_permute(mag_tensor_t* node, mag_tensor_t** grads) { @@ -1219,7 +1221,7 @@ static void mag_op_backward_permute(mag_tensor_t* node, mag_tensor_t** grads) { static void mag_op_backward_mean(mag_tensor_t* node, mag_tensor_t** grads) { mag_tensor_t* x = node->op_inputs[0]; double scale = 1.0/(double)x->numel; - *grads = mag_muls(node->grad, scale); + *grads = mag_muls(node->grad, (float)scale); } static void mag_op_backward_min(mag_tensor_t* node, mag_tensor_t** grads) { @@ -1318,7 +1320,7 @@ static void mag_op_backward_softmax(mag_tensor_t* node, mag_tensor_t** grads) { static void mag_op_backward_sigmoid(mag_tensor_t* node, mag_tensor_t** grads) { mag_tensor_t* x = node->op_inputs[0]; mag_tensor_t* dv = mag_sigmoid_dv(x); - grads[0] = mag_mul(node->grad, dv); + grads[0] = mag_mul(dv, node->grad); mag_tensor_decref(dv); } @@ -1336,6 +1338,13 @@ static void mag_op_backward_tanh(mag_tensor_t* node, mag_tensor_t** grads) { mag_tensor_decref(dv); } +static void mag_op_backward_relu(mag_tensor_t* node, mag_tensor_t** grads) { + mag_tensor_t* x = node->op_inputs[0]; + mag_tensor_t* mask = mag_step(x); + grads[0] = mag_mul(node->grad, mask); + mag_tensor_decref(mask); +} + static void mag_op_backward_gelu(mag_tensor_t* node, mag_tensor_t** grads) { mag_tensor_t* x = node->op_inputs[0]; mag_tensor_t* dv = mag_gelu_dv(x); @@ -1356,8 +1365,8 @@ static void mag_op_backward_sub(mag_tensor_t* node, mag_tensor_t** grads) { static void mag_op_backward_mul(mag_tensor_t* node, mag_tensor_t** grads) { mag_tensor_t* x = node->op_inputs[0]; mag_tensor_t* y = node->op_inputs[1]; - grads[0] = mag_mul(node->grad, y); - grads[1] = mag_mul(node->grad, x); + grads[0] = mag_mul(y, node->grad); + grads[1] = mag_mul(x, node->grad); } static void mag_op_backward_div(mag_tensor_t* node, mag_tensor_t** grads) { @@ -2017,12 +2026,12 @@ static mag_tensor_t* MAG_HOTPROC mag_tensor_operator( const mag_op_meta_t* meta = mag_op_meta_of(op); mag_tensor_t* (*r_alloc)(mag_tensor_t**, const mag_op_param_t*) = meta->r_alloc; bool (*validate_op)(mag_op_t, mag_tensor_t*, mag_tensor_t**, const mag_op_param_t*) = meta->validator; - mag_tensor_t* R = (inplace && numin && meta->inplace) /* Inplace requested? */ - ? mag_tensor_create(ctx, (*inputs)->dtype, (*inputs)->shape, (*inputs)->rank, *inputs, 0) /* View R <- X for inplace aliasing op. */ - : (*r_alloc)(inputs, params); /* Construct new result tensor. */ - if (mag_unlikely(!(*validate_op)(op, R, inputs, params))) return NULL; /* Validation failed. */ - R->op = op; /* Set operation for deferred execution mode. */ - for (uint32_t i=0; i < numin; ++i) { /* Set input tensors and flags. */ + mag_tensor_t* R = (inplace && numin && meta->inplace) /* Inplace requested? */ + ? mag_tensor_create(ctx, (*inputs)->dtype, (*inputs)->shape, (*inputs)->rank, *inputs, 0) /* View R <- X for inplace aliasing op. */ + : (*r_alloc)(inputs, params); /* Construct new result tensor. */ + mag_assert((*validate_op)(op, R, inputs, params), "Invalid operation %s.", meta->mnemonic); /* Validate operation */ + R->op = op; /* Set operation for deferred execution mode. */ + for (uint32_t i=0; i < numin; ++i) { /* Set input tensors and flags. */ mag_tensor_t* input = inputs[i]; R->op_inputs[i] = input; if (ctx->flags & MAG_CTX_FLAG_GRAD_RECORDER) @@ -2533,6 +2542,7 @@ void mag_tensor_backward(mag_tensor_t* root) { mag_tensor_t* grad = mag_tensor_create(child->ctx, child->dtype, child->shape, child->rank, NULL, 0); grad->flags = (grad->flags | MAG_TFLAG_IS_GRAD) & ~MAG_TFLAG_REQUIRES_GRAD; mag_tensor_fill(grad, 1.0f); + mag_tensor_fmt_name(grad, "%s (grad)", child->name); child->grad = grad; } if (mag_unlikely(child->op == MAG_OP_NOP)) continue; @@ -2546,7 +2556,8 @@ void mag_tensor_backward(mag_tensor_t* root) { mag_assert2(input); if (!(input->flags & MAG_TFLAG_REQUIRES_GRAD)) continue; mag_tensor_t* gri = grads[i]; - if (!gri) continue; + mag_tensor_fmt_name(gri, "%s (grad)", input->name); + mag_assert(gri, "Gradient for op %s, input #%d is not computed", meta->mnemonic, i); if (!input->grad) { gri->flags = (gri->flags | MAG_TFLAG_IS_GRAD) & ~MAG_TFLAG_REQUIRES_GRAD; input->grad = gri; diff --git a/magnetron/magnetron_cpu_blas.inl b/magnetron/magnetron_cpu_blas.inl index db0dbe7..530deec 100644 --- a/magnetron/magnetron_cpu_blas.inl +++ b/magnetron/magnetron_cpu_blas.inl @@ -998,8 +998,8 @@ static void MAG_HOTPROC mag_vtanh_dv_f32( /* tanh' : ℝ -> (-1, 1), x |-> 1 / ( const mag_f32_t* x ) { for (int64_t i=0; i < numel; ++i) { - const mag_f32_t cx = coshf(x[i]); - o[i] = 1.0f / (cx*cx); + float t = tanhf(x[i]); + o[i] = 1.0f - t*t; } } diff --git a/python/benchmarks/bench_tool.py b/python/benchmarks/bench_tool.py index 197e318..daffd46 100644 --- a/python/benchmarks/bench_tool.py +++ b/python/benchmarks/bench_tool.py @@ -52,7 +52,7 @@ def plot(self, flops_per_op: int=2, plot_style: str='bars'): ax2.set_xticklabels(x_labels, rotation=45, ha='right') else: dims = [sa[0] for sa in self.shapes_a] # For square matrices, any dimension works - markers = ['o', '+', 'x', '*', '.', 'X', '^'] + markers = ['o', '+', 'x', '*', '.', 'x', '^'] for i, participant in enumerate(self.participants): ax1.plot(dims, participant.timings, label=participant.name, diff --git a/python/examples/xor.py b/python/examples/xor.py index c492230..9c97caa 100644 --- a/python/examples/xor.py +++ b/python/examples/xor.py @@ -1,53 +1,49 @@ # (c) 2025 Mario "Neo" Sieg. -from magnetron import Tensor -from magnetron.layer import DenseLayer -from magnetron.model import SequentialModel, HyperParams -import matplotlib.pyplot as plt +from magnetron import Tensor, Module, Linear +from magnetron.optim import SGD, mse_loss -EPOCHS: int = 10000 -RATE: float = 0.1 -# Inputs: shape (4, 2) -inputs = Tensor.const([ +class XOR(Module): + def __init__(self) -> None: + super().__init__() + self.l1 = Linear(2, 2) + self.l2 = Linear(2, 1) + + def forward(self, x: Tensor) -> Tensor: + x = self.l1(x).tanh() + x = self.l2(x).tanh() + return x + +model = XOR() +optim = SGD(model.parameters(), lr=1e-1) + +x = Tensor.const([ [0, 0], [0, 1], [1, 0], [1, 1] -]) +], name='x') -# Targets: shape (4, 1) -targets = Tensor.const([ +y = Tensor.const([ [0], [1], [1], [0] -]) - -params = HyperParams(lr=RATE, epochs=EPOCHS) -mlp = SequentialModel(params, [ - DenseLayer(2, 4), - DenseLayer(4, 1) -]) - -# Train -losses = mlp.train(inputs, targets) - -# Inference -test_points = [ - (0, 0), - (0, 1), - (1, 0), - (1, 1), -] - -for (x_val, y_val) in test_points: - result = mlp.forward(Tensor.const([[x_val, y_val]]))[0] - print(f"{x_val} XOR {y_val} => {result:.4f}") - -# Plot MSE loss -plt.plot(list(range(0, EPOCHS)), losses) -plt.xlabel('Epochs') -plt.ylabel('MSE Loss') -plt.title('XOR Problem') -plt.show() +], name='y') + +epochs: int = 2000 + +y_hat = model(x) +print(y_hat) +for epoch in range(epochs): + y_hat = model(x) + loss = mse_loss(y_hat, y) + loss.backward() + optim.step() + optim.zero_grad() + if epoch % 1000 == 0: + print(f'Epoch: {epoch}, Loss: {loss.item()}') + +y_hat = model(x) +print(y_hat) diff --git a/python/magnetron_framework/magnetron/__init__.py b/python/magnetron_framework/magnetron/__init__.py index d1e19df..f13ddae 100644 --- a/python/magnetron_framework/magnetron/__init__.py +++ b/python/magnetron_framework/magnetron/__init__.py @@ -8,3 +8,5 @@ __url__ = 'https://github.com/MarioSieg/magnetron' from .core import * +from .module import * +from .optim import * diff --git a/python/magnetron_framework/magnetron/core.py b/python/magnetron_framework/magnetron/core.py index 1c0bb81..ca43b4b 100644 --- a/python/magnetron_framework/magnetron/core.py +++ b/python/magnetron_framework/magnetron/core.py @@ -225,8 +225,6 @@ class Tensor: __slots__ = ('_ctx', '_ptr', '_inputs') def __init__(self, ptr: ffi.CData | None = None) -> None: - if isinstance(ptr, ffi.CData): - assert ptr != ffi.NULL, 'Invalid tensor pointer' self._ctx = None self._ptr = ptr self._inputs = None @@ -927,79 +925,3 @@ def __setitem__(self, indices: int | tuple[int, ...], value: float) -> None: C.mag_tensor_set_scalar_physical_index(self._ptr, *idx, float(value)) else: raise TypeError('Indices must be an int or a tuple of ints.') - - -class Parameter: - """A tensor that is a learnable parameter of a model.""" - - __slots__ = ('x',) - - def __init__(self, x: Tensor) -> None: - x.requires_grad = True - self.x = x - - -class Module: - """Base class for all neural network modules.""" - - def parameters(self) -> list[Parameter]: - """Returns all unique and nested parameters of the module.""" - params: list[Parameter] = [] - for k, v in self.__dict__.items(): - if isinstance(v, Parameter): - params.append(v) - elif isinstance(v, Module): - params += v.parameters() - elif isinstance(v, ModuleList): - for mod in v: - params += mod.parameters() - return list(set(params)) - - def eval(self) -> None: - """Sets the module in evaluation mode.""" - for p in self.parameters(): - p.x.requires_grad = False - - def train(self) -> None: - """Sets the module in training mode.""" - for p in self.parameters(): - p.x.requires_grad = True - - def forward(self, *args: Tensor, **kwargs: dict) -> Tensor: - """Forward pass must be implemented by subclass.""" - raise NotImplementedError - - def __call__(self, *args: Tensor, **kwargs: dict) -> Tensor: - return self.forward(*args, **kwargs) - - -class ModuleList(Module, list): - """A list of modules that can be used as a single module.""" - - def __init__(self, mods: list[Module] | None) -> None: - super().__init__() - if mods is not None: - self += mods - - def append(self, mod: Module) -> None: - super().append(mod) - - def extend(self, __iterable: list[Module]) -> None: - super().extend(__iterable) - - def __iadd__(self, other: list[Module]) -> 'ModuleList': - self.extend(other) - return self - - def __setitem__(self, k: int, v: Module) -> None: - super().__setitem__(k, v) - - def __getitem__(self, k: int) -> Module: - return super().__getitem__(k) - - def parameters(self) -> list[Parameter]: - """Returns all unique and nested parameters of the module.""" - params: list[Parameter] = [] - for mod in self: - params += mod.parameters() - return list(set(params)) diff --git a/python/magnetron_framework/magnetron/layer.py b/python/magnetron_framework/magnetron/layer.py deleted file mode 100644 index 3978af9..0000000 --- a/python/magnetron_framework/magnetron/layer.py +++ /dev/null @@ -1,124 +0,0 @@ -# (c) 2025 Mario "Neo" Sieg. - -import math -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import unique, Enum - -from magnetron import Tensor - - -class Layer(ABC): - """Abstract base class for all layers""" - - @abstractmethod - def forward(self, inputs: Tensor) -> Tensor: - pass - - @abstractmethod - def backward(self, is_hidden_layer: bool, delta: Tensor, rate: float) -> Tensor: - pass - - -@dataclass -class LayerInit: - """Weight/bias initialization methods and parameters""" - - @unique - class Distribution(Enum): - NORMAL = 0 - UNIFORM = 1 - - @unique - class Method(Enum): - RANDOM = 0 - XAVIER = 1 - HE = 2 - - method: Method - dist: Distribution - uniform_interval: (float, float) = (-1.0, 1.0) - mean: float = 0.0 - stddev: float = 1.0 - gain: float = 1.0 - - def __init__( - self, method: Method, dist: Distribution, **kwargs: dict[str, object] - ) -> None: - self.method = method - self.dist = dist - for key, value in kwargs.items(): - setattr(self, key, value) - - def apply(self, shape: tuple[int, ...]) -> Tensor: - assert len(shape) >= 1 - - if self.method == self.Method.RANDOM: - if self.dist == self.Distribution.NORMAL: - return Tensor.normal(shape, mean=self.mean, stddev=self.stddev) - elif self.dist == self.Distribution.UNIFORM: - return Tensor.uniform(shape, interval=self.uniform_interval) - - fan_in: int = shape[0] - fan_out: int = shape[1] if self.method == self.Method.XAVIER else None - factor: float = 1.0 - bound: float = 1.0 - - if self.method == self.Method.XAVIER: - factor = 2.0 / (fan_in + fan_out) - bound = math.sqrt(6.0 / (fan_in + fan_out)) - elif self.method == self.Method.HE: - factor = 1.0 / fan_in - bound = math.sqrt(3.0 / fan_in) - - if self.dist == self.Distribution.NORMAL: - stddev = self.gain * math.sqrt(factor) - return Tensor.normal(shape, mean=0.0, stddev=stddev) - elif self.dist == self.Distribution.UNIFORM: - return Tensor.uniform( - shape, interval=(-self.gain * bound, self.gain * bound) - ) - else: - raise ValueError('Invalid weight/bias initialization method') - - -class DenseLayer(Layer): - """Fully connected layer""" - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool, - init: LayerInit = LayerInit( - LayerInit.Method.RANDOM, LayerInit.Distribution.UNIFORM - ), - ) -> None: - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = init.apply((in_features, out_features)) - self.bias = init.apply((1, out_features)) if bias else None - self._x = None - self._z = None - self._out = None - - def forward(self, x: Tensor) -> Tensor: - self._x = x - self._z = x @ self.weight - if self.bias is not None: - self._z += self.bias - self._out = self._z.sigmoid() - return self._out - - def backward(self, is_hidden_layer: bool, delta: Tensor, rate: float) -> Tensor: - dW = self._x.T.clone() @ delta - ones_batch = Tensor.full((delta.shape[0], 1), fill_value=1.0) - dB = (delta.T.clone() @ ones_batch).T.clone() - self.weight -= dW * rate - self.bias -= dB * rate - - next_delta = delta @ self.weight.T.clone() - if is_hidden_layer: - next_delta *= self._z.sigmoid(derivative=True) - return next_delta diff --git a/python/magnetron_framework/magnetron/model.py b/python/magnetron_framework/magnetron/model.py deleted file mode 100644 index 829408c..0000000 --- a/python/magnetron_framework/magnetron/model.py +++ /dev/null @@ -1,81 +0,0 @@ -# (c) 2025 Mario "Neo" Sieg. - -import time -from abc import ABC, abstractmethod -from dataclasses import dataclass - -from magnetron import Tensor -from magnetron.layer import Layer -from magnetron.optim import Optimizer - - -@dataclass -class HyperParams: - lr: float = 0.01 - epochs: int = 10000 - epoch_step: int = 1000 - - -class Model(ABC): - """Abstract base class for all models""" - - def __init__(self, hyper_params: HyperParams) -> None: - self.hyper_params = hyper_params - - @abstractmethod - def forward(self, inputs: Tensor) -> Tensor: - pass - - @abstractmethod - def backward(self, outputs: Tensor, targets: Tensor, rate: float) -> None: - pass - - @abstractmethod - def train(self, inputs: Tensor, targets: Tensor) -> None: - pass - - @abstractmethod - def summary(self) -> None: - pass - - -class SequentialModel(Model): - """Feedforward neural network model (multi-layer perceptron)""" - - def __init__(self, hyper_params: HyperParams, layers: list[Layer]) -> None: - super().__init__(hyper_params) - self.layers = layers - - def forward(self, inputs: Tensor) -> Tensor: - x = inputs - for layer in self.layers: - x = layer.forward(x) - return x - - def backward(self, outputs: Tensor, targets: Tensor, rate: float) -> None: - delta = (outputs - targets) * outputs.sigmoid(derivative=True) - for i in reversed(range(len(self.layers))): - is_hidden = i > 0 - delta = self.layers[i].backward(is_hidden, delta, rate) - - def train(self, inputs: Tensor, targets: Tensor) -> None: - epochs: int = self.hyper_params.epochs - rate: float = self.hyper_params.lr - - print(f'Training started for {epochs} epochs with learning rate {rate}') - start_time = time.time_ns() - losses = [] - for epoch in range(epochs): - output = self.forward(inputs) - self.backward(output, targets, rate) - loss = Optimizer.mse(output, targets) - losses.append(loss) - if epoch % self.hyper_params.epoch_step == 0: - print(f'Epoch: {epoch}, Loss: {loss:.6f}') - - duration = (time.time_ns() - start_time) / 1e9 - print(f'Training finished in {duration:.2f} seconds') - return losses - - def summary(self) -> None: - pass diff --git a/python/magnetron_framework/magnetron/module.py b/python/magnetron_framework/magnetron/module.py new file mode 100644 index 0000000..c65aae3 --- /dev/null +++ b/python/magnetron_framework/magnetron/module.py @@ -0,0 +1,98 @@ +# (c) 2025 Mario "Neo" Sieg. +import math + +from magnetron.core import Tensor + +class Parameter: + """A tensor that is a learnable parameter of a model.""" + + __slots__ = ('x',) + + def __init__(self, x: Tensor) -> None: + x.requires_grad = True + self.x = x + +class Module: + """Base class for all neural network modules.""" + + def parameters(self) -> list[Parameter]: + """Returns all unique and nested parameters of the module.""" + params: list[Parameter] = [] + for k, v in self.__dict__.items(): + if isinstance(v, Parameter): + params.append(v) + elif isinstance(v, Module): + params += v.parameters() + elif isinstance(v, ModuleList): + for mod in v: + params += mod.parameters() + return list(set(params)) + + def eval(self) -> None: + """Sets the module in evaluation mode.""" + for p in self.parameters(): + p.x.requires_grad = False + + def train(self) -> None: + """Sets the module in training mode.""" + for p in self.parameters(): + p.x.requires_grad = True + + def forward(self, *args: Tensor, **kwargs: dict) -> Tensor: + """Forward pass must be implemented by subclass.""" + raise NotImplementedError + + def __call__(self, *args: Tensor, **kwargs: dict) -> Tensor: + return self.forward(*args, **kwargs) + + +class ModuleList(Module, list): + """A list of modules that can be used as a single module.""" + + def __init__(self, mods: list[Module] | None) -> None: + super().__init__() + if mods is not None: + self += mods + + def append(self, mod: Module) -> None: + super().append(mod) + + def extend(self, __iterable: list[Module]) -> None: + super().extend(__iterable) + + def __iadd__(self, other: list[Module]) -> 'ModuleList': + self.extend(other) + return self + + def __setitem__(self, k: int, v: Module) -> None: + super().__setitem__(k, v) + + def __getitem__(self, k: int) -> Module: + return super().__getitem__(k) + + def parameters(self) -> list[Parameter]: + """Returns all unique and nested parameters of the module.""" + params: list[Parameter] = [] + for mod in self: + params += mod.parameters() + return list(set(params)) + +class Linear(Module): + """A fully connected linear layer.""" + __slots__ = ('in_features', 'out_features', 'weight', 'bias') + + def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + weight = Tensor.normal((in_features, out_features), mean=0, stddev=1) + weight = weight / math.sqrt(in_features + out_features) + self.weight = Parameter(weight) + if bias: + self.bias = Parameter(Tensor.zeros((out_features,), name='bias')) + + def forward(self, x: Tensor) -> Tensor: + x = x @ self.weight.x.T.clone() + if self.bias is not None: + x = x + self.bias.x + return x diff --git a/python/magnetron_framework/magnetron/optim.py b/python/magnetron_framework/magnetron/optim.py index a5176a7..3242e9b 100644 --- a/python/magnetron_framework/magnetron/optim.py +++ b/python/magnetron_framework/magnetron/optim.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from magnetron import Tensor - +from magnetron.module import Parameter class PolynomialDecayLRScheduler: """Polynomial Decay Learning Rate Scheduler""" @@ -15,40 +15,35 @@ def step(self, iter: float) -> float: y: float = iter / self.max_iter return max(self.initial_lr * (1 - y) ** 2, 1.0e-7) +def mse_loss(y_hat: Tensor, y: Tensor) -> Tensor: + delta = y_hat - y + return (delta * delta).mean() class Optimizer(ABC): """Base class for all optimizers.""" + __slots__ = ('params', 'lr') - def __init__(self, params: list[Tensor], lr: float) -> None: - self.lr = lr + def __init__(self, params: list[Parameter], lr: float) -> None: self.params = params + self.lr = lr @abstractmethod def step(self) -> None: - pass + raise NotImplementedError def zero_grad(self) -> None: for param in self.params: - param.zero_grad() - - @staticmethod - def mse(y: Tensor, y_hat: Tensor) -> float: - return (y - y_hat).sqr_().mean()[0] - - @staticmethod - def cross_entropy(y: Tensor, y_hat: Tensor) -> float: - return -(y * y_hat.log_()).sum()[0] - + param.x.zero_grad() class SGD(Optimizer): """Stochastic Gradient Descent""" - def __init__(self, params: list[Tensor], lr: float) -> None: + def __init__(self, params: list[Parameter], lr: float) -> None: super().__init__(params, lr) def step(self) -> None: for param in self.params: - param -= param.grad * self.lr + param -= param.x.grad * self.lr class Adam(Optimizer): @@ -56,7 +51,7 @@ class Adam(Optimizer): def __init__( self, - params: list[Tensor], + params: list[Parameter], lr: float = 0.001, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, @@ -65,17 +60,17 @@ def __init__( self.betas = betas self.eps = eps self.t = 0 - self.m = [Tensor.zeros(p.shape) for p in self.params] - self.v = [Tensor.zeros(p.shape) for p in self.params] + self.m = [Tensor.zeros(p.x.shape) for p in self.params] + self.v = [Tensor.zeros(p.x.shape) for p in self.params] def step(self) -> None: self.t += 1 for i, p in enumerate(self.params): - grad = p.grad + grad = p.x.grad if grad is None: continue self.m[i] = self.betas[0] * self.m[i] + (1.0 - self.betas[0]) * grad self.v[i] = self.betas[1] * self.v[i] + (1.0 - self.betas[1]) * grad.sqr_() m_hat: Tensor = self.m[i] / (1.0 - self.betas[0] ** self.t) v_hat: Tensor = self.v[i] / (1.0 - self.betas[1] ** self.t) - p -= self.lr * m_hat / (v_hat.sqrt_() + self.eps) + p.x -= self.lr * m_hat / (v_hat.sqrt_() + self.eps)