Skip to content

Commit

Permalink
Core backprop module
Browse files Browse the repository at this point in the history
  • Loading branch information
MarioSieg committed Feb 13, 2025
1 parent c4987b6 commit 29987b4
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 359 deletions.
35 changes: 23 additions & 12 deletions magnetron/magnetron.c
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,9 @@ static void mag_op_backward_view(mag_tensor_t* node, mag_tensor_t** grads) {
}

static void mag_op_backward_transpose(mag_tensor_t* node, mag_tensor_t** grads) {
*grads = mag_transpose(node->grad);
mag_tensor_t* t = mag_transpose(node->grad);
*grads = mag_clone(t);
mag_tensor_decref(t);
}

static void mag_op_backward_permute(mag_tensor_t* node, mag_tensor_t** grads) {
Expand All @@ -1219,7 +1221,7 @@ static void mag_op_backward_permute(mag_tensor_t* node, mag_tensor_t** grads) {
static void mag_op_backward_mean(mag_tensor_t* node, mag_tensor_t** grads) {
mag_tensor_t* x = node->op_inputs[0];
double scale = 1.0/(double)x->numel;
*grads = mag_muls(node->grad, scale);
*grads = mag_muls(node->grad, (float)scale);
}

static void mag_op_backward_min(mag_tensor_t* node, mag_tensor_t** grads) {
Expand Down Expand Up @@ -1318,7 +1320,7 @@ static void mag_op_backward_softmax(mag_tensor_t* node, mag_tensor_t** grads) {
static void mag_op_backward_sigmoid(mag_tensor_t* node, mag_tensor_t** grads) {
mag_tensor_t* x = node->op_inputs[0];
mag_tensor_t* dv = mag_sigmoid_dv(x);
grads[0] = mag_mul(node->grad, dv);
grads[0] = mag_mul(dv, node->grad);
mag_tensor_decref(dv);
}

Expand All @@ -1336,6 +1338,13 @@ static void mag_op_backward_tanh(mag_tensor_t* node, mag_tensor_t** grads) {
mag_tensor_decref(dv);
}

static void mag_op_backward_relu(mag_tensor_t* node, mag_tensor_t** grads) {
mag_tensor_t* x = node->op_inputs[0];
mag_tensor_t* mask = mag_step(x);
grads[0] = mag_mul(node->grad, mask);
mag_tensor_decref(mask);
}

static void mag_op_backward_gelu(mag_tensor_t* node, mag_tensor_t** grads) {
mag_tensor_t* x = node->op_inputs[0];
mag_tensor_t* dv = mag_gelu_dv(x);
Expand All @@ -1356,8 +1365,8 @@ static void mag_op_backward_sub(mag_tensor_t* node, mag_tensor_t** grads) {
static void mag_op_backward_mul(mag_tensor_t* node, mag_tensor_t** grads) {
mag_tensor_t* x = node->op_inputs[0];
mag_tensor_t* y = node->op_inputs[1];
grads[0] = mag_mul(node->grad, y);
grads[1] = mag_mul(node->grad, x);
grads[0] = mag_mul(y, node->grad);
grads[1] = mag_mul(x, node->grad);
}

static void mag_op_backward_div(mag_tensor_t* node, mag_tensor_t** grads) {
Expand Down Expand Up @@ -2017,12 +2026,12 @@ static mag_tensor_t* MAG_HOTPROC mag_tensor_operator(
const mag_op_meta_t* meta = mag_op_meta_of(op);
mag_tensor_t* (*r_alloc)(mag_tensor_t**, const mag_op_param_t*) = meta->r_alloc;
bool (*validate_op)(mag_op_t, mag_tensor_t*, mag_tensor_t**, const mag_op_param_t*) = meta->validator;
mag_tensor_t* R = (inplace && numin && meta->inplace) /* Inplace requested? */
? mag_tensor_create(ctx, (*inputs)->dtype, (*inputs)->shape, (*inputs)->rank, *inputs, 0) /* View R <- X for inplace aliasing op. */
: (*r_alloc)(inputs, params); /* Construct new result tensor. */
if (mag_unlikely(!(*validate_op)(op, R, inputs, params))) return NULL; /* Validation failed. */
R->op = op; /* Set operation for deferred execution mode. */
for (uint32_t i=0; i < numin; ++i) { /* Set input tensors and flags. */
mag_tensor_t* R = (inplace && numin && meta->inplace) /* Inplace requested? */
? mag_tensor_create(ctx, (*inputs)->dtype, (*inputs)->shape, (*inputs)->rank, *inputs, 0) /* View R <- X for inplace aliasing op. */
: (*r_alloc)(inputs, params); /* Construct new result tensor. */
mag_assert((*validate_op)(op, R, inputs, params), "Invalid operation %s.", meta->mnemonic); /* Validate operation */
R->op = op; /* Set operation for deferred execution mode. */
for (uint32_t i=0; i < numin; ++i) { /* Set input tensors and flags. */
mag_tensor_t* input = inputs[i];
R->op_inputs[i] = input;
if (ctx->flags & MAG_CTX_FLAG_GRAD_RECORDER)
Expand Down Expand Up @@ -2533,6 +2542,7 @@ void mag_tensor_backward(mag_tensor_t* root) {
mag_tensor_t* grad = mag_tensor_create(child->ctx, child->dtype, child->shape, child->rank, NULL, 0);
grad->flags = (grad->flags | MAG_TFLAG_IS_GRAD) & ~MAG_TFLAG_REQUIRES_GRAD;
mag_tensor_fill(grad, 1.0f);
mag_tensor_fmt_name(grad, "%s (grad)", child->name);
child->grad = grad;
}
if (mag_unlikely(child->op == MAG_OP_NOP)) continue;
Expand All @@ -2546,7 +2556,8 @@ void mag_tensor_backward(mag_tensor_t* root) {
mag_assert2(input);
if (!(input->flags & MAG_TFLAG_REQUIRES_GRAD)) continue;
mag_tensor_t* gri = grads[i];
if (!gri) continue;
mag_tensor_fmt_name(gri, "%s (grad)", input->name);
mag_assert(gri, "Gradient for op %s, input #%d is not computed", meta->mnemonic, i);
if (!input->grad) {
gri->flags = (gri->flags | MAG_TFLAG_IS_GRAD) & ~MAG_TFLAG_REQUIRES_GRAD;
input->grad = gri;
Expand Down
4 changes: 2 additions & 2 deletions magnetron/magnetron_cpu_blas.inl
Original file line number Diff line number Diff line change
Expand Up @@ -998,8 +998,8 @@ static void MAG_HOTPROC mag_vtanh_dv_f32( /* tanh' : ℝ -> (-1, 1), x |-> 1 / (
const mag_f32_t* x
) {
for (int64_t i=0; i < numel; ++i) {
const mag_f32_t cx = coshf(x[i]);
o[i] = 1.0f / (cx*cx);
float t = tanhf(x[i]);
o[i] = 1.0f - t*t;
}
}

Expand Down
2 changes: 1 addition & 1 deletion python/benchmarks/bench_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def plot(self, flops_per_op: int=2, plot_style: str='bars'):
ax2.set_xticklabels(x_labels, rotation=45, ha='right')
else:
dims = [sa[0] for sa in self.shapes_a] # For square matrices, any dimension works
markers = ['o', '+', 'x', '*', '.', 'X', '^']
markers = ['o', '+', 'x', '*', '.', 'x', '^']

for i, participant in enumerate(self.participants):
ax1.plot(dims, participant.timings, label=participant.name,
Expand Down
76 changes: 36 additions & 40 deletions python/examples/xor.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,49 @@
# (c) 2025 Mario "Neo" Sieg. <[email protected]>

from magnetron import Tensor
from magnetron.layer import DenseLayer
from magnetron.model import SequentialModel, HyperParams
import matplotlib.pyplot as plt
from magnetron import Tensor, Module, Linear
from magnetron.optim import SGD, mse_loss

EPOCHS: int = 10000
RATE: float = 0.1

# Inputs: shape (4, 2)
inputs = Tensor.const([
class XOR(Module):
def __init__(self) -> None:
super().__init__()
self.l1 = Linear(2, 2)
self.l2 = Linear(2, 1)

def forward(self, x: Tensor) -> Tensor:
x = self.l1(x).tanh()
x = self.l2(x).tanh()
return x

model = XOR()
optim = SGD(model.parameters(), lr=1e-1)

x = Tensor.const([
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
], name='x')

# Targets: shape (4, 1)
targets = Tensor.const([
y = Tensor.const([
[0],
[1],
[1],
[0]
])

params = HyperParams(lr=RATE, epochs=EPOCHS)
mlp = SequentialModel(params, [
DenseLayer(2, 4),
DenseLayer(4, 1)
])

# Train
losses = mlp.train(inputs, targets)

# Inference
test_points = [
(0, 0),
(0, 1),
(1, 0),
(1, 1),
]

for (x_val, y_val) in test_points:
result = mlp.forward(Tensor.const([[x_val, y_val]]))[0]
print(f"{x_val} XOR {y_val} => {result:.4f}")

# Plot MSE loss
plt.plot(list(range(0, EPOCHS)), losses)
plt.xlabel('Epochs')
plt.ylabel('MSE Loss')
plt.title('XOR Problem')
plt.show()
], name='y')

epochs: int = 2000

y_hat = model(x)
print(y_hat)
for epoch in range(epochs):
y_hat = model(x)
loss = mse_loss(y_hat, y)
loss.backward()
optim.step()
optim.zero_grad()
if epoch % 1000 == 0:
print(f'Epoch: {epoch}, Loss: {loss.item()}')

y_hat = model(x)
print(y_hat)
2 changes: 2 additions & 0 deletions python/magnetron_framework/magnetron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
__url__ = 'https://github.com/MarioSieg/magnetron'

from .core import *
from .module import *
from .optim import *
78 changes: 0 additions & 78 deletions python/magnetron_framework/magnetron/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,6 @@ class Tensor:
__slots__ = ('_ctx', '_ptr', '_inputs')

def __init__(self, ptr: ffi.CData | None = None) -> None:
if isinstance(ptr, ffi.CData):
assert ptr != ffi.NULL, 'Invalid tensor pointer'
self._ctx = None
self._ptr = ptr
self._inputs = None
Expand Down Expand Up @@ -927,79 +925,3 @@ def __setitem__(self, indices: int | tuple[int, ...], value: float) -> None:
C.mag_tensor_set_scalar_physical_index(self._ptr, *idx, float(value))
else:
raise TypeError('Indices must be an int or a tuple of ints.')


class Parameter:
"""A tensor that is a learnable parameter of a model."""

__slots__ = ('x',)

def __init__(self, x: Tensor) -> None:
x.requires_grad = True
self.x = x


class Module:
"""Base class for all neural network modules."""

def parameters(self) -> list[Parameter]:
"""Returns all unique and nested parameters of the module."""
params: list[Parameter] = []
for k, v in self.__dict__.items():
if isinstance(v, Parameter):
params.append(v)
elif isinstance(v, Module):
params += v.parameters()
elif isinstance(v, ModuleList):
for mod in v:
params += mod.parameters()
return list(set(params))

def eval(self) -> None:
"""Sets the module in evaluation mode."""
for p in self.parameters():
p.x.requires_grad = False

def train(self) -> None:
"""Sets the module in training mode."""
for p in self.parameters():
p.x.requires_grad = True

def forward(self, *args: Tensor, **kwargs: dict) -> Tensor:
"""Forward pass must be implemented by subclass."""
raise NotImplementedError

def __call__(self, *args: Tensor, **kwargs: dict) -> Tensor:
return self.forward(*args, **kwargs)


class ModuleList(Module, list):
"""A list of modules that can be used as a single module."""

def __init__(self, mods: list[Module] | None) -> None:
super().__init__()
if mods is not None:
self += mods

def append(self, mod: Module) -> None:
super().append(mod)

def extend(self, __iterable: list[Module]) -> None:
super().extend(__iterable)

def __iadd__(self, other: list[Module]) -> 'ModuleList':
self.extend(other)
return self

def __setitem__(self, k: int, v: Module) -> None:
super().__setitem__(k, v)

def __getitem__(self, k: int) -> Module:
return super().__getitem__(k)

def parameters(self) -> list[Parameter]:
"""Returns all unique and nested parameters of the module."""
params: list[Parameter] = []
for mod in self:
params += mod.parameters()
return list(set(params))
Loading

0 comments on commit 29987b4

Please sign in to comment.