This is a blog series on full stack ML application with PyTorch. This post talks about the main process and learnings for re-implementing lab2 in PT. The main goal of lab2 is to get familiar with the APIs of datasets, simple networks, and the training pipeline.

TLDR:
1. 3-step (re)abstraction of dataset class
2. Warm up with simple NNs

NN-wise

Re-implement LeNet and MLP from Keras-based to PyTorch, easy.

Data-wise

The original API design for dataset object is a 3-step abstraction: Dataset (an abstract class containing basic ops like _download_raw_dataset) – EmnistDataset (a named dataset for containing the metadata, like num_classes, input_shape) – DataSequence (a class based on Keras’ Sequence class, dealing with a batch of data).
DataSequence has the classic signature of __len__ and __getitem__, and contains the data augmentation function, thus is equivalent to PT’s Dataset+DataLoader.

With this in mind, for re-implementation, the underlying logic in PT becomes: Dataset (minimal class with util class method) – EmnistDatset (similar) – DatasetSequence (equivalent to a PT’s DataLoader(Dataset()), which takes a NamedDataset’s config and returns a DataLoader). One thing to note: DatasetSequence is more of a util wrapper, in high level it should only be used in .fit()/.predict() and shouldn’t be intrinsic to a Model or a Dataset. In addition, originally the method _download_and_process_emnist() is contained in Dataset, which is contained in both Model and Predictor. This means that even when serving the predictor, the model will download and process the training dataset. This seems a design flaw and thus I moved this method out to be a standalone function, which is only used when creating a Model but not Predictor.

Ops-wise

Remove Callback function: Callback here is used for logging with Weights&Biases, thus I just removed it and didn’t re-implement. However generally speaking, Callback system proved to be an important util, and many PT-based frameworks (FastAI, Pytorch-lightning etc) have provided such a system.

Manual GPU management: Keras enables GPU by manually feeding os.environ["CUDA_VISIBLE_DEVICES"]. Here I used the classic signature in PT:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Comments

comments powered by Disqus

Published

Category

research

Tags