
Product
Socket Now Protects the Chrome Extension Ecosystem
Socket is launching experimental protection for Chrome extensions, scanning for malware and risky permissions to prevent silent supply chain attacks.
Image Test Time Augmentation with PyTorch!
Similar to what Data Augmentation is doing to the training set, the purpose of Test Time Augmentation is to perform random modifications to the test images. Thus, instead of showing the regular, “clean” images, only once to the trained model, we will show it the augmented images several times. We will then average the predictions of each corresponding image and take that as our final guess [1].
Input
| # input batch of images
/ / /|\ \ \ # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
\ \ \ / / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
Note: the model must return keypoints in the format torch([x1, y1, ..., xn, yn])
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.Rotate90(angles=[0, 180]),
tta.Scale(scales=[1, 2, 4]),
tta.Multiply(factors=[0.9, 1, 1.1]),
]
)
tta_model = tta.SegmentationTTAWrapper(model, transforms)
# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform()
# augment image
augmented_image = transformer.augment_image(image)
# pass to model
model_output = model(augmented_image, another_input_data)
# reverse augmentation for mask and label
deaug_mask = transformer.deaugment_mask(model_output['mask'])
deaug_label = transformer.deaugment_label(model_output['label'])
# save results
labels.append(deaug_mask)
masks.append(deaug_label)
# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)
Transform | Parameters | Values |
---|---|---|
HorizontalFlip | - | - |
VerticalFlip | - | - |
Rotate90 | angles | List[0, 90, 180, 270] |
Scale | scales interpolation | List[float] "nearest"/"linear" |
Resize | sizes original_size interpolation | List[Tuple[int, int]] Tuple[int,int] "nearest"/"linear" |
Add | values | List[float] |
Multiply | factors | List[float] |
FiveCrops | crop_height crop_width | int int |
PyPI:
$ pip install ttach
Source:
$ pip install git+https://github.com/qubvel/ttach
docker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider
FAQs
Images test time augmentation with PyTorch.
We found that ttach demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Product
Socket is launching experimental protection for Chrome extensions, scanning for malware and risky permissions to prevent silent supply chain attacks.
Product
Add secure dependency scanning to Claude Desktop with Socket MCP, a one-click extension that keeps your coding conversations safe from malicious packages.
Product
Socket now supports Scala and Kotlin, bringing AI-powered threat detection to JVM projects with easy manifest generation and fast, accurate scans.