pytorch-image-models

Форк
0
/
auto_augment.py 
997 строк · 34.7 Кб
1
""" AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch
2

3
This code implements the searched ImageNet policies with various tweaks and improvements and
4
does not include any of the search code.
5

6
AA and RA Implementation adapted from:
7
    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
8

9
AugMix adapted from:
10
    https://github.com/google-research/augmix
11

12
3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md
13

14
Papers:
15
    AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
16
    Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
17
    RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
18
    AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
19
    3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118
20

21
Hacked together by / Copyright 2019, Ross Wightman
22
"""
23
import random
24
import math
25
import re
26
from functools import partial
27
from typing import Dict, List, Optional, Union
28

29
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
30
import PIL
31
import numpy as np
32

33

34
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
35

36
_FILL = (128, 128, 128)
37

38
_LEVEL_DENOM = 10.  # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments
39

40
_HPARAMS_DEFAULT = dict(
41
    translate_const=250,
42
    img_mean=_FILL,
43
)
44

45
if hasattr(Image, "Resampling"):
46
    _RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC)
47
    _DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
48
else:
49
    _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
50
    _DEFAULT_INTERPOLATION = Image.BICUBIC
51

52

53
def _interpolation(kwargs):
54
    interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
55
    if isinstance(interpolation, (list, tuple)):
56
        return random.choice(interpolation)
57
    return interpolation
58

59

60
def _check_args_tf(kwargs):
61
    if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
62
        kwargs.pop('fillcolor')
63
    kwargs['resample'] = _interpolation(kwargs)
64

65

66
def shear_x(img, factor, **kwargs):
67
    _check_args_tf(kwargs)
68
    return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
69

70

71
def shear_y(img, factor, **kwargs):
72
    _check_args_tf(kwargs)
73
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
74

75

76
def translate_x_rel(img, pct, **kwargs):
77
    pixels = pct * img.size[0]
78
    _check_args_tf(kwargs)
79
    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
80

81

82
def translate_y_rel(img, pct, **kwargs):
83
    pixels = pct * img.size[1]
84
    _check_args_tf(kwargs)
85
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
86

87

88
def translate_x_abs(img, pixels, **kwargs):
89
    _check_args_tf(kwargs)
90
    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
91

92

93
def translate_y_abs(img, pixels, **kwargs):
94
    _check_args_tf(kwargs)
95
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
96

97

98
def rotate(img, degrees, **kwargs):
99
    _check_args_tf(kwargs)
100
    if _PIL_VER >= (5, 2):
101
        return img.rotate(degrees, **kwargs)
102
    if _PIL_VER >= (5, 0):
103
        w, h = img.size
104
        post_trans = (0, 0)
105
        rotn_center = (w / 2.0, h / 2.0)
106
        angle = -math.radians(degrees)
107
        matrix = [
108
            round(math.cos(angle), 15),
109
            round(math.sin(angle), 15),
110
            0.0,
111
            round(-math.sin(angle), 15),
112
            round(math.cos(angle), 15),
113
            0.0,
114
        ]
115

116
        def transform(x, y, matrix):
117
            (a, b, c, d, e, f) = matrix
118
            return a * x + b * y + c, d * x + e * y + f
119

120
        matrix[2], matrix[5] = transform(
121
            -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
122
        )
123
        matrix[2] += rotn_center[0]
124
        matrix[5] += rotn_center[1]
125
        return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
126
    return img.rotate(degrees, resample=kwargs['resample'])
127

128

129
def auto_contrast(img, **__):
130
    return ImageOps.autocontrast(img)
131

132

133
def invert(img, **__):
134
    return ImageOps.invert(img)
135

136

137
def equalize(img, **__):
138
    return ImageOps.equalize(img)
139

140

141
def solarize(img, thresh, **__):
142
    return ImageOps.solarize(img, thresh)
143

144

145
def solarize_add(img, add, thresh=128, **__):
146
    lut = []
147
    for i in range(256):
148
        if i < thresh:
149
            lut.append(min(255, i + add))
150
        else:
151
            lut.append(i)
152

153
    if img.mode in ("L", "RGB"):
154
        if img.mode == "RGB" and len(lut) == 256:
155
            lut = lut + lut + lut
156
        return img.point(lut)
157

158
    return img
159

160

161
def posterize(img, bits_to_keep, **__):
162
    if bits_to_keep >= 8:
163
        return img
164
    return ImageOps.posterize(img, bits_to_keep)
165

166

167
def contrast(img, factor, **__):
168
    return ImageEnhance.Contrast(img).enhance(factor)
169

170

171
def color(img, factor, **__):
172
    return ImageEnhance.Color(img).enhance(factor)
173

174

175
def brightness(img, factor, **__):
176
    return ImageEnhance.Brightness(img).enhance(factor)
177

178

179
def sharpness(img, factor, **__):
180
    return ImageEnhance.Sharpness(img).enhance(factor)
181

182

183
def gaussian_blur(img, factor, **__):
184
    img = img.filter(ImageFilter.GaussianBlur(radius=factor))
185
    return img
186

187

188
def gaussian_blur_rand(img, factor, **__):
189
    radius_min = 0.1
190
    radius_max = 2.0
191
    img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
192
    return img
193

194

195
def desaturate(img, factor, **_):
196
    factor = min(1., max(0., 1. - factor))
197
    # enhance factor 0 = grayscale, 1.0 = no-change
198
    return ImageEnhance.Color(img).enhance(factor)
199

200

201
def _randomly_negate(v):
202
    """With 50% prob, negate the value"""
203
    return -v if random.random() > 0.5 else v
204

205

206
def _rotate_level_to_arg(level, _hparams):
207
    # range [-30, 30]
208
    level = (level / _LEVEL_DENOM) * 30.
209
    level = _randomly_negate(level)
210
    return level,
211

212

213
def _enhance_level_to_arg(level, _hparams):
214
    # range [0.1, 1.9]
215
    return (level / _LEVEL_DENOM) * 1.8 + 0.1,
216

217

218
def _enhance_increasing_level_to_arg(level, _hparams):
219
    # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
220
    # range [0.1, 1.9] if level <= _LEVEL_DENOM
221
    level = (level / _LEVEL_DENOM) * .9
222
    level = max(0.1, 1.0 + _randomly_negate(level))  # keep it >= 0.1
223
    return level,
224

225

226
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
227
    level = (level / _LEVEL_DENOM)
228
    level = min_val + (max_val - min_val) * level
229
    if clamp:
230
        level = max(min_val, min(max_val, level))
231
    return level,
232

233

234
def _shear_level_to_arg(level, _hparams):
235
    # range [-0.3, 0.3]
236
    level = (level / _LEVEL_DENOM) * 0.3
237
    level = _randomly_negate(level)
238
    return level,
239

240

241
def _translate_abs_level_to_arg(level, hparams):
242
    translate_const = hparams['translate_const']
243
    level = (level / _LEVEL_DENOM) * float(translate_const)
244
    level = _randomly_negate(level)
245
    return level,
246

247

248
def _translate_rel_level_to_arg(level, hparams):
249
    # default range [-0.45, 0.45]
250
    translate_pct = hparams.get('translate_pct', 0.45)
251
    level = (level / _LEVEL_DENOM) * translate_pct
252
    level = _randomly_negate(level)
253
    return level,
254

255

256
def _posterize_level_to_arg(level, _hparams):
257
    # As per Tensorflow TPU EfficientNet impl
258
    # range [0, 4], 'keep 0 up to 4 MSB of original image'
259
    # intensity/severity of augmentation decreases with level
260
    return int((level / _LEVEL_DENOM) * 4),
261

262

263
def _posterize_increasing_level_to_arg(level, hparams):
264
    # As per Tensorflow models research and UDA impl
265
    # range [4, 0], 'keep 4 down to 0 MSB of original image',
266
    # intensity/severity of augmentation increases with level
267
    return 4 - _posterize_level_to_arg(level, hparams)[0],
268

269

270
def _posterize_original_level_to_arg(level, _hparams):
271
    # As per original AutoAugment paper description
272
    # range [4, 8], 'keep 4 up to 8 MSB of image'
273
    # intensity/severity of augmentation decreases with level
274
    return int((level / _LEVEL_DENOM) * 4) + 4,
275

276

277
def _solarize_level_to_arg(level, _hparams):
278
    # range [0, 256]
279
    # intensity/severity of augmentation decreases with level
280
    return min(256, int((level / _LEVEL_DENOM) * 256)),
281

282

283
def _solarize_increasing_level_to_arg(level, _hparams):
284
    # range [0, 256]
285
    # intensity/severity of augmentation increases with level
286
    return 256 - _solarize_level_to_arg(level, _hparams)[0],
287

288

289
def _solarize_add_level_to_arg(level, _hparams):
290
    # range [0, 110]
291
    return min(128, int((level / _LEVEL_DENOM) * 110)),
292

293

294
LEVEL_TO_ARG = {
295
    'AutoContrast': None,
296
    'Equalize': None,
297
    'Invert': None,
298
    'Rotate': _rotate_level_to_arg,
299
    # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
300
    'Posterize': _posterize_level_to_arg,
301
    'PosterizeIncreasing': _posterize_increasing_level_to_arg,
302
    'PosterizeOriginal': _posterize_original_level_to_arg,
303
    'Solarize': _solarize_level_to_arg,
304
    'SolarizeIncreasing': _solarize_increasing_level_to_arg,
305
    'SolarizeAdd': _solarize_add_level_to_arg,
306
    'Color': _enhance_level_to_arg,
307
    'ColorIncreasing': _enhance_increasing_level_to_arg,
308
    'Contrast': _enhance_level_to_arg,
309
    'ContrastIncreasing': _enhance_increasing_level_to_arg,
310
    'Brightness': _enhance_level_to_arg,
311
    'BrightnessIncreasing': _enhance_increasing_level_to_arg,
312
    'Sharpness': _enhance_level_to_arg,
313
    'SharpnessIncreasing': _enhance_increasing_level_to_arg,
314
    'ShearX': _shear_level_to_arg,
315
    'ShearY': _shear_level_to_arg,
316
    'TranslateX': _translate_abs_level_to_arg,
317
    'TranslateY': _translate_abs_level_to_arg,
318
    'TranslateXRel': _translate_rel_level_to_arg,
319
    'TranslateYRel': _translate_rel_level_to_arg,
320
    'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0),
321
    'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0),
322
    'GaussianBlurRand': _minmax_level_to_arg,
323
}
324

325

326
NAME_TO_OP = {
327
    'AutoContrast': auto_contrast,
328
    'Equalize': equalize,
329
    'Invert': invert,
330
    'Rotate': rotate,
331
    'Posterize': posterize,
332
    'PosterizeIncreasing': posterize,
333
    'PosterizeOriginal': posterize,
334
    'Solarize': solarize,
335
    'SolarizeIncreasing': solarize,
336
    'SolarizeAdd': solarize_add,
337
    'Color': color,
338
    'ColorIncreasing': color,
339
    'Contrast': contrast,
340
    'ContrastIncreasing': contrast,
341
    'Brightness': brightness,
342
    'BrightnessIncreasing': brightness,
343
    'Sharpness': sharpness,
344
    'SharpnessIncreasing': sharpness,
345
    'ShearX': shear_x,
346
    'ShearY': shear_y,
347
    'TranslateX': translate_x_abs,
348
    'TranslateY': translate_y_abs,
349
    'TranslateXRel': translate_x_rel,
350
    'TranslateYRel': translate_y_rel,
351
    'Desaturate': desaturate,
352
    'GaussianBlur': gaussian_blur,
353
    'GaussianBlurRand': gaussian_blur_rand,
354
}
355

356

357
class AugmentOp:
358

359
    def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
360
        hparams = hparams or _HPARAMS_DEFAULT
361
        self.name = name
362
        self.aug_fn = NAME_TO_OP[name]
363
        self.level_fn = LEVEL_TO_ARG[name]
364
        self.prob = prob
365
        self.magnitude = magnitude
366
        self.hparams = hparams.copy()
367
        self.kwargs = dict(
368
            fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
369
            resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
370
        )
371

372
        # If magnitude_std is > 0, we introduce some randomness
373
        # in the usually fixed policy and sample magnitude from a normal distribution
374
        # with mean `magnitude` and std-dev of `magnitude_std`.
375
        # NOTE This is my own hack, being tested, not in papers or reference impls.
376
        # If magnitude_std is inf, we sample magnitude from a uniform distribution
377
        self.magnitude_std = self.hparams.get('magnitude_std', 0)
378
        self.magnitude_max = self.hparams.get('magnitude_max', None)
379

380
    def __call__(self, img):
381
        if self.prob < 1.0 and random.random() > self.prob:
382
            return img
383
        magnitude = self.magnitude
384
        if self.magnitude_std > 0:
385
            # magnitude randomization enabled
386
            if self.magnitude_std == float('inf'):
387
                # inf == uniform sampling
388
                magnitude = random.uniform(0, magnitude)
389
            elif self.magnitude_std > 0:
390
                magnitude = random.gauss(magnitude, self.magnitude_std)
391
        # default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
392
        # setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
393
        upper_bound = self.magnitude_max or _LEVEL_DENOM
394
        magnitude = max(0., min(magnitude, upper_bound))
395
        level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
396
        return self.aug_fn(img, *level_args, **self.kwargs)
397

398
    def __repr__(self):
399
        fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
400
        fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
401
        if self.magnitude_max is not None:
402
            fs += f', mmax={self.magnitude_max}'
403
        fs += ')'
404
        return fs
405

406

407
def auto_augment_policy_v0(hparams):
408
    # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
409
    policy = [
410
        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
411
        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
412
        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
413
        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
414
        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
415
        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
416
        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
417
        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
418
        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
419
        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
420
        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
421
        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
422
        [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
423
        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
424
        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
425
        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
426
        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
427
        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
428
        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
429
        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
430
        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
431
        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
432
        [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],  # This results in black image with Tpu posterize
433
        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
434
        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
435
    ]
436
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
437
    return pc
438

439

440
def auto_augment_policy_v0r(hparams):
441
    # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
442
    # in Google research implementation (number of bits discarded increases with magnitude)
443
    policy = [
444
        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
445
        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
446
        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
447
        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
448
        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
449
        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
450
        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
451
        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
452
        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
453
        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
454
        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
455
        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
456
        [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
457
        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
458
        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
459
        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
460
        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
461
        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
462
        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
463
        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
464
        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
465
        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
466
        [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
467
        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
468
        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
469
    ]
470
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
471
    return pc
472

473

474
def auto_augment_policy_original(hparams):
475
    # ImageNet policy from https://arxiv.org/abs/1805.09501
476
    policy = [
477
        [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
478
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
479
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
480
        [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
481
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
482
        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
483
        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
484
        [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
485
        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
486
        [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
487
        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
488
        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
489
        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
490
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
491
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
492
        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
493
        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
494
        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
495
        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
496
        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
497
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
498
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
499
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
500
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
501
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
502
    ]
503
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
504
    return pc
505

506

507
def auto_augment_policy_originalr(hparams):
508
    # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
509
    policy = [
510
        [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
511
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
512
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
513
        [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
514
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
515
        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
516
        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
517
        [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
518
        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
519
        [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
520
        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
521
        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
522
        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
523
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
524
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
525
        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
526
        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
527
        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
528
        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
529
        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
530
        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
531
        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
532
        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
533
        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
534
        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
535
    ]
536
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
537
    return pc
538

539

540
def auto_augment_policy_3a(hparams):
541
    policy = [
542
        [('Solarize', 1.0, 5)],  # 128 solarize threshold @ 5 magnitude
543
        [('Desaturate', 1.0, 10)],  # grayscale at 10 magnitude
544
        [('GaussianBlurRand', 1.0, 10)],
545
    ]
546
    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
547
    return pc
548

549

550
def auto_augment_policy(name='v0', hparams=None):
551
    hparams = hparams or _HPARAMS_DEFAULT
552
    if name == 'original':
553
        return auto_augment_policy_original(hparams)
554
    if name == 'originalr':
555
        return auto_augment_policy_originalr(hparams)
556
    if name == 'v0':
557
        return auto_augment_policy_v0(hparams)
558
    if name == 'v0r':
559
        return auto_augment_policy_v0r(hparams)
560
    if name == '3a':
561
        return auto_augment_policy_3a(hparams)
562
    assert False, f'Unknown AA policy {name}'
563

564

565
class AutoAugment:
566

567
    def __init__(self, policy):
568
        self.policy = policy
569

570
    def __call__(self, img):
571
        sub_policy = random.choice(self.policy)
572
        for op in sub_policy:
573
            img = op(img)
574
        return img
575

576
    def __repr__(self):
577
        fs = self.__class__.__name__ + '(policy='
578
        for p in self.policy:
579
            fs += '\n\t['
580
            fs += ', '.join([str(op) for op in p])
581
            fs += ']'
582
        fs += ')'
583
        return fs
584

585

586
def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
587
    """
588
    Create a AutoAugment transform
589

590
    Args:
591
        config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
592
            dashes ('-').
593
            The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
594

595
            The remaining sections:
596
                'mstd' -  float std deviation of magnitude noise applied
597
            Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
598

599
        hparams: Other hparams (kwargs) for the AutoAugmentation scheme
600

601
    Returns:
602
         A PyTorch compatible Transform
603
    """
604
    config = config_str.split('-')
605
    policy_name = config[0]
606
    config = config[1:]
607
    for c in config:
608
        cs = re.split(r'(\d.*)', c)
609
        if len(cs) < 2:
610
            continue
611
        key, val = cs[:2]
612
        if key == 'mstd':
613
            # noise param injected via hparams for now
614
            hparams.setdefault('magnitude_std', float(val))
615
        else:
616
            assert False, 'Unknown AutoAugment config section'
617
    aa_policy = auto_augment_policy(policy_name, hparams=hparams)
618
    return AutoAugment(aa_policy)
619

620

621
_RAND_TRANSFORMS = [
622
    'AutoContrast',
623
    'Equalize',
624
    'Invert',
625
    'Rotate',
626
    'Posterize',
627
    'Solarize',
628
    'SolarizeAdd',
629
    'Color',
630
    'Contrast',
631
    'Brightness',
632
    'Sharpness',
633
    'ShearX',
634
    'ShearY',
635
    'TranslateXRel',
636
    'TranslateYRel',
637
    # 'Cutout'  # NOTE I've implement this as random erasing separately
638
]
639

640

641
_RAND_INCREASING_TRANSFORMS = [
642
    'AutoContrast',
643
    'Equalize',
644
    'Invert',
645
    'Rotate',
646
    'PosterizeIncreasing',
647
    'SolarizeIncreasing',
648
    'SolarizeAdd',
649
    'ColorIncreasing',
650
    'ContrastIncreasing',
651
    'BrightnessIncreasing',
652
    'SharpnessIncreasing',
653
    'ShearX',
654
    'ShearY',
655
    'TranslateXRel',
656
    'TranslateYRel',
657
    # 'Cutout'  # NOTE I've implement this as random erasing separately
658
]
659

660

661
_RAND_3A = [
662
    'SolarizeIncreasing',
663
    'Desaturate',
664
    'GaussianBlur',
665
]
666

667

668
_RAND_WEIGHTED_3A = {
669
    'SolarizeIncreasing': 6,
670
    'Desaturate': 6,
671
    'GaussianBlur': 6,
672
    'Rotate': 3,
673
    'ShearX': 2,
674
    'ShearY': 2,
675
    'PosterizeIncreasing': 1,
676
    'AutoContrast': 1,
677
    'ColorIncreasing': 1,
678
    'SharpnessIncreasing': 1,
679
    'ContrastIncreasing': 1,
680
    'BrightnessIncreasing': 1,
681
    'Equalize': 1,
682
    'Invert': 1,
683
}
684

685

686
# These experimental weights are based loosely on the relative improvements mentioned in paper.
687
# They may not result in increased performance, but could likely be tuned to so.
688
_RAND_WEIGHTED_0 = {
689
    'Rotate': 3,
690
    'ShearX': 2,
691
    'ShearY': 2,
692
    'TranslateXRel': 1,
693
    'TranslateYRel': 1,
694
    'ColorIncreasing': .25,
695
    'SharpnessIncreasing': 0.25,
696
    'AutoContrast': 0.25,
697
    'SolarizeIncreasing': .05,
698
    'SolarizeAdd': .05,
699
    'ContrastIncreasing': .05,
700
    'BrightnessIncreasing': .05,
701
    'Equalize': .05,
702
    'PosterizeIncreasing': 0.05,
703
    'Invert': 0.05,
704
}
705

706

707
def _get_weighted_transforms(transforms: Dict):
708
    transforms, probs = list(zip(*transforms.items()))
709
    probs = np.array(probs)
710
    probs = probs / np.sum(probs)
711
    return transforms, probs
712

713

714
def rand_augment_choices(name: str, increasing=True):
715
    if name == 'weights':
716
        return _RAND_WEIGHTED_0
717
    if name == '3aw':
718
        return _RAND_WEIGHTED_3A
719
    if name == '3a':
720
        return _RAND_3A
721
    return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
722

723

724
def rand_augment_ops(
725
        magnitude: Union[int, float] = 10,
726
        prob: float = 0.5,
727
        hparams: Optional[Dict] = None,
728
        transforms: Optional[Union[Dict, List]] = None,
729
):
730
    hparams = hparams or _HPARAMS_DEFAULT
731
    transforms = transforms or _RAND_TRANSFORMS
732
    return [AugmentOp(
733
        name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
734

735

736
class RandAugment:
737
    def __init__(self, ops, num_layers=2, choice_weights=None):
738
        self.ops = ops
739
        self.num_layers = num_layers
740
        self.choice_weights = choice_weights
741

742
    def __call__(self, img):
743
        # no replacement when using weighted choice
744
        ops = np.random.choice(
745
            self.ops,
746
            self.num_layers,
747
            replace=self.choice_weights is None,
748
            p=self.choice_weights,
749
        )
750
        for op in ops:
751
            img = op(img)
752
        return img
753

754
    def __repr__(self):
755
        fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
756
        for op in self.ops:
757
            fs += f'\n\t{op}'
758
        fs += ')'
759
        return fs
760

761

762
def rand_augment_transform(
763
        config_str: str,
764
        hparams: Optional[Dict] = None,
765
        transforms: Optional[Union[str, Dict, List]] = None,
766
):
767
    """
768
    Create a RandAugment transform
769

770
    Args:
771
        config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
772
            by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
773
            The remaining sections, not order sepecific determine
774
                'm' - integer magnitude of rand augment
775
                'n' - integer num layers (number of transform ops selected per image)
776
                'p' - float probability of applying each layer (default 0.5)
777
                'mstd' -  float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
778
                'mmax' - set upper bound for magnitude to something other than default of  _LEVEL_DENOM (10)
779
                'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
780
                't' - str name of transform set to use
781
            Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
782
            'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
783

784
        hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
785

786
    Returns:
787
         A PyTorch compatible Transform
788
    """
789
    magnitude = _LEVEL_DENOM  # default to _LEVEL_DENOM for magnitude (currently 10)
790
    num_layers = 2  # default to 2 ops per image
791
    increasing = False
792
    prob = 0.5
793
    config = config_str.split('-')
794
    assert config[0] == 'rand'
795
    config = config[1:]
796
    for c in config:
797
        if c.startswith('t'):
798
            # NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights'
799
            val = str(c[1:])
800
            if transforms is None:
801
                transforms = val
802
        else:
803
            # numeric options
804
            cs = re.split(r'(\d.*)', c)
805
            if len(cs) < 2:
806
                continue
807
            key, val = cs[:2]
808
            if key == 'mstd':
809
                # noise param / randomization of magnitude values
810
                mstd = float(val)
811
                if mstd > 100:
812
                    # use uniform sampling in 0 to magnitude if mstd is > 100
813
                    mstd = float('inf')
814
                hparams.setdefault('magnitude_std', mstd)
815
            elif key == 'mmax':
816
                # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
817
                hparams.setdefault('magnitude_max', int(val))
818
            elif key == 'inc':
819
                if bool(val):
820
                    increasing = True
821
            elif key == 'm':
822
                magnitude = int(val)
823
            elif key == 'n':
824
                num_layers = int(val)
825
            elif key == 'p':
826
                prob = float(val)
827
            else:
828
                assert False, 'Unknown RandAugment config section'
829

830
    if isinstance(transforms, str):
831
        transforms = rand_augment_choices(transforms, increasing=increasing)
832
    elif transforms is None:
833
        transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
834

835
    choice_weights = None
836
    if isinstance(transforms, Dict):
837
        transforms, choice_weights = _get_weighted_transforms(transforms)
838

839
    ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms)
840
    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
841

842

843
_AUGMIX_TRANSFORMS = [
844
    'AutoContrast',
845
    'ColorIncreasing',  # not in paper
846
    'ContrastIncreasing',  # not in paper
847
    'BrightnessIncreasing',  # not in paper
848
    'SharpnessIncreasing',  # not in paper
849
    'Equalize',
850
    'Rotate',
851
    'PosterizeIncreasing',
852
    'SolarizeIncreasing',
853
    'ShearX',
854
    'ShearY',
855
    'TranslateXRel',
856
    'TranslateYRel',
857
]
858

859

860
def augmix_ops(
861
        magnitude: Union[int, float] = 10,
862
        hparams: Optional[Dict] = None,
863
        transforms: Optional[Union[str, Dict, List]] = None,
864
):
865
    hparams = hparams or _HPARAMS_DEFAULT
866
    transforms = transforms or _AUGMIX_TRANSFORMS
867
    return [AugmentOp(
868
        name,
869
        prob=1.0,
870
        magnitude=magnitude,
871
        hparams=hparams
872
    ) for name in transforms]
873

874

875
class AugMixAugment:
876
    """ AugMix Transform
877
    Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
878
    From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
879
    https://arxiv.org/abs/1912.02781
880
    """
881
    def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
882
        self.ops = ops
883
        self.alpha = alpha
884
        self.width = width
885
        self.depth = depth
886
        self.blended = blended  # blended mode is faster but not well tested
887

888
    def _calc_blended_weights(self, ws, m):
889
        ws = ws * m
890
        cump = 1.
891
        rws = []
892
        for w in ws[::-1]:
893
            alpha = w / cump
894
            cump *= (1 - alpha)
895
            rws.append(alpha)
896
        return np.array(rws[::-1], dtype=np.float32)
897

898
    def _apply_blended(self, img, mixing_weights, m):
899
        # This is my first crack and implementing a slightly faster mixed augmentation. Instead
900
        # of accumulating the mix for each chain in a Numpy array and then blending with original,
901
        # it recomputes the blending coefficients and applies one PIL image blend per chain.
902
        # TODO the results appear in the right ballpark but they differ by more than rounding.
903
        img_orig = img.copy()
904
        ws = self._calc_blended_weights(mixing_weights, m)
905
        for w in ws:
906
            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
907
            ops = np.random.choice(self.ops, depth, replace=True)
908
            img_aug = img_orig  # no ops are in-place, deep copy not necessary
909
            for op in ops:
910
                img_aug = op(img_aug)
911
            img = Image.blend(img, img_aug, w)
912
        return img
913

914
    def _apply_basic(self, img, mixing_weights, m):
915
        # This is a literal adaptation of the paper/official implementation without normalizations and
916
        # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
917
        # typical augmentation transforms, could use a GPU / Kornia implementation.
918
        img_shape = img.size[0], img.size[1], len(img.getbands())
919
        mixed = np.zeros(img_shape, dtype=np.float32)
920
        for mw in mixing_weights:
921
            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
922
            ops = np.random.choice(self.ops, depth, replace=True)
923
            img_aug = img  # no ops are in-place, deep copy not necessary
924
            for op in ops:
925
                img_aug = op(img_aug)
926
            mixed += mw * np.asarray(img_aug, dtype=np.float32)
927
        np.clip(mixed, 0, 255., out=mixed)
928
        mixed = Image.fromarray(mixed.astype(np.uint8))
929
        return Image.blend(img, mixed, m)
930

931
    def __call__(self, img):
932
        mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
933
        m = np.float32(np.random.beta(self.alpha, self.alpha))
934
        if self.blended:
935
            mixed = self._apply_blended(img, mixing_weights, m)
936
        else:
937
            mixed = self._apply_basic(img, mixing_weights, m)
938
        return mixed
939

940
    def __repr__(self):
941
        fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
942
        for op in self.ops:
943
            fs += f'\n\t{op}'
944
        fs += ')'
945
        return fs
946

947

948
def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None):
949
    """ Create AugMix PyTorch transform
950

951
    Args:
952
        config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
953
            by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
954
            The remaining sections, not order sepecific determine
955
                'm' - integer magnitude (severity) of augmentation mix (default: 3)
956
                'w' - integer width of augmentation chain (default: 3)
957
                'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
958
                'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
959
                'mstd' -  float std deviation of magnitude noise applied (default: 0)
960
            Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
961

962
        hparams: Other hparams (kwargs) for the Augmentation transforms
963

964
    Returns:
965
         A PyTorch compatible Transform
966
    """
967
    magnitude = 3
968
    width = 3
969
    depth = -1
970
    alpha = 1.
971
    blended = False
972
    config = config_str.split('-')
973
    assert config[0] == 'augmix'
974
    config = config[1:]
975
    for c in config:
976
        cs = re.split(r'(\d.*)', c)
977
        if len(cs) < 2:
978
            continue
979
        key, val = cs[:2]
980
        if key == 'mstd':
981
            # noise param injected via hparams for now
982
            hparams.setdefault('magnitude_std', float(val))
983
        elif key == 'm':
984
            magnitude = int(val)
985
        elif key == 'w':
986
            width = int(val)
987
        elif key == 'd':
988
            depth = int(val)
989
        elif key == 'a':
990
            alpha = float(val)
991
        elif key == 'b':
992
            blended = bool(val)
993
        else:
994
            assert False, 'Unknown AugMix config section'
995
    hparams.setdefault('magnitude_std', float('inf'))  # default to uniform sampling (if not set via mstd arg)
996
    ops = augmix_ops(magnitude=magnitude, hparams=hparams)
997
    return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
998

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.