Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion lambeq/training/pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ def __init__(self,
loss_function: Callable[..., torch.Tensor],
epochs: int,
optimizer: type[torch.optim.Optimizer] = torch.optim.AdamW,
scheduler: type[torch.optim.lr_scheduler.LRScheduler] | None
= None,
learning_rate: float = 1e-3,
device: int = -1,
*,
optimizer_args: dict[str, Any] | None = None,
scheduler_args: dict[str, Any] | None = None,
evaluate_functions: Mapping[str, EvalFuncT] | None = None,
evaluate_on_train: bool = True,
use_tensorboard: bool = False,
Expand All @@ -66,13 +69,18 @@ def __init__(self,
Number of training epochs.
optimizer : torch.optim.Optimizer, default: torch.optim.AdamW
A PyTorch optimizer from `torch.optim`.
scheduler : torch.optim.lr_scheduler, default: None
A PyTorch scheduler for the learning rate, from
`torch.optim.lr_scheduler`.
learning_rate : float, default: 1e-3
The learning rate provided to the optimizer for training.
device : int, default: -1
CUDA device ID used for tensor operation speed-up.
A negative value uses the CPU.
optimizer_args : dict of str to Any, optional
Any extra arguments to pass to the optimizer.
scheduler_args : dict of str to Any, optional
Any extra arguments to pass to the scheduler.
evaluate_functions : mapping of str to callable, optional
Mapping of evaluation metric functions from their names.
Structure [{"metric": func}].
Expand Down Expand Up @@ -118,6 +126,11 @@ def __init__(self,
if learning_rate is not None:
optimizer_args['lr'] = learning_rate
self.optimizer = optimizer(self.model.parameters(), **optimizer_args)

scheduler_args = dict(scheduler_args or {})
self.scheduler = (scheduler(self.optimizer, **scheduler_args)
if scheduler is not None else None)

self.model.to(self.device)

def _add_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None:
Expand All @@ -136,7 +149,10 @@ def _add_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None:
"""
checkpoint.add_many(
{'torch_random_state': torch.get_rng_state(),
'optimizer_state_dict': self.optimizer.state_dict()})
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict()
if self.scheduler is not None
else None})

def _load_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None:
"""Load additional checkpoint information.
Expand All @@ -152,6 +168,9 @@ def _load_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None:
"""
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
torch.set_rng_state(checkpoint['torch_random_state'])
if (checkpoint['scheduler_state_dict'] is not None
and self.scheduler is not None):
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

def validation_step(
self,
Expand Down Expand Up @@ -201,3 +220,13 @@ def training_step(
loss.backward()
self.optimizer.step()
return y_hat, loss.item()

def post_epoch_step(self, epoch: int, loss: float):
# Step the scheduler if present
if self.scheduler is not None:
# Plateau scheduler wants to know about the loss
if isinstance(self.scheduler,
torch.optim.lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(loss, epoch=epoch)
else:
self.scheduler.step(epoch=epoch)
11 changes: 9 additions & 2 deletions lambeq/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,11 @@ def validation_step(

"""

def post_epoch_step(self, epoch: int, loss: float):
"""Perform any post-epoch updates, such as updating the scheduled
learning rate."""
return None

def _get_weighted_mean(self,
metric_running: list[tuple[int, Any]]):
"""Calculate weighted mean of metric from the running results."""
Expand Down Expand Up @@ -761,13 +766,15 @@ def fit(self,
if early_stopping:
break # inner epoch loop

epoch_loss = self._get_weighted_mean(train_losses)
self.post_epoch_step(epoch, loss=epoch_loss)

epoch_end = time.time()
epoch_duration = epoch_end - epoch_start
self.train_epoch_durations.append(epoch_duration)

# calculate epoch loss
self.train_epoch_costs.append(
self._get_weighted_mean(train_losses))
self.train_epoch_costs.append(epoch_loss)
self._to_tensorboard('train/epoch_loss',
self.train_epoch_costs[-1],
epoch)
Expand Down
12 changes: 11 additions & 1 deletion tests/training/test_pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import numpy as np
import pytest

from lambeq.backend.grammar import Cup, Id, Word
from lambeq.backend.tensor import Dim
Expand Down Expand Up @@ -36,7 +37,8 @@
dev_circuits = [ansatz(d) for d in dev_diagrams]


def test_trainer(tmp_path):
@pytest.mark.parametrize("scheduler", [None, torch.optim.lr_scheduler.StepLR])
def test_trainer(tmp_path, scheduler):
model = PytorchModel.from_diagrams(train_circuits + dev_circuits)

log_dir = tmp_path / 'test_runs'
Expand All @@ -45,6 +47,8 @@ def test_trainer(tmp_path):
model=model,
loss_function=torch.nn.BCEWithLogitsLoss(),
optimizer=torch.optim.AdamW,
scheduler=scheduler,
scheduler_args={"step_size": 1, "gamma": 0.1},
learning_rate=3e-3,
epochs=EPOCHS,
evaluate_functions={"acc": acc},
Expand All @@ -66,6 +70,12 @@ def test_trainer(tmp_path):
assert len(trainer.train_durations) == EPOCHS * (
ceil(len(train_diagrams) / train_dataset.batch_size))
assert len(trainer.val_durations) == EPOCHS
if scheduler is not None:
# Expect lr to have decayed exactly once, up to float error = 0.1 * 3e-3
assert torch.allclose(
torch.tensor(trainer.scheduler.get_last_lr()),
torch.tensor(3e-4)
)


def test_restart_training(tmp_path):
Expand Down
Loading