This is a blog series on full stack ML application with PyTorch. This post talks about the main process and learnings for re-implementing lab2 in PT. The main goal of lab2 is to get familiar with the APIs of datasets, simple networks, and the training pipeline.
TLDR:
1. 3-step (re)abstraction of dataset class
2. Warm up with simple NNs
NN-wise
Re-implement LeNet and MLP from Keras-based to PyTorch, easy.
Data-wise
The original API design for dataset object is a 3-step abstraction:
Dataset
(an abstract class containing basic ops like _download_raw_dataset) – EmnistDataset
(a named dataset for containing the metadata, like num_classes, input_shape) – DataSequence
(a class based on Keras’ Sequence class, dealing with a batch of data).
DataSequence
has the classic signature of __len__
and __getitem__
, and contains the data augmentation function, thus is equivalent to PT’s Dataset
+DataLoader
.
With this in mind, for re-implementation, the underlying logic in PT becomes: Dataset
(minimal class with util class method) – EmnistDatset
(similar) – DatasetSequence
(equivalent to a PT’s DataLoader(Dataset())
, which takes a NamedDataset’s config and returns a DataLoader).
One thing to note: DatasetSequence
is more of a util wrapper, in high level it should only be used in .fit()/.predict()
and shouldn’t be intrinsic to a Model or a Dataset.
In addition, originally the method _download_and_process_emnist()
is contained in Dataset
, which is contained in both Model and Predictor. This means that even when serving the predictor, the model will download and process the training dataset. This seems a design flaw and thus I moved this method out to be a standalone function, which is only used when creating a Model but not Predictor.
Ops-wise
Remove Callback
function: Callback
here is used for logging with Weights&Biases, thus I just removed it and didn’t re-implement. However generally speaking, Callback
system proved to be an important util, and many PT-based frameworks (FastAI, Pytorch-lightning etc) have provided such a system.
Manual GPU management: Keras enables GPU by manually feeding os.environ["CUDA_VISIBLE_DEVICES"]
. Here I used the classic signature in PT:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Comments