Building a platform to train mixed-input deep learning models from scratch. Based the lessons from part 2 of fast.ai. To illustrate the concepts, we will be training a model that accepts both tabular and image data and output a single prediction. This will be a multi-part series, with each part addressing a subset of the total componets as outlined here:
- fastai
- torch
- torchvision
- numpy
- pandas
Deep learning models are trained on batches. The dataloader is responsible for iterating over the dataset and producing these batches. To accomplish this we will go over constructing the following components:
- A Dataset class, which defines how items are retrieved from the source data
- A Sampler class, which outlines how items are sampled from the dataset to assemble the batches
- A Collate function, which asssembles the sampled items into the input data batch (xb) and labels batch (yb)
- A Transform class, which serves as a placeholder for the transforms that we will introduce later
- A DataLoader class, which assembles all the components together. Once you have your dataloader, you can simply call
for xb, yb in dataloader:
to generate all the batches for one epoch (or round) of training.
PyTorch models are composed of nn.Module instances that are strung together in sequence to pass the input data through the model. We will go over how to create your own modules and assemble them into a custom network:
Training involves a specific sequence of steps:
- Run the batch of input data (xb) through the model
- Compare it to the labels using a Loss Function, which assigns how the penalty for incorrect predictions is assigned
- A backward step, which calculates the parameter gradients using backpropagation
- An Optimizer step, which updates all the trainable parameters of the network
- Reseting all parameter gradients
We can customize training by inserting openings for functions to run at key steps during training (e.g. after the model prediction) called callbacks. We will combine the basic train cycle and callbacks with a new Learner class. This will allow us to customize training without alterning the base code for the trainer. Callbacks can perform a host of functions ranging from logging training statistics to scheduling training parameters.
Our callbacks system works well, but doesn't allow us to delve into the individual layers of the network. Although it is possible to create our own method of doing this through the HookedSequential class, we can also make use of the hook system inherent to PyTorch.
To be implemented