This is a blog series on full stack ML application with PyTorch. This post talks about the main process and learnings for re-implementing lab6 in PT. In this lab we implement a model for line detection, which takes in an image of multiple lines, and outputs the coordinates of the image segments that contain lines.
TLDR:
1. Apply the same transformation to both X and y during data augmentation.
2. Implement FCN and UNet; manually calculate padding for dilation
NN-wise
Two NNs are used: Fully Convolutional Network (FCN) and U-Net. These two NNs require dilation, which requires padding of images (from smaller perception field to larger). Keras handles input/output shape automatically, while in PT you need to manually calculate padding logic (instruction).
Data-wise
This lab introduced a new dataset: IAM paragraph
, which consists of ~1500 images. The highlight here is to leverage data augmentation. Keras has the pre-built class called ImageDataGenerator
which just needs a config dictionary to specify the transformation. In PT we leveraged torchvision.transforms
. Initially I implemented a series of random transformations, and things seemingly run OK, except that the training loss is not decreasing: it turns out the random transformation was applied separately on X (image)
and y (coordinates)
, which essentially broke the correlation between the two!
To solve this problem, in PT we need a special pair_transform()
to apply same transformation on multiple objects at the same time. The best way to do so is to store the random transformation each round, then apply it via the functional API. It seems there is no workaround to achieve so using object-based API (seems there's an open issue). An alternative is to keep random seed.
Ops-wise
Nothing particularly challenging for this lab, except that we need to output probability (as opposed to the log_probability during training) to accommodate _find_line_bounding_boxes
interface in the LineDetectorModel
, so that during evaluation step with this output the best segmentation can be found.
Annoyance:
Numpy based image is processed as hwc
shape unless it’s grayscale image, while PT assumes nchw
shape (note the position of c
dimension).
Comments