Test time Augmentation using PyTorch

Soumo Chatterjee
Analytics Vidhya
Published in
3 min readDec 28, 2020

--

In image classification, while we are going to predict classes for our test set images after training our model, we will generate the confidence probability for each test image for n number of times and finally we will assign the max average value among all the prediction classes to the image. This is called test time augmentation.

To achieve this in PyTorch, first we are going to define the test dataset class like this

class test_Dataset(torch.utils.data.Dataset):
def __init__(self, ids, image_ids):
self.ids = ids
self.image_ids = image_ids # list of testset image ids
#test data augmentations
self.aug = albumentations.Compose([
albumentations.RandomResizedCrop(256, 256),
albumentations.Transpose(p=0.5),
albumentations.HorizontalFlip(p=0.5),
albumentations.VerticalFlip(p=0.5),
albumentations.HueSaturationValue(
hue_shift_limit=0.2,
sat_shift_limit=0.2,
val_shift_limit=0.2,
p=0.5
),
albumentations.RandomBrightnessContrast(
brightness_limit=(-0.1,0.1),
contrast_limit=(-0.1, 0.1),
p=0.5
),
albumentations.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255.0,
p=1.0
)
], p=1.)
def __len__(self):
return len(self.ids)

def __getitem__(self, index):
# converting jpg format of images to numpy array
img = np.array(Image.open('_PATH_' + self.image_ids[index]))
#Applying augmentations to numpy array
img = self.aug(image = img)['image']
# converting to pytorch image format & 2,0,1 because pytorch excepts image channel first then dimension of image
img = np.transpose(img , (2,0,1)).astype(np.float32)

# finally returning image tensor and its image id
return torch.tensor(img, dtype = torch.float) , self.image_ids[index]
source - https://miro.medium.com/max/850/1*ae1tW5ngf1zhPRyh7aaM1Q.png

After it we will call this :

test_data = test_Dataset(ids = [i for i in range(len(YOUR_LIST))], image_ids = YOUR_LIST)

And if we try to run this n number of times for any single image id, then we are going to get a different image every time because of the augmentations applied on it.

idx = ANY_IMAGE_ID(from 0 to length of YOUR_LIST)print(test_data[idx][1])img = test_data[idx][0]
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))

Now wrap the test_data in a data loader and then we are ready for making predictions with the help of this code

no_of_times_we_wanna_do_prediction_for_one_image = any integer value
final_predictions = None
for j in range(no_of_times_we_wanna_do_prediction_for_one_image):
for image,image_id in test_dataloader:
image = image.to(device, dtype=torch.float)

with torch.no_grad():
preds = model(image)
temporary_predictions = None
for p in preds:
if temporary_predictions is None:
temporary_predictions = p
else:
temp_preds = np.vstack((temporary_predictions, p))
if final_preds is None:
final_predictions = temporary_predictions
else:
final_predictions += temporary_predictions
final_predictions /= no_of_times_we_wanna_do_prediction_for_one_imagefinal_predictions = final_predictions.detach().cpu().numpy()final_list_of_all_predictions = np.argmax(final_predictions)

Now we will have the final_list_of_all_predictions which will correspond to the prediction list of all test images as our answer. I hope you have understood this. If you have any questions comments or concerns , please post it out on the comment section of this article and until then enjoy learning.

--

--

Soumo Chatterjee
Analytics Vidhya

Machine learning and Deep Learning Enthusiast | | Mindtree Mind | | Python Lover