pytorch-image-models
997 строк · 34.7 Кб
1""" AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch
2
3This code implements the searched ImageNet policies with various tweaks and improvements and
4does not include any of the search code.
5
6AA and RA Implementation adapted from:
7https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
8
9AugMix adapted from:
10https://github.com/google-research/augmix
11
123-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md
13
14Papers:
15AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
16Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
17RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
18AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
193-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118
20
21Hacked together by / Copyright 2019, Ross Wightman
22"""
23import random
24import math
25import re
26from functools import partial
27from typing import Dict, List, Optional, Union
28
29from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
30import PIL
31import numpy as np
32
33
34_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
35
36_FILL = (128, 128, 128)
37
38_LEVEL_DENOM = 10. # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments
39
40_HPARAMS_DEFAULT = dict(
41translate_const=250,
42img_mean=_FILL,
43)
44
45if hasattr(Image, "Resampling"):
46_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC)
47_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
48else:
49_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
50_DEFAULT_INTERPOLATION = Image.BICUBIC
51
52
53def _interpolation(kwargs):
54interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
55if isinstance(interpolation, (list, tuple)):
56return random.choice(interpolation)
57return interpolation
58
59
60def _check_args_tf(kwargs):
61if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
62kwargs.pop('fillcolor')
63kwargs['resample'] = _interpolation(kwargs)
64
65
66def shear_x(img, factor, **kwargs):
67_check_args_tf(kwargs)
68return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
69
70
71def shear_y(img, factor, **kwargs):
72_check_args_tf(kwargs)
73return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
74
75
76def translate_x_rel(img, pct, **kwargs):
77pixels = pct * img.size[0]
78_check_args_tf(kwargs)
79return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
80
81
82def translate_y_rel(img, pct, **kwargs):
83pixels = pct * img.size[1]
84_check_args_tf(kwargs)
85return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
86
87
88def translate_x_abs(img, pixels, **kwargs):
89_check_args_tf(kwargs)
90return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
91
92
93def translate_y_abs(img, pixels, **kwargs):
94_check_args_tf(kwargs)
95return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
96
97
98def rotate(img, degrees, **kwargs):
99_check_args_tf(kwargs)
100if _PIL_VER >= (5, 2):
101return img.rotate(degrees, **kwargs)
102if _PIL_VER >= (5, 0):
103w, h = img.size
104post_trans = (0, 0)
105rotn_center = (w / 2.0, h / 2.0)
106angle = -math.radians(degrees)
107matrix = [
108round(math.cos(angle), 15),
109round(math.sin(angle), 15),
1100.0,
111round(-math.sin(angle), 15),
112round(math.cos(angle), 15),
1130.0,
114]
115
116def transform(x, y, matrix):
117(a, b, c, d, e, f) = matrix
118return a * x + b * y + c, d * x + e * y + f
119
120matrix[2], matrix[5] = transform(
121-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
122)
123matrix[2] += rotn_center[0]
124matrix[5] += rotn_center[1]
125return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
126return img.rotate(degrees, resample=kwargs['resample'])
127
128
129def auto_contrast(img, **__):
130return ImageOps.autocontrast(img)
131
132
133def invert(img, **__):
134return ImageOps.invert(img)
135
136
137def equalize(img, **__):
138return ImageOps.equalize(img)
139
140
141def solarize(img, thresh, **__):
142return ImageOps.solarize(img, thresh)
143
144
145def solarize_add(img, add, thresh=128, **__):
146lut = []
147for i in range(256):
148if i < thresh:
149lut.append(min(255, i + add))
150else:
151lut.append(i)
152
153if img.mode in ("L", "RGB"):
154if img.mode == "RGB" and len(lut) == 256:
155lut = lut + lut + lut
156return img.point(lut)
157
158return img
159
160
161def posterize(img, bits_to_keep, **__):
162if bits_to_keep >= 8:
163return img
164return ImageOps.posterize(img, bits_to_keep)
165
166
167def contrast(img, factor, **__):
168return ImageEnhance.Contrast(img).enhance(factor)
169
170
171def color(img, factor, **__):
172return ImageEnhance.Color(img).enhance(factor)
173
174
175def brightness(img, factor, **__):
176return ImageEnhance.Brightness(img).enhance(factor)
177
178
179def sharpness(img, factor, **__):
180return ImageEnhance.Sharpness(img).enhance(factor)
181
182
183def gaussian_blur(img, factor, **__):
184img = img.filter(ImageFilter.GaussianBlur(radius=factor))
185return img
186
187
188def gaussian_blur_rand(img, factor, **__):
189radius_min = 0.1
190radius_max = 2.0
191img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
192return img
193
194
195def desaturate(img, factor, **_):
196factor = min(1., max(0., 1. - factor))
197# enhance factor 0 = grayscale, 1.0 = no-change
198return ImageEnhance.Color(img).enhance(factor)
199
200
201def _randomly_negate(v):
202"""With 50% prob, negate the value"""
203return -v if random.random() > 0.5 else v
204
205
206def _rotate_level_to_arg(level, _hparams):
207# range [-30, 30]
208level = (level / _LEVEL_DENOM) * 30.
209level = _randomly_negate(level)
210return level,
211
212
213def _enhance_level_to_arg(level, _hparams):
214# range [0.1, 1.9]
215return (level / _LEVEL_DENOM) * 1.8 + 0.1,
216
217
218def _enhance_increasing_level_to_arg(level, _hparams):
219# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
220# range [0.1, 1.9] if level <= _LEVEL_DENOM
221level = (level / _LEVEL_DENOM) * .9
222level = max(0.1, 1.0 + _randomly_negate(level)) # keep it >= 0.1
223return level,
224
225
226def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
227level = (level / _LEVEL_DENOM)
228level = min_val + (max_val - min_val) * level
229if clamp:
230level = max(min_val, min(max_val, level))
231return level,
232
233
234def _shear_level_to_arg(level, _hparams):
235# range [-0.3, 0.3]
236level = (level / _LEVEL_DENOM) * 0.3
237level = _randomly_negate(level)
238return level,
239
240
241def _translate_abs_level_to_arg(level, hparams):
242translate_const = hparams['translate_const']
243level = (level / _LEVEL_DENOM) * float(translate_const)
244level = _randomly_negate(level)
245return level,
246
247
248def _translate_rel_level_to_arg(level, hparams):
249# default range [-0.45, 0.45]
250translate_pct = hparams.get('translate_pct', 0.45)
251level = (level / _LEVEL_DENOM) * translate_pct
252level = _randomly_negate(level)
253return level,
254
255
256def _posterize_level_to_arg(level, _hparams):
257# As per Tensorflow TPU EfficientNet impl
258# range [0, 4], 'keep 0 up to 4 MSB of original image'
259# intensity/severity of augmentation decreases with level
260return int((level / _LEVEL_DENOM) * 4),
261
262
263def _posterize_increasing_level_to_arg(level, hparams):
264# As per Tensorflow models research and UDA impl
265# range [4, 0], 'keep 4 down to 0 MSB of original image',
266# intensity/severity of augmentation increases with level
267return 4 - _posterize_level_to_arg(level, hparams)[0],
268
269
270def _posterize_original_level_to_arg(level, _hparams):
271# As per original AutoAugment paper description
272# range [4, 8], 'keep 4 up to 8 MSB of image'
273# intensity/severity of augmentation decreases with level
274return int((level / _LEVEL_DENOM) * 4) + 4,
275
276
277def _solarize_level_to_arg(level, _hparams):
278# range [0, 256]
279# intensity/severity of augmentation decreases with level
280return min(256, int((level / _LEVEL_DENOM) * 256)),
281
282
283def _solarize_increasing_level_to_arg(level, _hparams):
284# range [0, 256]
285# intensity/severity of augmentation increases with level
286return 256 - _solarize_level_to_arg(level, _hparams)[0],
287
288
289def _solarize_add_level_to_arg(level, _hparams):
290# range [0, 110]
291return min(128, int((level / _LEVEL_DENOM) * 110)),
292
293
294LEVEL_TO_ARG = {
295'AutoContrast': None,
296'Equalize': None,
297'Invert': None,
298'Rotate': _rotate_level_to_arg,
299# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
300'Posterize': _posterize_level_to_arg,
301'PosterizeIncreasing': _posterize_increasing_level_to_arg,
302'PosterizeOriginal': _posterize_original_level_to_arg,
303'Solarize': _solarize_level_to_arg,
304'SolarizeIncreasing': _solarize_increasing_level_to_arg,
305'SolarizeAdd': _solarize_add_level_to_arg,
306'Color': _enhance_level_to_arg,
307'ColorIncreasing': _enhance_increasing_level_to_arg,
308'Contrast': _enhance_level_to_arg,
309'ContrastIncreasing': _enhance_increasing_level_to_arg,
310'Brightness': _enhance_level_to_arg,
311'BrightnessIncreasing': _enhance_increasing_level_to_arg,
312'Sharpness': _enhance_level_to_arg,
313'SharpnessIncreasing': _enhance_increasing_level_to_arg,
314'ShearX': _shear_level_to_arg,
315'ShearY': _shear_level_to_arg,
316'TranslateX': _translate_abs_level_to_arg,
317'TranslateY': _translate_abs_level_to_arg,
318'TranslateXRel': _translate_rel_level_to_arg,
319'TranslateYRel': _translate_rel_level_to_arg,
320'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0),
321'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0),
322'GaussianBlurRand': _minmax_level_to_arg,
323}
324
325
326NAME_TO_OP = {
327'AutoContrast': auto_contrast,
328'Equalize': equalize,
329'Invert': invert,
330'Rotate': rotate,
331'Posterize': posterize,
332'PosterizeIncreasing': posterize,
333'PosterizeOriginal': posterize,
334'Solarize': solarize,
335'SolarizeIncreasing': solarize,
336'SolarizeAdd': solarize_add,
337'Color': color,
338'ColorIncreasing': color,
339'Contrast': contrast,
340'ContrastIncreasing': contrast,
341'Brightness': brightness,
342'BrightnessIncreasing': brightness,
343'Sharpness': sharpness,
344'SharpnessIncreasing': sharpness,
345'ShearX': shear_x,
346'ShearY': shear_y,
347'TranslateX': translate_x_abs,
348'TranslateY': translate_y_abs,
349'TranslateXRel': translate_x_rel,
350'TranslateYRel': translate_y_rel,
351'Desaturate': desaturate,
352'GaussianBlur': gaussian_blur,
353'GaussianBlurRand': gaussian_blur_rand,
354}
355
356
357class AugmentOp:
358
359def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
360hparams = hparams or _HPARAMS_DEFAULT
361self.name = name
362self.aug_fn = NAME_TO_OP[name]
363self.level_fn = LEVEL_TO_ARG[name]
364self.prob = prob
365self.magnitude = magnitude
366self.hparams = hparams.copy()
367self.kwargs = dict(
368fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
369resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
370)
371
372# If magnitude_std is > 0, we introduce some randomness
373# in the usually fixed policy and sample magnitude from a normal distribution
374# with mean `magnitude` and std-dev of `magnitude_std`.
375# NOTE This is my own hack, being tested, not in papers or reference impls.
376# If magnitude_std is inf, we sample magnitude from a uniform distribution
377self.magnitude_std = self.hparams.get('magnitude_std', 0)
378self.magnitude_max = self.hparams.get('magnitude_max', None)
379
380def __call__(self, img):
381if self.prob < 1.0 and random.random() > self.prob:
382return img
383magnitude = self.magnitude
384if self.magnitude_std > 0:
385# magnitude randomization enabled
386if self.magnitude_std == float('inf'):
387# inf == uniform sampling
388magnitude = random.uniform(0, magnitude)
389elif self.magnitude_std > 0:
390magnitude = random.gauss(magnitude, self.magnitude_std)
391# default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
392# setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
393upper_bound = self.magnitude_max or _LEVEL_DENOM
394magnitude = max(0., min(magnitude, upper_bound))
395level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
396return self.aug_fn(img, *level_args, **self.kwargs)
397
398def __repr__(self):
399fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
400fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
401if self.magnitude_max is not None:
402fs += f', mmax={self.magnitude_max}'
403fs += ')'
404return fs
405
406
407def auto_augment_policy_v0(hparams):
408# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
409policy = [
410[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
411[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
412[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
413[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
414[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
415[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
416[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
417[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
418[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
419[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
420[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
421[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
422[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
423[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
424[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
425[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
426[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
427[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
428[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
429[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
430[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
431[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
432[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
433[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
434[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
435]
436pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
437return pc
438
439
440def auto_augment_policy_v0r(hparams):
441# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
442# in Google research implementation (number of bits discarded increases with magnitude)
443policy = [
444[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
445[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
446[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
447[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
448[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
449[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
450[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
451[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
452[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
453[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
454[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
455[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
456[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
457[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
458[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
459[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
460[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
461[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
462[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
463[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
464[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
465[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
466[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
467[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
468[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
469]
470pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
471return pc
472
473
474def auto_augment_policy_original(hparams):
475# ImageNet policy from https://arxiv.org/abs/1805.09501
476policy = [
477[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
478[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
479[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
480[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
481[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
482[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
483[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
484[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
485[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
486[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
487[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
488[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
489[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
490[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
491[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
492[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
493[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
494[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
495[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
496[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
497[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
498[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
499[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
500[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
501[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
502]
503pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
504return pc
505
506
507def auto_augment_policy_originalr(hparams):
508# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
509policy = [
510[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
511[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
512[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
513[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
514[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
515[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
516[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
517[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
518[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
519[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
520[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
521[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
522[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
523[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
524[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
525[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
526[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
527[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
528[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
529[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
530[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
531[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
532[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
533[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
534[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
535]
536pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
537return pc
538
539
540def auto_augment_policy_3a(hparams):
541policy = [
542[('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude
543[('Desaturate', 1.0, 10)], # grayscale at 10 magnitude
544[('GaussianBlurRand', 1.0, 10)],
545]
546pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
547return pc
548
549
550def auto_augment_policy(name='v0', hparams=None):
551hparams = hparams or _HPARAMS_DEFAULT
552if name == 'original':
553return auto_augment_policy_original(hparams)
554if name == 'originalr':
555return auto_augment_policy_originalr(hparams)
556if name == 'v0':
557return auto_augment_policy_v0(hparams)
558if name == 'v0r':
559return auto_augment_policy_v0r(hparams)
560if name == '3a':
561return auto_augment_policy_3a(hparams)
562assert False, f'Unknown AA policy {name}'
563
564
565class AutoAugment:
566
567def __init__(self, policy):
568self.policy = policy
569
570def __call__(self, img):
571sub_policy = random.choice(self.policy)
572for op in sub_policy:
573img = op(img)
574return img
575
576def __repr__(self):
577fs = self.__class__.__name__ + '(policy='
578for p in self.policy:
579fs += '\n\t['
580fs += ', '.join([str(op) for op in p])
581fs += ']'
582fs += ')'
583return fs
584
585
586def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
587"""
588Create a AutoAugment transform
589
590Args:
591config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
592dashes ('-').
593The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
594
595The remaining sections:
596'mstd' - float std deviation of magnitude noise applied
597Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
598
599hparams: Other hparams (kwargs) for the AutoAugmentation scheme
600
601Returns:
602A PyTorch compatible Transform
603"""
604config = config_str.split('-')
605policy_name = config[0]
606config = config[1:]
607for c in config:
608cs = re.split(r'(\d.*)', c)
609if len(cs) < 2:
610continue
611key, val = cs[:2]
612if key == 'mstd':
613# noise param injected via hparams for now
614hparams.setdefault('magnitude_std', float(val))
615else:
616assert False, 'Unknown AutoAugment config section'
617aa_policy = auto_augment_policy(policy_name, hparams=hparams)
618return AutoAugment(aa_policy)
619
620
621_RAND_TRANSFORMS = [
622'AutoContrast',
623'Equalize',
624'Invert',
625'Rotate',
626'Posterize',
627'Solarize',
628'SolarizeAdd',
629'Color',
630'Contrast',
631'Brightness',
632'Sharpness',
633'ShearX',
634'ShearY',
635'TranslateXRel',
636'TranslateYRel',
637# 'Cutout' # NOTE I've implement this as random erasing separately
638]
639
640
641_RAND_INCREASING_TRANSFORMS = [
642'AutoContrast',
643'Equalize',
644'Invert',
645'Rotate',
646'PosterizeIncreasing',
647'SolarizeIncreasing',
648'SolarizeAdd',
649'ColorIncreasing',
650'ContrastIncreasing',
651'BrightnessIncreasing',
652'SharpnessIncreasing',
653'ShearX',
654'ShearY',
655'TranslateXRel',
656'TranslateYRel',
657# 'Cutout' # NOTE I've implement this as random erasing separately
658]
659
660
661_RAND_3A = [
662'SolarizeIncreasing',
663'Desaturate',
664'GaussianBlur',
665]
666
667
668_RAND_WEIGHTED_3A = {
669'SolarizeIncreasing': 6,
670'Desaturate': 6,
671'GaussianBlur': 6,
672'Rotate': 3,
673'ShearX': 2,
674'ShearY': 2,
675'PosterizeIncreasing': 1,
676'AutoContrast': 1,
677'ColorIncreasing': 1,
678'SharpnessIncreasing': 1,
679'ContrastIncreasing': 1,
680'BrightnessIncreasing': 1,
681'Equalize': 1,
682'Invert': 1,
683}
684
685
686# These experimental weights are based loosely on the relative improvements mentioned in paper.
687# They may not result in increased performance, but could likely be tuned to so.
688_RAND_WEIGHTED_0 = {
689'Rotate': 3,
690'ShearX': 2,
691'ShearY': 2,
692'TranslateXRel': 1,
693'TranslateYRel': 1,
694'ColorIncreasing': .25,
695'SharpnessIncreasing': 0.25,
696'AutoContrast': 0.25,
697'SolarizeIncreasing': .05,
698'SolarizeAdd': .05,
699'ContrastIncreasing': .05,
700'BrightnessIncreasing': .05,
701'Equalize': .05,
702'PosterizeIncreasing': 0.05,
703'Invert': 0.05,
704}
705
706
707def _get_weighted_transforms(transforms: Dict):
708transforms, probs = list(zip(*transforms.items()))
709probs = np.array(probs)
710probs = probs / np.sum(probs)
711return transforms, probs
712
713
714def rand_augment_choices(name: str, increasing=True):
715if name == 'weights':
716return _RAND_WEIGHTED_0
717if name == '3aw':
718return _RAND_WEIGHTED_3A
719if name == '3a':
720return _RAND_3A
721return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
722
723
724def rand_augment_ops(
725magnitude: Union[int, float] = 10,
726prob: float = 0.5,
727hparams: Optional[Dict] = None,
728transforms: Optional[Union[Dict, List]] = None,
729):
730hparams = hparams or _HPARAMS_DEFAULT
731transforms = transforms or _RAND_TRANSFORMS
732return [AugmentOp(
733name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
734
735
736class RandAugment:
737def __init__(self, ops, num_layers=2, choice_weights=None):
738self.ops = ops
739self.num_layers = num_layers
740self.choice_weights = choice_weights
741
742def __call__(self, img):
743# no replacement when using weighted choice
744ops = np.random.choice(
745self.ops,
746self.num_layers,
747replace=self.choice_weights is None,
748p=self.choice_weights,
749)
750for op in ops:
751img = op(img)
752return img
753
754def __repr__(self):
755fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
756for op in self.ops:
757fs += f'\n\t{op}'
758fs += ')'
759return fs
760
761
762def rand_augment_transform(
763config_str: str,
764hparams: Optional[Dict] = None,
765transforms: Optional[Union[str, Dict, List]] = None,
766):
767"""
768Create a RandAugment transform
769
770Args:
771config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
772by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
773The remaining sections, not order sepecific determine
774'm' - integer magnitude of rand augment
775'n' - integer num layers (number of transform ops selected per image)
776'p' - float probability of applying each layer (default 0.5)
777'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
778'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
779'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
780't' - str name of transform set to use
781Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
782'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
783
784hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
785
786Returns:
787A PyTorch compatible Transform
788"""
789magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
790num_layers = 2 # default to 2 ops per image
791increasing = False
792prob = 0.5
793config = config_str.split('-')
794assert config[0] == 'rand'
795config = config[1:]
796for c in config:
797if c.startswith('t'):
798# NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights'
799val = str(c[1:])
800if transforms is None:
801transforms = val
802else:
803# numeric options
804cs = re.split(r'(\d.*)', c)
805if len(cs) < 2:
806continue
807key, val = cs[:2]
808if key == 'mstd':
809# noise param / randomization of magnitude values
810mstd = float(val)
811if mstd > 100:
812# use uniform sampling in 0 to magnitude if mstd is > 100
813mstd = float('inf')
814hparams.setdefault('magnitude_std', mstd)
815elif key == 'mmax':
816# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
817hparams.setdefault('magnitude_max', int(val))
818elif key == 'inc':
819if bool(val):
820increasing = True
821elif key == 'm':
822magnitude = int(val)
823elif key == 'n':
824num_layers = int(val)
825elif key == 'p':
826prob = float(val)
827else:
828assert False, 'Unknown RandAugment config section'
829
830if isinstance(transforms, str):
831transforms = rand_augment_choices(transforms, increasing=increasing)
832elif transforms is None:
833transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
834
835choice_weights = None
836if isinstance(transforms, Dict):
837transforms, choice_weights = _get_weighted_transforms(transforms)
838
839ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms)
840return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
841
842
843_AUGMIX_TRANSFORMS = [
844'AutoContrast',
845'ColorIncreasing', # not in paper
846'ContrastIncreasing', # not in paper
847'BrightnessIncreasing', # not in paper
848'SharpnessIncreasing', # not in paper
849'Equalize',
850'Rotate',
851'PosterizeIncreasing',
852'SolarizeIncreasing',
853'ShearX',
854'ShearY',
855'TranslateXRel',
856'TranslateYRel',
857]
858
859
860def augmix_ops(
861magnitude: Union[int, float] = 10,
862hparams: Optional[Dict] = None,
863transforms: Optional[Union[str, Dict, List]] = None,
864):
865hparams = hparams or _HPARAMS_DEFAULT
866transforms = transforms or _AUGMIX_TRANSFORMS
867return [AugmentOp(
868name,
869prob=1.0,
870magnitude=magnitude,
871hparams=hparams
872) for name in transforms]
873
874
875class AugMixAugment:
876""" AugMix Transform
877Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
878From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
879https://arxiv.org/abs/1912.02781
880"""
881def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
882self.ops = ops
883self.alpha = alpha
884self.width = width
885self.depth = depth
886self.blended = blended # blended mode is faster but not well tested
887
888def _calc_blended_weights(self, ws, m):
889ws = ws * m
890cump = 1.
891rws = []
892for w in ws[::-1]:
893alpha = w / cump
894cump *= (1 - alpha)
895rws.append(alpha)
896return np.array(rws[::-1], dtype=np.float32)
897
898def _apply_blended(self, img, mixing_weights, m):
899# This is my first crack and implementing a slightly faster mixed augmentation. Instead
900# of accumulating the mix for each chain in a Numpy array and then blending with original,
901# it recomputes the blending coefficients and applies one PIL image blend per chain.
902# TODO the results appear in the right ballpark but they differ by more than rounding.
903img_orig = img.copy()
904ws = self._calc_blended_weights(mixing_weights, m)
905for w in ws:
906depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
907ops = np.random.choice(self.ops, depth, replace=True)
908img_aug = img_orig # no ops are in-place, deep copy not necessary
909for op in ops:
910img_aug = op(img_aug)
911img = Image.blend(img, img_aug, w)
912return img
913
914def _apply_basic(self, img, mixing_weights, m):
915# This is a literal adaptation of the paper/official implementation without normalizations and
916# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
917# typical augmentation transforms, could use a GPU / Kornia implementation.
918img_shape = img.size[0], img.size[1], len(img.getbands())
919mixed = np.zeros(img_shape, dtype=np.float32)
920for mw in mixing_weights:
921depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
922ops = np.random.choice(self.ops, depth, replace=True)
923img_aug = img # no ops are in-place, deep copy not necessary
924for op in ops:
925img_aug = op(img_aug)
926mixed += mw * np.asarray(img_aug, dtype=np.float32)
927np.clip(mixed, 0, 255., out=mixed)
928mixed = Image.fromarray(mixed.astype(np.uint8))
929return Image.blend(img, mixed, m)
930
931def __call__(self, img):
932mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
933m = np.float32(np.random.beta(self.alpha, self.alpha))
934if self.blended:
935mixed = self._apply_blended(img, mixing_weights, m)
936else:
937mixed = self._apply_basic(img, mixing_weights, m)
938return mixed
939
940def __repr__(self):
941fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
942for op in self.ops:
943fs += f'\n\t{op}'
944fs += ')'
945return fs
946
947
948def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None):
949""" Create AugMix PyTorch transform
950
951Args:
952config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
953by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
954The remaining sections, not order sepecific determine
955'm' - integer magnitude (severity) of augmentation mix (default: 3)
956'w' - integer width of augmentation chain (default: 3)
957'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
958'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
959'mstd' - float std deviation of magnitude noise applied (default: 0)
960Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
961
962hparams: Other hparams (kwargs) for the Augmentation transforms
963
964Returns:
965A PyTorch compatible Transform
966"""
967magnitude = 3
968width = 3
969depth = -1
970alpha = 1.
971blended = False
972config = config_str.split('-')
973assert config[0] == 'augmix'
974config = config[1:]
975for c in config:
976cs = re.split(r'(\d.*)', c)
977if len(cs) < 2:
978continue
979key, val = cs[:2]
980if key == 'mstd':
981# noise param injected via hparams for now
982hparams.setdefault('magnitude_std', float(val))
983elif key == 'm':
984magnitude = int(val)
985elif key == 'w':
986width = int(val)
987elif key == 'd':
988depth = int(val)
989elif key == 'a':
990alpha = float(val)
991elif key == 'b':
992blended = bool(val)
993else:
994assert False, 'Unknown AugMix config section'
995hparams.setdefault('magnitude_std', float('inf')) # default to uniform sampling (if not set via mstd arg)
996ops = augmix_ops(magnitude=magnitude, hparams=hparams)
997return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
998