TTAch
Image Test Time Augmentation with PyTorch!
Similar to what Data Augmentation is doing to the training set, the purpose of Test Time Augmentation is to perform random modifications to the test images. Thus, instead of showing the regular, “clean” images, only once to the trained model, we will show it the augmented images several times. We will then average the predictions of each corresponding image and take that as our final guess [1].
Input
| # input batch of images
/ / /|\ \ \ # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
\ \ \ / / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
Table of Contents
- Quick Start
- Transforms
- Aliases
- Merge modes
- Installation
Quick start
Segmentation model wrapping:
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
Classification model wrapping:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
Keypoints model wrapping:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
Note: the model must return keypoints in the format torch([x1, y1, ..., xn, yn])
Advanced Examples
Custom transform:
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.Rotate90(angles=[0, 180]),
tta.Scale(scales=[1, 2, 4]),
tta.Multiply(factors=[0.9, 1, 1.1]),
]
)
tta_model = tta.SegmentationTTAWrapper(model, transforms)
Custom model (multi-input / multi-output)
for transformer in transforms:
augmented_image = transformer.augment_image(image)
model_output = model(augmented_image, another_input_data)
deaug_mask = transformer.deaugment_mask(model_output['mask'])
deaug_label = transformer.deaugment_label(model_output['label'])
labels.append(deaug_mask)
masks.append(deaug_label)
label = mean(labels)
mask = mean(masks)
Transforms
Transform | Parameters | Values |
---|
HorizontalFlip | - | - |
VerticalFlip | - | - |
Rotate90 | angles | List[0, 90, 180, 270] |
Scale | scales interpolation | List[float] "nearest"/"linear" |
Resize | sizes original_size interpolation | List[Tuple[int, int]] Tuple[int,int] "nearest"/"linear" |
Add | values | List[float] |
Multiply | factors | List[float] |
FiveCrops | crop_height crop_width | int int |
Aliases
- flip_transform (horizontal + vertical flips)
- hflip_transform (horizontal flip)
- d4_transform (flips + rotation 0, 90, 180, 270)
- multiscale_transform (scale transform, take scales as input parameter)
- five_crop_transform (corner crops + center crop)
- ten_crop_transform (five crops + five crops on horizontal flip)
Merge modes
Installation
PyPI:
$ pip install ttach
Source:
$ pip install git+https://github.com/qubvel/ttach
Run tests
docker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider