Socket
Socket
Sign inDemoInstall

ttach

Package Overview
Dependencies
1
Maintainers
1
Alerts
File Explorer

Install Socket

Detect and block malicious and high-risk dependencies

Install

    ttach

Images test time augmentation with PyTorch.


Maintainers
1

Readme

TTAch

Image Test Time Augmentation with PyTorch!

Similar to what Data Augmentation is doing to the training set, the purpose of Test Time Augmentation is to perform random modifications to the test images. Thus, instead of showing the regular, “clean” images, only once to the trained model, we will show it the augmented images several times. We will then average the predictions of each corresponding image and take that as our final guess [1].

           Input
             |           # input batch of images 
        / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)
       | | | | | | |     # pass augmented batches through model
       | | | | | | |     # reverse transformations for each batch of masks/labels
        \ \ \ / / /      # merge predictions (mean, max, gmean, etc.)
             |           # output batch of masks/labels
           Output

Table of Contents

  1. Quick Start
  2. Transforms
  3. Aliases
  4. Merge modes
  5. Installation

Quick start

Segmentation model wrapping:
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
Classification model wrapping:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
Keypoints model wrapping:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)

Note: the model must return keypoints in the format torch([x1, y1, ..., xn, yn])

Advanced Examples

Custom transform:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 180]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1, 1.1]),        
    ]
)

tta_model = tta.SegmentationTTAWrapper(model, transforms)
Custom model (multi-input / multi-output)
# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)

for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 

    # augment image
    augmented_image = transformer.augment_image(image)

    # pass to model
    model_output = model(augmented_image, another_input_data)

    # reverse augmentation for mask and label
    deaug_mask = transformer.deaugment_mask(model_output['mask'])
    deaug_label = transformer.deaugment_label(model_output['label'])

    # save results
    labels.append(deaug_mask)
    masks.append(deaug_label)

# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)

Transforms

TransformParametersValues
HorizontalFlip--
VerticalFlip--
Rotate90anglesList[0, 90, 180, 270]
Scalescales
interpolation
List[float]
"nearest"/"linear"
Resizesizes
original_size
interpolation
List[Tuple[int, int]]
Tuple[int,int]
"nearest"/"linear"
AddvaluesList[float]
MultiplyfactorsList[float]
FiveCropscrop_height
crop_width
int
int

Aliases

  • flip_transform (horizontal + vertical flips)
  • hflip_transform (horizontal flip)
  • d4_transform (flips + rotation 0, 90, 180, 270)
  • multiscale_transform (scale transform, take scales as input parameter)
  • five_crop_transform (corner crops + center crop)
  • ten_crop_transform (five crops + five crops on horizontal flip)

Merge modes

Installation

PyPI:

$ pip install ttach

Source:

$ pip install git+https://github.com/qubvel/ttach

Run tests

docker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider

FAQs


Did you know?

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc