pytorch-image-models

Форк
0
534 строки · 18.4 Кб
1
import math
2
import numbers
3
import random
4
import warnings
5
from typing import List, Sequence, Tuple, Union
6

7
import torch
8
import torchvision.transforms.functional as F
9
try:
10
    from torchvision.transforms.functional import InterpolationMode
11
    has_interpolation_mode = True
12
except ImportError:
13
    has_interpolation_mode = False
14
from PIL import Image
15
import numpy as np
16

17
__all__ = [
18
    "ToNumpy", "ToTensor", "str_to_interp_mode", "str_to_pil_interp", "interp_mode_to_str",
19
    "RandomResizedCropAndInterpolation", "CenterCropOrPad", "center_crop_or_pad", "crop_or_pad",
20
    "RandomCropOrPad", "RandomPad", "ResizeKeepRatio", "TrimBorder"
21
]
22

23

24
class ToNumpy:
25

26
    def __call__(self, pil_img):
27
        np_img = np.array(pil_img, dtype=np.uint8)
28
        if np_img.ndim < 3:
29
            np_img = np.expand_dims(np_img, axis=-1)
30
        np_img = np.rollaxis(np_img, 2)  # HWC to CHW
31
        return np_img
32

33

34
class ToTensor:
35
    """ ToTensor with no rescaling of values"""
36
    def __init__(self, dtype=torch.float32):
37
        self.dtype = dtype
38

39
    def __call__(self, pil_img):
40
        return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
41

42

43
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
44
# favor of the Image.Resampling enum. The top-level resampling attributes will be
45
# removed in Pillow 10.
46
if hasattr(Image, "Resampling"):
47
    _pil_interpolation_to_str = {
48
        Image.Resampling.NEAREST: 'nearest',
49
        Image.Resampling.BILINEAR: 'bilinear',
50
        Image.Resampling.BICUBIC: 'bicubic',
51
        Image.Resampling.BOX: 'box',
52
        Image.Resampling.HAMMING: 'hamming',
53
        Image.Resampling.LANCZOS: 'lanczos',
54
    }
55
else:
56
    _pil_interpolation_to_str = {
57
        Image.NEAREST: 'nearest',
58
        Image.BILINEAR: 'bilinear',
59
        Image.BICUBIC: 'bicubic',
60
        Image.BOX: 'box',
61
        Image.HAMMING: 'hamming',
62
        Image.LANCZOS: 'lanczos',
63
    }
64

65
_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
66

67

68
if has_interpolation_mode:
69
    _torch_interpolation_to_str = {
70
        InterpolationMode.NEAREST: 'nearest',
71
        InterpolationMode.BILINEAR: 'bilinear',
72
        InterpolationMode.BICUBIC: 'bicubic',
73
        InterpolationMode.BOX: 'box',
74
        InterpolationMode.HAMMING: 'hamming',
75
        InterpolationMode.LANCZOS: 'lanczos',
76
    }
77
    _str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}
78
else:
79
    _pil_interpolation_to_torch = {}
80
    _torch_interpolation_to_str = {}
81

82

83
def str_to_pil_interp(mode_str):
84
    return _str_to_pil_interpolation[mode_str]
85

86

87
def str_to_interp_mode(mode_str):
88
    if has_interpolation_mode:
89
        return _str_to_torch_interpolation[mode_str]
90
    else:
91
        return _str_to_pil_interpolation[mode_str]
92

93

94
def interp_mode_to_str(mode):
95
    if has_interpolation_mode:
96
        return _torch_interpolation_to_str[mode]
97
    else:
98
        return _pil_interpolation_to_str[mode]
99

100

101
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
102

103

104
def _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."):
105
    if isinstance(size, numbers.Number):
106
        return int(size), int(size)
107

108
    if isinstance(size, Sequence) and len(size) == 1:
109
        return size[0], size[0]
110

111
    if len(size) != 2:
112
        raise ValueError(error_msg)
113

114
    return size
115

116

117
class RandomResizedCropAndInterpolation:
118
    """Crop the given PIL Image to random size and aspect ratio with random interpolation.
119

120
    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
121
    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
122
    is finally resized to given size.
123
    This is popularly used to train the Inception networks.
124

125
    Args:
126
        size: expected output size of each edge
127
        scale: range of size of the origin size cropped
128
        ratio: range of aspect ratio of the origin aspect ratio cropped
129
        interpolation: Default: PIL.Image.BILINEAR
130
    """
131

132
    def __init__(
133
            self,
134
            size,
135
            scale=(0.08, 1.0),
136
            ratio=(3. / 4., 4. / 3.),
137
            interpolation='bilinear',
138
    ):
139
        if isinstance(size, (list, tuple)):
140
            self.size = tuple(size)
141
        else:
142
            self.size = (size, size)
143
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
144
            warnings.warn("range should be of kind (min, max)")
145

146
        if interpolation == 'random':
147
            self.interpolation = _RANDOM_INTERPOLATION
148
        else:
149
            self.interpolation = str_to_interp_mode(interpolation)
150
        self.scale = scale
151
        self.ratio = ratio
152

153
    @staticmethod
154
    def get_params(img, scale, ratio):
155
        """Get parameters for ``crop`` for a random sized crop.
156

157
        Args:
158
            img (PIL Image): Image to be cropped.
159
            scale (tuple): range of size of the origin size cropped
160
            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
161

162
        Returns:
163
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
164
                sized crop.
165
        """
166
        img_w, img_h = F.get_image_size(img)
167
        area = img_w * img_h
168

169
        for attempt in range(10):
170
            target_area = random.uniform(*scale) * area
171
            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
172
            aspect_ratio = math.exp(random.uniform(*log_ratio))
173

174
            target_w = int(round(math.sqrt(target_area * aspect_ratio)))
175
            target_h = int(round(math.sqrt(target_area / aspect_ratio)))
176
            if target_w <= img_w and target_h <= img_h:
177
                i = random.randint(0, img_h - target_h)
178
                j = random.randint(0, img_w - target_w)
179
                return i, j, target_h, target_w
180

181
        # Fallback to central crop
182
        in_ratio = img_w / img_h
183
        if in_ratio < min(ratio):
184
            target_w = img_w
185
            target_h = int(round(target_w / min(ratio)))
186
        elif in_ratio > max(ratio):
187
            target_h = img_h
188
            target_w = int(round(target_h * max(ratio)))
189
        else:  # whole image
190
            target_w = img_w
191
            target_h = img_h
192
        i = (img_h - target_h) // 2
193
        j = (img_w - target_w) // 2
194
        return i, j, target_h, target_w
195

196
    def __call__(self, img):
197
        """
198
        Args:
199
            img (PIL Image): Image to be cropped and resized.
200

201
        Returns:
202
            PIL Image: Randomly cropped and resized image.
203
        """
204
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
205
        if isinstance(self.interpolation, (tuple, list)):
206
            interpolation = random.choice(self.interpolation)
207
        else:
208
            interpolation = self.interpolation
209
        return F.resized_crop(img, i, j, h, w, self.size, interpolation)
210

211
    def __repr__(self):
212
        if isinstance(self.interpolation, (tuple, list)):
213
            interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
214
        else:
215
            interpolate_str = interp_mode_to_str(self.interpolation)
216
        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
217
        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
218
        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
219
        format_string += ', interpolation={0})'.format(interpolate_str)
220
        return format_string
221

222

223
def center_crop_or_pad(
224
        img: torch.Tensor,
225
        output_size: Union[int, List[int]],
226
        fill: Union[int, Tuple[int, int, int]] = 0,
227
        padding_mode: str = 'constant',
228
) -> torch.Tensor:
229
    """Center crops and/or pads the given image.
230

231
    If the image is torch Tensor, it is expected
232
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
233
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
234

235
    Args:
236
        img (PIL Image or Tensor): Image to be cropped.
237
        output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
238
            it is used for both directions.
239
        fill (int, Tuple[int]): Padding color
240

241
    Returns:
242
        PIL Image or Tensor: Cropped image.
243
    """
244
    output_size = _setup_size(output_size)
245
    crop_height, crop_width = output_size
246
    _, image_height, image_width = F.get_dimensions(img)
247

248
    if crop_width > image_width or crop_height > image_height:
249
        padding_ltrb = [
250
            (crop_width - image_width) // 2 if crop_width > image_width else 0,
251
            (crop_height - image_height) // 2 if crop_height > image_height else 0,
252
            (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
253
            (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
254
        ]
255
        img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
256
        _, image_height, image_width = F.get_dimensions(img)
257
        if crop_width == image_width and crop_height == image_height:
258
            return img
259

260
    crop_top = int(round((image_height - crop_height) / 2.0))
261
    crop_left = int(round((image_width - crop_width) / 2.0))
262
    return F.crop(img, crop_top, crop_left, crop_height, crop_width)
263

264

265
class CenterCropOrPad(torch.nn.Module):
266
    """Crops the given image at the center.
267
    If the image is torch Tensor, it is expected
268
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
269
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
270

271
    Args:
272
        size (sequence or int): Desired output size of the crop. If size is an
273
            int instead of sequence like (h, w), a square crop (size, size) is
274
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
275
    """
276

277
    def __init__(
278
            self,
279
            size: Union[int, List[int]],
280
            fill: Union[int, Tuple[int, int, int]] = 0,
281
            padding_mode: str = 'constant',
282
    ):
283
        super().__init__()
284
        self.size = _setup_size(size)
285
        self.fill = fill
286
        self.padding_mode = padding_mode
287

288
    def forward(self, img):
289
        """
290
        Args:
291
            img (PIL Image or Tensor): Image to be cropped.
292

293
        Returns:
294
            PIL Image or Tensor: Cropped image.
295
        """
296
        return center_crop_or_pad(img, self.size, fill=self.fill, padding_mode=self.padding_mode)
297

298
    def __repr__(self) -> str:
299
        return f"{self.__class__.__name__}(size={self.size})"
300

301

302
def crop_or_pad(
303
        img: torch.Tensor,
304
        top: int,
305
        left: int,
306
        height: int,
307
        width: int,
308
        fill: Union[int, Tuple[int, int, int]] = 0,
309
        padding_mode: str = 'constant',
310
) -> torch.Tensor:
311
    """ Crops and/or pads image to meet target size, with control over fill and padding_mode.
312
    """
313
    _, image_height, image_width = F.get_dimensions(img)
314
    right = left + width
315
    bottom = top + height
316
    if left < 0 or top < 0 or right > image_width or bottom > image_height:
317
        padding_ltrb = [
318
            max(-left + min(0, right), 0),
319
            max(-top + min(0, bottom), 0),
320
            max(right - max(image_width, left), 0),
321
            max(bottom - max(image_height, top), 0),
322
        ]
323
        img = F.pad(img, padding_ltrb, fill=fill, padding_mode=padding_mode)
324

325
    top = max(top, 0)
326
    left = max(left, 0)
327
    return F.crop(img, top, left, height, width)
328

329

330
class RandomCropOrPad(torch.nn.Module):
331
    """ Crop and/or pad image with random placement within the crop or pad margin.
332
    """
333

334
    def __init__(
335
            self,
336
            size: Union[int, List[int]],
337
            fill: Union[int, Tuple[int, int, int]] = 0,
338
            padding_mode: str = 'constant',
339
    ):
340
        super().__init__()
341
        self.size = _setup_size(size)
342
        self.fill = fill
343
        self.padding_mode = padding_mode
344

345
    @staticmethod
346
    def get_params(img, size):
347
        _, image_height, image_width = F.get_dimensions(img)
348
        delta_height = image_height - size[0]
349
        delta_width = image_width - size[1]
350
        top = int(math.copysign(random.randint(0, abs(delta_height)), delta_height))
351
        left = int(math.copysign(random.randint(0, abs(delta_width)), delta_width))
352
        return top, left
353

354
    def forward(self, img):
355
        """
356
        Args:
357
            img (PIL Image or Tensor): Image to be cropped.
358

359
        Returns:
360
            PIL Image or Tensor: Cropped image.
361
        """
362
        top, left = self.get_params(img, self.size)
363
        return crop_or_pad(
364
            img,
365
            top=top,
366
            left=left,
367
            height=self.size[0],
368
            width=self.size[1],
369
            fill=self.fill,
370
            padding_mode=self.padding_mode,
371
        )
372

373
    def __repr__(self) -> str:
374
        return f"{self.__class__.__name__}(size={self.size})"
375

376

377
class RandomPad:
378
    def __init__(self, input_size, fill=0):
379
        self.input_size = input_size
380
        self.fill = fill
381

382
    @staticmethod
383
    def get_params(img, input_size):
384
        width, height = F.get_image_size(img)
385
        delta_width = max(input_size[1] - width, 0)
386
        delta_height = max(input_size[0] - height, 0)
387
        pad_left = random.randint(0, delta_width)
388
        pad_top = random.randint(0, delta_height)
389
        pad_right = delta_width - pad_left
390
        pad_bottom = delta_height - pad_top
391
        return pad_left, pad_top, pad_right, pad_bottom
392

393
    def __call__(self, img):
394
        padding = self.get_params(img, self.input_size)
395
        img = F.pad(img, padding, self.fill)
396
        return img
397

398

399
class ResizeKeepRatio:
400
    """ Resize and Keep Aspect Ratio
401
    """
402

403
    def __init__(
404
            self,
405
            size,
406
            longest=0.,
407
            interpolation='bilinear',
408
            random_scale_prob=0.,
409
            random_scale_range=(0.85, 1.05),
410
            random_scale_area=False,
411
            random_aspect_prob=0.,
412
            random_aspect_range=(0.9, 1.11),
413
    ):
414
        """
415

416
        Args:
417
            size:
418
            longest:
419
            interpolation:
420
            random_scale_prob:
421
            random_scale_range:
422
            random_scale_area:
423
            random_aspect_prob:
424
            random_aspect_range:
425
        """
426
        if isinstance(size, (list, tuple)):
427
            self.size = tuple(size)
428
        else:
429
            self.size = (size, size)
430
        if interpolation == 'random':
431
            self.interpolation = _RANDOM_INTERPOLATION
432
        else:
433
            self.interpolation = str_to_interp_mode(interpolation)
434
        self.longest = float(longest)
435
        self.random_scale_prob = random_scale_prob
436
        self.random_scale_range = random_scale_range
437
        self.random_scale_area = random_scale_area
438
        self.random_aspect_prob = random_aspect_prob
439
        self.random_aspect_range = random_aspect_range
440

441
    @staticmethod
442
    def get_params(
443
            img,
444
            target_size,
445
            longest,
446
            random_scale_prob=0.,
447
            random_scale_range=(1.0, 1.33),
448
            random_scale_area=False,
449
            random_aspect_prob=0.,
450
            random_aspect_range=(0.9, 1.11)
451
    ):
452
        """Get parameters
453
        """
454
        img_h, img_w = img_size = F.get_dimensions(img)[1:]
455
        target_h, target_w = target_size
456
        ratio_h = img_h / target_h
457
        ratio_w = img_w / target_w
458
        ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
459

460
        if random_scale_prob > 0 and random.random() < random_scale_prob:
461
            ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
462
            if random_scale_area:
463
                # make ratio factor equivalent to RRC area crop where < 1.0 = area zoom,
464
                # otherwise like affine scale where < 1.0 = linear zoom out
465
                ratio_factor = 1. / math.sqrt(ratio_factor)
466
            ratio_factor = (ratio_factor, ratio_factor)
467
        else:
468
            ratio_factor = (1., 1.)
469

470
        if random_aspect_prob > 0 and random.random() < random_aspect_prob:
471
            log_aspect = (math.log(random_aspect_range[0]), math.log(random_aspect_range[1]))
472
            aspect_factor = math.exp(random.uniform(*log_aspect))
473
            aspect_factor = math.sqrt(aspect_factor)
474
            # currently applying random aspect adjustment equally to both dims,
475
            # could change to keep output sizes above their target where possible
476
            ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
477

478
        size = [round(x * f / ratio) for x, f in zip(img_size, ratio_factor)]
479
        return size
480

481
    def __call__(self, img):
482
        """
483
        Args:
484
            img (PIL Image): Image to be cropped and resized.
485

486
        Returns:
487
            PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
488
        """
489
        size = self.get_params(
490
            img, self.size, self.longest,
491
            self.random_scale_prob, self.random_scale_range, self.random_scale_area,
492
            self.random_aspect_prob, self.random_aspect_range
493
        )
494
        if isinstance(self.interpolation, (tuple, list)):
495
            interpolation = random.choice(self.interpolation)
496
        else:
497
            interpolation = self.interpolation
498
        img = F.resize(img, size, interpolation)
499
        return img
500

501
    def __repr__(self):
502
        if isinstance(self.interpolation, (tuple, list)):
503
            interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
504
        else:
505
            interpolate_str = interp_mode_to_str(self.interpolation)
506
        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
507
        format_string += f', interpolation={interpolate_str}'
508
        format_string += f', longest={self.longest:.3f}'
509
        format_string += f', random_scale_prob={self.random_scale_prob:.3f}'
510
        format_string += f', random_scale_range=(' \
511
                         f'{self.random_scale_range[0]:.3f}, {self.random_aspect_range[1]:.3f})'
512
        format_string += f', random_aspect_prob={self.random_aspect_prob:.3f}'
513
        format_string += f', random_aspect_range=(' \
514
                         f'{self.random_aspect_range[0]:.3f}, {self.random_aspect_range[1]:.3f}))'
515
        return format_string
516

517

518
class TrimBorder(torch.nn.Module):
519

520
    def __init__(
521
            self,
522
            border_size: int,
523
    ):
524
        super().__init__()
525
        self.border_size = border_size
526

527
    def forward(self, img):
528
        w, h = F.get_image_size(img)
529
        top = left = self.border_size
530
        top = min(top, h)
531
        left = min(left, h)
532
        height = max(0, h - 2 * self.border_size)
533
        width = max(0, w - 2 * self.border_size)
534
        return F.crop(img, top, left, height, width)

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.