Pythae 0.1.1
New features
- Added the training_callback
TrainHistoryCallback
that stores the training metrics during training in #71 by @VolodyaCO
from pythae.trainers.training_callbacks import TrainHistoryCallback
>>> train_history = TrainHistoryCallback()
>>> callbacks = [train_history]
>>> pipeline(
... train_data=train_dataset,
... eval_data=eval_dataset,
... callbacks=callbacks
... )
>>> train_history.history
... {
... 'train_loss': [58.51896972363562, 42.15931177749049, 40.583426756017346],
... 'eval_loss': [43.39408182034827, 41.45351771943888, 39.77221281209569]
... }
- Added a
predict
method that encodes and decodes input data without loss computation in #75 by @soumickmj and @ravih18
>>> out = model.predict(eval_dataset[:3])
>>> out.embedding.shape, out.recon_x.shape
... (torch.Size([3, 16]), torch.Size([3, 1, 28, 28]))
>>> out = model.embed(eval_dataset[:3].to(device))
>>> out.shape
... torch.Size([3, 16])