- Published on
Test-Time Augmentation (TTA): Boosting Model Performance at Inference
- Authors
- Name
- Jared Chung
Test-Time Augmentation (TTA) is a powerful technique that improves model predictions by applying data augmentation not just during training, but also at inference time. By creating multiple augmented versions of a test image and aggregating their predictions, TTA can significantly boost model performance with minimal computational overhead.
In this comprehensive guide, we'll explore various TTA strategies, implementation techniques, and advanced methods for maximizing their effectiveness.
What is Test-Time Augmentation?
Test-Time Augmentation involves:
- Creating multiple versions of the input image using various transformations
- Running inference on each augmented version
- Aggregating predictions (typically averaging) to get the final result
- Improving robustness by reducing prediction variance
The key insight is that while individual predictions may vary, the averaged prediction across multiple augmentations tends to be more reliable and accurate.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Callable, Union
import cv2
# Basic TTA implementation
class TestTimeAugmentation:
def __init__(self, model, device='cuda'):
self.model = model
self.device = device
self.model.eval()
def predict_with_tta(self, image, augmentations, aggregation='mean'):
"""
Perform TTA on a single image
Args:
image: Input image (PIL Image or tensor)
augmentations: List of augmentation functions
aggregation: Method to aggregate predictions ('mean', 'geometric_mean', 'voting')
"""
predictions = []
with torch.no_grad():
# Get prediction for original image
original_pred = self._get_prediction(image)
predictions.append(original_pred)
# Get predictions for augmented versions
for aug_fn in augmentations:
aug_image = aug_fn(image)
aug_pred = self._get_prediction(aug_image)
predictions.append(aug_pred)
# Aggregate predictions
predictions = torch.stack(predictions)
if aggregation == 'mean':
final_pred = torch.mean(predictions, dim=0)
elif aggregation == 'geometric_mean':
final_pred = torch.exp(torch.mean(torch.log(predictions + 1e-8), dim=0))
elif aggregation == 'voting':
# Hard voting (for classification)
votes = torch.argmax(predictions, dim=-1)
final_pred = torch.mode(votes, dim=0)[0]
else:
raise ValueError(f"Unknown aggregation method: {aggregation}")
return final_pred, predictions
def _get_prediction(self, image):
"""Get model prediction for a single image"""
if isinstance(image, Image.Image):
# Convert PIL to tensor
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image).unsqueeze(0)
image = image.to(self.device)
output = self.model(image)
# Apply softmax for classification
if output.dim() == 2 and output.size(1) over 1:
output = F.softmax(output, dim=1)
return output.squeeze(0)
Standard TTA Transformations
Geometric Augmentations
class GeometricTTA:
"""Standard geometric transformations for TTA"""
@staticmethod
def horizontal_flip(image):
"""Horizontal flip"""
if isinstance(image, Image.Image):
return transforms.functional.hflip(image)
return torch.flip(image, [-1])
@staticmethod
def vertical_flip(image):
"""Vertical flip"""
if isinstance(image, Image.Image):
return transforms.functional.vflip(image)
return torch.flip(image, [-2])
@staticmethod
def rotation_90(image):
"""90-degree rotation"""
if isinstance(image, Image.Image):
return image.rotate(90, expand=True)
return torch.rot90(image, k=1, dims=[-2, -1])
@staticmethod
def rotation_180(image):
"""180-degree rotation"""
if isinstance(image, Image.Image):
return image.rotate(180)
return torch.rot90(image, k=2, dims=[-2, -1])
@staticmethod
def rotation_270(image):
"""270-degree rotation"""
if isinstance(image, Image.Image):
return image.rotate(270, expand=True)
return torch.rot90(image, k=3, dims=[-2, -1])
@staticmethod
def transpose(image):
"""Transpose (diagonal flip)"""
if isinstance(image, Image.Image):
return image.transpose(Image.TRANSPOSE)
return torch.transpose(image, -2, -1)
@staticmethod
def transverse(image):
"""Transverse (anti-diagonal flip)"""
if isinstance(image, Image.Image):
return image.transpose(Image.TRANSVERSE)
return torch.flip(torch.transpose(image, -2, -1), [-1])
# Eight-fold TTA (D4 group transformations)
def get_d4_transformations():
"""Get all 8 transformations of the D4 dihedral group"""
return [
lambda x: x, # Identity
GeometricTTA.horizontal_flip,
GeometricTTA.vertical_flip,
GeometricTTA.rotation_90,
GeometricTTA.rotation_180,
GeometricTTA.rotation_270,
GeometricTTA.transpose,
GeometricTTA.transverse
]
Multi-Scale TTA
class MultiScaleTTA:
"""Multi-scale test-time augmentation"""
def __init__(self, scales=[0.8, 0.9, 1.0, 1.1, 1.2], base_size=224):
self.scales = scales
self.base_size = base_size
def generate_scale_transforms(self):
"""Generate transforms for different scales"""
transforms_list = []
for scale in self.scales:
size = int(self.base_size * scale)
transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.CenterCrop(self.base_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
transforms_list.append(transform)
return transforms_list
def predict_multi_scale(self, model, image, device='cuda'):
"""Predict with multi-scale TTA"""
model.eval()
predictions = []
with torch.no_grad():
for transform in self.generate_scale_transforms():
scaled_image = transform(image).unsqueeze(0).to(device)
output = model(scaled_image)
if output.dim() == 2 and output.size(1) over 1:
output = F.softmax(output, dim=1)
predictions.append(output.squeeze(0))
# Average predictions
final_prediction = torch.mean(torch.stack(predictions), dim=0)
return final_prediction
Crop-based TTA
class CropBasedTTA:
"""TTA using different crop positions"""
def __init__(self, crop_size=224, stride=32):
self.crop_size = crop_size
self.stride = stride
def generate_crops(self, image):
"""Generate multiple crops from different positions"""
if isinstance(image, Image.Image):
w, h = image.size
image_np = np.array(image)
else:
h, w = image.shape[-2:]
image_np = image.permute(1, 2, 0).numpy()
crops = []
positions = []
# Center crop
center_x, center_y = w // 2, h // 2
x1 = center_x - self.crop_size // 2
y1 = center_y - self.crop_size // 2
x2 = x1 + self.crop_size
y2 = y1 + self.crop_size
if x1 >= 0 and y1 >= 0 and x2 <= w and y2 <= h:
crop = image_np[y1:y2, x1:x2]
crops.append(Image.fromarray(crop))
positions.append((x1, y1, x2, y2))
# Corner crops
corners = [
(0, 0), # Top-left
(w - self.crop_size, 0), # Top-right
(0, h - self.crop_size), # Bottom-left
(w - self.crop_size, h - self.crop_size) # Bottom-right
]
for x1, y1 in corners:
if x1 >= 0 and y1 >= 0:
x2, y2 = x1 + self.crop_size, y1 + self.crop_size
crop = image_np[y1:y2, x1:x2]
crops.append(Image.fromarray(crop))
positions.append((x1, y1, x2, y2))
# Random crops
for _ in range(5):
x1 = np.random.randint(0, max(1, w - self.crop_size))
y1 = np.random.randint(0, max(1, h - self.crop_size))
x2, y2 = x1 + self.crop_size, y1 + self.crop_size
crop = image_np[y1:y2, x1:x2]
crops.append(Image.fromarray(crop))
positions.append((x1, y1, x2, y2))
return crops, positions
def predict_with_crops(self, model, image, device='cuda'):
"""Predict using multiple crops"""
crops, positions = self.generate_crops(image)
model.eval()
predictions = []
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
with torch.no_grad():
for crop in crops:
crop_tensor = transform(crop).unsqueeze(0).to(device)
output = model(crop_tensor)
if output.dim() == 2 and output.size(1) over 1:
output = F.softmax(output, dim=1)
predictions.append(output.squeeze(0))
# Average predictions
final_prediction = torch.mean(torch.stack(predictions), dim=0)
return final_prediction, predictions
Advanced TTA Techniques
Noise-based TTA
class NoiseTTA:
"""TTA using different types of noise"""
def __init__(self, noise_levels=[0.01, 0.02, 0.03]):
self.noise_levels = noise_levels
def add_gaussian_noise(self, image, std):
"""Add Gaussian noise to image"""
if isinstance(image, Image.Image):
image = transforms.functional.to_tensor(image)
noise = torch.randn_like(image) * std
noisy_image = torch.clamp(image + noise, 0, 1)
return transforms.functional.to_pil_image(noisy_image)
def add_uniform_noise(self, image, magnitude):
"""Add uniform noise to image"""
if isinstance(image, Image.Image):
image = transforms.functional.to_tensor(image)
noise = (torch.rand_like(image) - 0.5) * 2 * magnitude
noisy_image = torch.clamp(image + noise, 0, 1)
return transforms.functional.to_pil_image(noisy_image)
def generate_noise_transforms(self):
"""Generate noise-based transformations"""
transforms_list = []
for std in self.noise_levels:
transforms_list.append(lambda x, s=std: self.add_gaussian_noise(x, s))
transforms_list.append(lambda x, m=std: self.add_uniform_noise(x, m))
return transforms_list
# Color augmentation TTA
class ColorTTA:
"""TTA using color space manipulations"""
def __init__(self):
pass
def brightness_variants(self, image, factors=[0.9, 1.0, 1.1]):
"""Generate brightness variants"""
variants = []
for factor in factors:
enhancer = transforms.ColorJitter(brightness=factor)
variants.append(enhancer(image))
return variants
def contrast_variants(self, image, factors=[0.9, 1.0, 1.1]):
"""Generate contrast variants"""
variants = []
for factor in factors:
enhancer = transforms.ColorJitter(contrast=factor)
variants.append(enhancer(image))
return variants
def saturation_variants(self, image, factors=[0.9, 1.0, 1.1]):
"""Generate saturation variants"""
variants = []
for factor in factors:
enhancer = transforms.ColorJitter(saturation=factor)
variants.append(enhancer(image))
return variants
Ensemble TTA
class EnsembleTTA:
"""TTA with model ensembles"""
def __init__(self, models, device='cuda'):
self.models = models
self.device = device
# Set all models to evaluation mode
for model in self.models:
model.eval()
def predict_ensemble_tta(self, image, augmentations,
model_weights=None, aug_weights=None):
"""
Perform TTA with multiple models
Args:
image: Input image
augmentations: List of augmentation functions
model_weights: Weights for each model (optional)
aug_weights: Weights for each augmentation (optional)
"""
if model_weights is None:
model_weights = [1.0] * len(self.models)
if aug_weights is None:
aug_weights = [1.0] * (len(augmentations) + 1) # +1 for original
all_predictions = []
# Get predictions from all models and augmentations
with torch.no_grad():
for model_idx, model in enumerate(self.models):
model_preds = []
# Original image
original_pred = self._get_prediction(model, image)
model_preds.append(original_pred * aug_weights[0])
# Augmented images
for aug_idx, aug_fn in enumerate(augmentations):
aug_image = aug_fn(image)
aug_pred = self._get_prediction(model, aug_image)
model_preds.append(aug_pred * aug_weights[aug_idx + 1])
# Weight by model importance
weighted_model_pred = torch.stack(model_preds).sum(dim=0) * model_weights[model_idx]
all_predictions.append(weighted_model_pred)
# Combine all predictions
final_prediction = torch.stack(all_predictions).sum(dim=0)
final_prediction = final_prediction / final_prediction.sum() # Normalize
return final_prediction
def _get_prediction(self, model, image):
"""Get prediction from a single model"""
if isinstance(image, Image.Image):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image).unsqueeze(0)
image = image.to(self.device)
output = model(image)
if output.dim() == 2 and output.size(1) over 1:
output = F.softmax(output, dim=1)
return output.squeeze(0)
Segmentation TTA
class SegmentationTTA:
"""TTA for segmentation tasks"""
def __init__(self, model, device='cuda'):
self.model = model
self.device = device
self.model.eval()
def predict_with_tta(self, image, return_all=False):
"""
Perform TTA for segmentation
Args:
image: Input image tensor [C, H, W]
return_all: Whether to return all predictions
"""
predictions = []
with torch.no_grad():
# Original
pred = self._predict_single(image)
predictions.append(pred)
# Horizontal flip
flipped_h = torch.flip(image, [-1])
pred_h = self._predict_single(flipped_h)
pred_h = torch.flip(pred_h, [-1]) # Flip back
predictions.append(pred_h)
# Vertical flip
flipped_v = torch.flip(image, [-2])
pred_v = self._predict_single(flipped_v)
pred_v = torch.flip(pred_v, [-2]) # Flip back
predictions.append(pred_v)
# Both flips
flipped_hv = torch.flip(image, [-2, -1])
pred_hv = self._predict_single(flipped_hv)
pred_hv = torch.flip(pred_hv, [-2, -1]) # Flip back
predictions.append(pred_hv)
# Rotations
for k in [1, 2, 3]: # 90, 180, 270 degrees
rotated = torch.rot90(image, k=k, dims=[-2, -1])
pred_rot = self._predict_single(rotated)
pred_rot = torch.rot90(pred_rot, k=-k, dims=[-2, -1]) # Rotate back
predictions.append(pred_rot)
# Average predictions
predictions = torch.stack(predictions)
avg_prediction = torch.mean(predictions, dim=0)
if return_all:
return avg_prediction, predictions
return avg_prediction
def _predict_single(self, image):
"""Get prediction for a single image"""
if image.dim() == 3:
image = image.unsqueeze(0)
image = image.to(self.device)
with torch.no_grad():
output = self.model(image)
# Apply softmax if multi-class
if output.size(1) over 1:
output = F.softmax(output, dim=1)
return output.squeeze(0)
def multi_scale_tta(self, image, scales=[0.75, 1.0, 1.25]):
"""Multi-scale TTA for segmentation"""
original_size = image.shape[-2:]
predictions = []
with torch.no_grad():
for scale in scales:
# Resize image
new_size = [int(s * scale) for s in original_size]
scaled_image = F.interpolate(
image.unsqueeze(0),
size=new_size,
mode='bilinear',
align_corners=False
).squeeze(0)
# Get prediction
pred = self._predict_single(scaled_image)
# Resize prediction back to original size
pred = F.interpolate(
pred.unsqueeze(0),
size=original_size,
mode='bilinear',
align_corners=False
).squeeze(0)
predictions.append(pred)
# Average predictions
avg_prediction = torch.mean(torch.stack(predictions), dim=0)
return avg_prediction
Adaptive TTA
class AdaptiveTTA:
"""Adaptive TTA that selects augmentations based on uncertainty"""
def __init__(self, model, device='cuda', uncertainty_threshold=0.1):
self.model = model
self.device = device
self.uncertainty_threshold = uncertainty_threshold
self.model.eval()
def calculate_uncertainty(self, predictions):
"""Calculate prediction uncertainty"""
if predictions.dim() == 1:
# For single prediction
entropy = -torch.sum(predictions * torch.log(predictions + 1e-8))
else:
# For multiple predictions
mean_pred = torch.mean(predictions, dim=0)
entropy = -torch.sum(mean_pred * torch.log(mean_pred + 1e-8))
return entropy.item()
def predict_adaptive_tta(self, image, max_augmentations=8):
"""
Perform adaptive TTA - add augmentations until uncertainty is low
"""
augmentations = [
GeometricTTA.horizontal_flip,
GeometricTTA.vertical_flip,
GeometricTTA.rotation_90,
GeometricTTA.rotation_180,
lambda x: transforms.ColorJitter(brightness=0.1)(x),
lambda x: transforms.ColorJitter(contrast=0.1)(x),
lambda x: self._add_noise(x, 0.01),
lambda x: self._scale_image(x, 1.1)
]
predictions = []
with torch.no_grad():
# Start with original image
original_pred = self._get_prediction(image)
predictions.append(original_pred)
uncertainty = self.calculate_uncertainty(torch.stack(predictions))
# Add augmentations until uncertainty is low or max reached
for i, aug_fn in enumerate(augmentations):
if uncertainty < self.uncertainty_threshold or i >= max_augmentations:
break
aug_image = aug_fn(image)
aug_pred = self._get_prediction(aug_image)
predictions.append(aug_pred)
uncertainty = self.calculate_uncertainty(torch.stack(predictions))
# Return average prediction and number of augmentations used
final_pred = torch.mean(torch.stack(predictions), dim=0)
return final_pred, len(predictions), uncertainty
def _get_prediction(self, image):
"""Get model prediction"""
if isinstance(image, Image.Image):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image).unsqueeze(0)
image = image.to(self.device)
output = self.model(image)
if output.dim() == 2 and output.size(1) over 1:
output = F.softmax(output, dim=1)
return output.squeeze(0)
def _add_noise(self, image, std):
"""Add Gaussian noise"""
if isinstance(image, Image.Image):
image = transforms.functional.to_tensor(image)
noise = torch.randn_like(image) * std
noisy_image = torch.clamp(image + noise, 0, 1)
return transforms.functional.to_pil_image(noisy_image)
def _scale_image(self, image, scale):
"""Scale image"""
if isinstance(image, Image.Image):
w, h = image.size
new_size = (int(w * scale), int(h * scale))
return image.resize(new_size, Image.BILINEAR)
return image
TTA Evaluation and Analysis
class TTAAnalyzer:
"""Analyze TTA performance and effectiveness"""
def __init__(self, model, device='cuda'):
self.model = model
self.device = device
def analyze_tta_impact(self, test_loader, augmentations, num_samples=100):
"""Analyze the impact of TTA on model performance"""
self.model.eval()
results = {
'without_tta': [],
'with_tta': [],
'confidence_without': [],
'confidence_with': [],
'uncertainty_reduction': []
}
tta = TestTimeAugmentation(self.model, self.device)
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
if batch_idx >= num_samples:
break
for i in range(data.size(0)):
image = data[i]
true_label = target[i].item()
# Prediction without TTA
pred_without = self.model(image.unsqueeze(0).to(self.device))
pred_without = F.softmax(pred_without, dim=1).squeeze(0)
# Prediction with TTA
pil_image = transforms.ToPILImage()(image)
pred_with, all_preds = tta.predict_with_tta(pil_image, augmentations)
# Calculate accuracies
acc_without = int(torch.argmax(pred_without).item() == true_label)
acc_with = int(torch.argmax(pred_with).item() == true_label)
# Calculate confidences
conf_without = torch.max(pred_without).item()
conf_with = torch.max(pred_with).item()
# Calculate uncertainty reduction
uncertainty_without = self._calculate_entropy(pred_without)
uncertainty_with = self._calculate_entropy(pred_with)
uncertainty_reduction = uncertainty_without - uncertainty_with
# Store results
results['without_tta'].append(acc_without)
results['with_tta'].append(acc_with)
results['confidence_without'].append(conf_without)
results['confidence_with'].append(conf_with)
results['uncertainty_reduction'].append(uncertainty_reduction)
# Calculate summary statistics
summary = {
'accuracy_without_tta': np.mean(results['without_tta']),
'accuracy_with_tta': np.mean(results['with_tta']),
'accuracy_improvement': np.mean(results['with_tta']) - np.mean(results['without_tta']),
'avg_confidence_without': np.mean(results['confidence_without']),
'avg_confidence_with': np.mean(results['confidence_with']),
'avg_uncertainty_reduction': np.mean(results['uncertainty_reduction'])
}
return summary, results
def _calculate_entropy(self, probs):
"""Calculate entropy of probability distribution"""
return -torch.sum(probs * torch.log(probs + 1e-8)).item()
def visualize_tta_analysis(self, summary, results):
"""Visualize TTA analysis results"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# Accuracy comparison
axes[0, 0].bar(['Without TTA', 'With TTA'],
[summary['accuracy_without_tta'], summary['accuracy_with_tta']])
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].set_title('Accuracy Comparison')
axes[0, 0].set_ylim(0, 1)
# Confidence comparison
axes[0, 1].hist(results['confidence_without'], alpha=0.7, label='Without TTA', bins=20)
axes[0, 1].hist(results['confidence_with'], alpha=0.7, label='With TTA', bins=20)
axes[0, 1].set_xlabel('Confidence')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Confidence Distribution')
axes[0, 1].legend()
# Uncertainty reduction
axes[1, 0].hist(results['uncertainty_reduction'], bins=20)
axes[1, 0].set_xlabel('Uncertainty Reduction')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title('Uncertainty Reduction Distribution')
# Improvement vs original confidence
axes[1, 1].scatter(results['confidence_without'],
np.array(results['with_tta']) - np.array(results['without_tta']))
axes[1, 1].set_xlabel('Original Confidence')
axes[1, 1].set_ylabel('Accuracy Improvement')
axes[1, 1].set_title('TTA Improvement vs Original Confidence')
axes[1, 1].axhline(y=0, color='r', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()
Efficient TTA Implementation
class EfficientTTA:
"""Memory and time efficient TTA implementation"""
def __init__(self, model, device='cuda', batch_size=8):
self.model = model
self.device = device
self.batch_size = batch_size
self.model.eval()
def predict_batch_tta(self, images, augmentations):
"""Perform TTA on a batch of images efficiently"""
batch_predictions = []
# Process in batches to manage memory
for i in range(0, len(images), self.batch_size):
batch_images = images[i:i + self.batch_size]
batch_pred = self._process_batch(batch_images, augmentations)
batch_predictions.extend(batch_pred)
return batch_predictions
def _process_batch(self, images, augmentations):
"""Process a single batch with TTA"""
all_augmented = []
# Prepare all augmented versions
for image in images:
# Original image
all_augmented.append(image)
# Augmented versions
for aug_fn in augmentations:
aug_image = aug_fn(image)
all_augmented.append(aug_image)
# Convert to tensor batch
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
batch_tensor = torch.stack([transform(img) for img in all_augmented])
batch_tensor = batch_tensor.to(self.device)
# Get predictions for entire batch
with torch.no_grad():
outputs = self.model(batch_tensor)
if outputs.dim() == 2 and outputs.size(1) over 1:
outputs = F.softmax(outputs, dim=1)
# Reshape and average predictions
num_augs = len(augmentations) + 1 # +1 for original
outputs = outputs.view(len(images), num_augs, -1)
averaged_outputs = torch.mean(outputs, dim=1)
return averaged_outputs.cpu()
def predict_with_memory_limit(self, image, augmentations, memory_limit_mb=1000):
"""Perform TTA with memory constraints"""
import psutil
import gc
predictions = []
# Monitor memory usage
process = psutil.Process()
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
with torch.no_grad():
# Original prediction
pred = self._get_prediction(image)
predictions.append(pred)
for aug_fn in augmentations:
current_memory = process.memory_info().rss / 1024 / 1024
if current_memory - initial_memory > memory_limit_mb:
# Clear cache and garbage collect
torch.cuda.empty_cache()
gc.collect()
break
aug_image = aug_fn(image)
aug_pred = self._get_prediction(aug_image)
predictions.append(aug_pred)
# Average predictions
final_pred = torch.mean(torch.stack(predictions), dim=0)
return final_pred, len(predictions)
def _get_prediction(self, image):
"""Get prediction for single image"""
if isinstance(image, Image.Image):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = transform(image).unsqueeze(0)
image = image.to(self.device)
output = self.model(image)
if output.dim() == 2 and output.size(1) over 1:
output = F.softmax(output, dim=1)
return output.squeeze(0)
Complete TTA Pipeline
class CompleteTTAPipeline:
"""Complete TTA pipeline with configurable strategies"""
def __init__(self, model, device='cuda', config=None):
self.model = model
self.device = device
self.model.eval()
# Default configuration
self.config = config or {
'geometric_transforms': True,
'color_transforms': False,
'noise_transforms': False,
'multi_scale': False,
'crop_based': False,
'adaptive': False,
'scales': [0.9, 1.0, 1.1],
'noise_levels': [0.01, 0.02],
'uncertainty_threshold': 0.1,
'max_augmentations': 8,
'aggregation': 'mean' # 'mean', 'geometric_mean', 'voting'
}
def get_augmentations(self):
"""Get augmentations based on configuration"""
augmentations = []
if self.config['geometric_transforms']:
augmentations.extend([
GeometricTTA.horizontal_flip,
GeometricTTA.vertical_flip,
GeometricTTA.rotation_90,
GeometricTTA.rotation_180
])
if self.config['color_transforms']:
augmentations.extend([
lambda x: transforms.ColorJitter(brightness=0.1)(x),
lambda x: transforms.ColorJitter(contrast=0.1)(x),
lambda x: transforms.ColorJitter(saturation=0.1)(x)
])
if self.config['noise_transforms']:
noise_tta = NoiseTTA(self.config['noise_levels'])
augmentations.extend(noise_tta.generate_noise_transforms())
return augmentations[:self.config['max_augmentations']]
def predict(self, image, return_details=False):
"""Main prediction method"""
augmentations = self.get_augmentations()
if self.config['adaptive']:
adaptive_tta = AdaptiveTTA(self.model, self.device,
self.config['uncertainty_threshold'])
final_pred, num_augs, uncertainty = adaptive_tta.predict_adaptive_tta(
image, self.config['max_augmentations']
)
if return_details:
return final_pred, {'num_augmentations': num_augs, 'uncertainty': uncertainty}
return final_pred
elif self.config['multi_scale']:
multi_scale_tta = MultiScaleTTA(self.config['scales'])
final_pred = multi_scale_tta.predict_multi_scale(self.model, image, self.device)
if return_details:
return final_pred, {'num_augmentations': len(self.config['scales'])}
return final_pred
elif self.config['crop_based']:
crop_tta = CropBasedTTA()
final_pred, all_preds = crop_tta.predict_with_crops(self.model, image, self.device)
if return_details:
return final_pred, {'num_augmentations': len(all_preds)}
return final_pred
else:
# Standard TTA
tta = TestTimeAugmentation(self.model, self.device)
final_pred, all_preds = tta.predict_with_tta(
image, augmentations, self.config['aggregation']
)
if return_details:
return final_pred, {'num_augmentations': len(all_preds)}
return final_pred
# Usage example
def main():
# Load model
model = torchvision.models.resnet50(pretrained=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Configure TTA
tta_config = {
'geometric_transforms': True,
'color_transforms': True,
'multi_scale': True,
'scales': [0.8, 0.9, 1.0, 1.1, 1.2],
'aggregation': 'mean',
'max_augmentations': 10
}
# Create TTA pipeline
tta_pipeline = CompleteTTAPipeline(model, device, tta_config)
# Load test image
image = Image.open('test_image.jpg')
# Get prediction with TTA
prediction, details = tta_pipeline.predict(image, return_details=True)
print(f"Prediction: {torch.argmax(prediction).item()}")
print(f"Confidence: {torch.max(prediction).item():.4f}")
print(f"Number of augmentations used: {details['num_augmentations']}")
if __name__ == "__main__":
main()
Best Practices and Guidelines
When to Use TTA
- Competition settings where small improvements matter
- Critical applications requiring high reliability
- Limited training data scenarios
- When computational budget allows at inference time
TTA Selection Guidelines
def select_tta_strategy(task_type, model_type, computational_budget):
"""Guide for selecting appropriate TTA strategy"""
strategies = {
'classification': {
'lightweight': ['horizontal_flip'],
'standard': ['horizontal_flip', 'vertical_flip', 'rotation_90', 'rotation_180'],
'heavy': ['d4_transforms', 'multi_scale', 'color_jitter']
},
'segmentation': {
'lightweight': ['horizontal_flip', 'vertical_flip'],
'standard': ['d4_transforms'],
'heavy': ['d4_transforms', 'multi_scale']
},
'object_detection': {
'lightweight': ['horizontal_flip'],
'standard': ['horizontal_flip', 'multi_scale'],
'heavy': ['horizontal_flip', 'multi_scale', 'crop_based']
}
}
if computational_budget == 'low':
budget_level = 'lightweight'
elif computational_budget == 'medium':
budget_level = 'standard'
else:
budget_level = 'heavy'
return strategies.get(task_type, {}).get(budget_level, ['horizontal_flip'])
Conclusion
Test-Time Augmentation is a powerful technique that can provide significant performance improvements with relatively simple implementation. Key takeaways:
Benefits
- Improved accuracy with minimal code changes
- Better uncertainty estimation through prediction averaging
- Increased robustness to input variations
- Model-agnostic approach
Considerations
- Computational cost increases linearly with augmentations
- Diminishing returns beyond certain number of augmentations
- Memory requirements can be significant
- Task-specific augmentation selection is important
Best Practices
- Start with simple geometric transforms
- Use domain knowledge for augmentation selection
- Monitor computational vs. accuracy trade-offs
- Consider adaptive strategies for efficiency
TTA remains one of the most effective techniques for improving model performance at inference time, making it an essential tool in the computer vision practitioner's toolkit.
References
- Krizhevsky, A., et al. (2012). "ImageNet Classification with Deep Convolutional Neural Networks."
- Wang, G., et al. (2019). "Test-time augmentation for deep learning-based cell segmentation on microscopy images."
- Shanmugam, D., et al. (2021). "Better Aggregation in Test-Time Augmentation."
- Lyzhov, A., et al. (2020). "Greedy Policy Search: A Simple Baseline for Learnable Test-Time Augmentation."