vision

Форк
0
/
test_transforms_tensor.py 
892 строки · 34.3 Кб
1
import os
2
import sys
3

4
import numpy as np
5
import PIL.Image
6
import pytest
7
import torch
8
from common_utils import (
9
    _assert_approx_equal_tensor_to_pil,
10
    _assert_equal_tensor_to_pil,
11
    _create_data,
12
    _create_data_batch,
13
    assert_equal,
14
    cpu_and_cuda,
15
    float_dtypes,
16
    get_tmp_dir,
17
    int_dtypes,
18
)
19
from torchvision import transforms as T
20
from torchvision.transforms import functional as F, InterpolationMode
21
from torchvision.transforms.autoaugment import _apply_op
22

23
NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
24
    InterpolationMode.NEAREST,
25
    InterpolationMode.NEAREST_EXACT,
26
    InterpolationMode.BILINEAR,
27
    InterpolationMode.BICUBIC,
28
)
29

30

31
def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
32
    torch.manual_seed(12)
33
    out1 = transform(tensor)
34
    torch.manual_seed(12)
35
    out2 = s_transform(tensor)
36
    assert_equal(out1, out2, msg=msg)
37

38

39
def _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors, msg=None):
40
    torch.manual_seed(12)
41
    transformed_batch = transform(batch_tensors)
42

43
    for i in range(len(batch_tensors)):
44
        img_tensor = batch_tensors[i, ...]
45
        torch.manual_seed(12)
46
        transformed_img = transform(img_tensor)
47
        assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
48

49
    torch.manual_seed(12)
50
    s_transformed_batch = s_transform(batch_tensors)
51
    assert_equal(transformed_batch, s_transformed_batch, msg=msg)
52

53

54
def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=True, **match_kwargs):
55
    fn_kwargs = fn_kwargs or {}
56

57
    tensor, pil_img = _create_data(height=10, width=10, channels=channels, device=device)
58
    transformed_tensor = f(tensor, **fn_kwargs)
59
    transformed_pil_img = f(pil_img, **fn_kwargs)
60
    if test_exact_match:
61
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
62
    else:
63
        _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
64

65

66
def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs):
67
    meth_kwargs = meth_kwargs or {}
68

69
    # test for class interface
70
    f = transform_cls(**meth_kwargs)
71
    scripted_fn = torch.jit.script(f)
72

73
    tensor, pil_img = _create_data(26, 34, channels, device=device)
74
    # set seed to reproduce the same transformation for tensor and PIL image
75
    torch.manual_seed(12)
76
    transformed_tensor = f(tensor)
77
    torch.manual_seed(12)
78
    transformed_pil_img = f(pil_img)
79
    if test_exact_match:
80
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
81
    else:
82
        _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
83

84
    torch.manual_seed(12)
85
    transformed_tensor_script = scripted_fn(tensor)
86
    assert_equal(transformed_tensor, transformed_tensor_script)
87

88
    batch_tensors = _create_data_batch(height=23, width=34, channels=channels, num_samples=4, device=device)
89
    _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
90

91
    with get_tmp_dir() as tmp_dir:
92
        scripted_fn.save(os.path.join(tmp_dir, f"t_{transform_cls.__name__}.pt"))
93

94

95
def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
96
    _test_functional_op(func, device, channels, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
97
    _test_class_op(method, device, channels, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
98

99

100
def _test_fn_save_load(fn, tmpdir):
101
    scripted_fn = torch.jit.script(fn)
102
    p = os.path.join(tmpdir, f"t_op_list_{getattr(fn, '__name__', fn.__class__.__name__)}.pt")
103
    scripted_fn.save(p)
104
    _ = torch.jit.load(p)
105

106

107
@pytest.mark.parametrize("device", cpu_and_cuda())
108
@pytest.mark.parametrize(
109
    "func,method,fn_kwargs,match_kwargs",
110
    [
111
        (F.hflip, T.RandomHorizontalFlip, None, {}),
112
        (F.vflip, T.RandomVerticalFlip, None, {}),
113
        (F.invert, T.RandomInvert, None, {}),
114
        (F.posterize, T.RandomPosterize, {"bits": 4}, {}),
115
        (F.solarize, T.RandomSolarize, {"threshold": 192.0}, {}),
116
        (F.adjust_sharpness, T.RandomAdjustSharpness, {"sharpness_factor": 2.0}, {}),
117
        (
118
            F.autocontrast,
119
            T.RandomAutocontrast,
120
            None,
121
            {"test_exact_match": False, "agg_method": "max", "tol": (1 + 1e-5), "allowed_percentage_diff": 0.05},
122
        ),
123
        (F.equalize, T.RandomEqualize, None, {}),
124
    ],
125
)
126
@pytest.mark.parametrize("channels", [1, 3])
127
def test_random(func, method, device, channels, fn_kwargs, match_kwargs):
128
    _test_op(func, method, device, channels, fn_kwargs, fn_kwargs, **match_kwargs)
129

130

131
@pytest.mark.parametrize("seed", range(10))
132
@pytest.mark.parametrize("device", cpu_and_cuda())
133
@pytest.mark.parametrize("channels", [1, 3])
134
class TestColorJitter:
135
    @pytest.fixture(autouse=True)
136
    def set_random_seed(self, seed):
137
        torch.random.manual_seed(seed)
138

139
    @pytest.mark.parametrize("brightness", [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]])
140
    def test_color_jitter_brightness(self, brightness, device, channels):
141
        tol = 1.0 + 1e-10
142
        meth_kwargs = {"brightness": brightness}
143
        _test_class_op(
144
            T.ColorJitter,
145
            meth_kwargs=meth_kwargs,
146
            test_exact_match=False,
147
            device=device,
148
            tol=tol,
149
            agg_method="max",
150
            channels=channels,
151
        )
152

153
    @pytest.mark.parametrize("contrast", [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]])
154
    def test_color_jitter_contrast(self, contrast, device, channels):
155
        tol = 1.0 + 1e-10
156
        meth_kwargs = {"contrast": contrast}
157
        _test_class_op(
158
            T.ColorJitter,
159
            meth_kwargs=meth_kwargs,
160
            test_exact_match=False,
161
            device=device,
162
            tol=tol,
163
            agg_method="max",
164
            channels=channels,
165
        )
166

167
    @pytest.mark.parametrize("saturation", [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]])
168
    def test_color_jitter_saturation(self, saturation, device, channels):
169
        tol = 1.0 + 1e-10
170
        meth_kwargs = {"saturation": saturation}
171
        _test_class_op(
172
            T.ColorJitter,
173
            meth_kwargs=meth_kwargs,
174
            test_exact_match=False,
175
            device=device,
176
            tol=tol,
177
            agg_method="max",
178
            channels=channels,
179
        )
180

181
    @pytest.mark.parametrize("hue", [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]])
182
    def test_color_jitter_hue(self, hue, device, channels):
183
        meth_kwargs = {"hue": hue}
184
        _test_class_op(
185
            T.ColorJitter,
186
            meth_kwargs=meth_kwargs,
187
            test_exact_match=False,
188
            device=device,
189
            tol=16.1,
190
            agg_method="max",
191
            channels=channels,
192
        )
193

194
    def test_color_jitter_all(self, device, channels):
195
        # All 4 parameters together
196
        meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
197
        _test_class_op(
198
            T.ColorJitter,
199
            meth_kwargs=meth_kwargs,
200
            test_exact_match=False,
201
            device=device,
202
            tol=12.1,
203
            agg_method="max",
204
            channels=channels,
205
        )
206

207

208
@pytest.mark.parametrize("device", cpu_and_cuda())
209
@pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"])
210
@pytest.mark.parametrize("mul", [1, -1])
211
def test_pad(m, mul, device):
212
    fill = 127 if m == "constant" else 0
213

214
    # Test functional.pad (PIL and Tensor) with padding as single int
215
    _test_functional_op(F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, device=device)
216
    # Test functional.pad and transforms.Pad with padding as [int, ]
217
    fn_kwargs = meth_kwargs = {
218
        "padding": [mul * 2],
219
        "fill": fill,
220
        "padding_mode": m,
221
    }
222
    _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
223
    # Test functional.pad and transforms.Pad with padding as list
224
    fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
225
    _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
226
    # Test functional.pad and transforms.Pad with padding as tuple
227
    fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
228
    _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
229

230

231
@pytest.mark.parametrize("device", cpu_and_cuda())
232
def test_crop(device):
233
    fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
234
    # Test transforms.RandomCrop with size and padding as tuple
235
    meth_kwargs = {
236
        "size": (4, 5),
237
        "padding": (4, 4),
238
        "pad_if_needed": True,
239
    }
240
    _test_op(F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
241

242
    # Test transforms.functional.crop including outside the image area
243
    fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5}  # top
244
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
245

246
    fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5}  # left
247
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
248

249
    fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5}  # bottom
250
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
251

252
    fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5}  # right
253
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
254

255
    fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15}  # all
256
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
257

258

259
@pytest.mark.parametrize("device", cpu_and_cuda())
260
@pytest.mark.parametrize(
261
    "padding_config",
262
    [
263
        {"padding_mode": "constant", "fill": 0},
264
        {"padding_mode": "constant", "fill": 10},
265
        {"padding_mode": "edge"},
266
        {"padding_mode": "reflect"},
267
    ],
268
)
269
@pytest.mark.parametrize("pad_if_needed", [True, False])
270
@pytest.mark.parametrize("padding", [[5], [5, 4], [1, 2, 3, 4]])
271
@pytest.mark.parametrize("size", [5, [5], [6, 6]])
272
def test_random_crop(size, padding, pad_if_needed, padding_config, device):
273
    config = dict(padding_config)
274
    config["size"] = size
275
    config["padding"] = padding
276
    config["pad_if_needed"] = pad_if_needed
277
    _test_class_op(T.RandomCrop, device, meth_kwargs=config)
278

279

280
def test_random_crop_save_load(tmpdir):
281
    fn = T.RandomCrop(32, [4], pad_if_needed=True)
282
    _test_fn_save_load(fn, tmpdir)
283

284

285
@pytest.mark.parametrize("device", cpu_and_cuda())
286
def test_center_crop(device, tmpdir):
287
    fn_kwargs = {"output_size": (4, 5)}
288
    meth_kwargs = {"size": (4, 5)}
289
    _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
290
    fn_kwargs = {"output_size": (5,)}
291
    meth_kwargs = {"size": (5,)}
292
    _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
293
    tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=device)
294
    # Test torchscript of transforms.CenterCrop with size as int
295
    f = T.CenterCrop(size=5)
296
    scripted_fn = torch.jit.script(f)
297
    scripted_fn(tensor)
298

299
    # Test torchscript of transforms.CenterCrop with size as [int, ]
300
    f = T.CenterCrop(size=[5])
301
    scripted_fn = torch.jit.script(f)
302
    scripted_fn(tensor)
303

304
    # Test torchscript of transforms.CenterCrop with size as tuple
305
    f = T.CenterCrop(size=(6, 6))
306
    scripted_fn = torch.jit.script(f)
307
    scripted_fn(tensor)
308

309

310
def test_center_crop_save_load(tmpdir):
311
    fn = T.CenterCrop(size=[5])
312
    _test_fn_save_load(fn, tmpdir)
313

314

315
@pytest.mark.parametrize("device", cpu_and_cuda())
316
@pytest.mark.parametrize(
317
    "fn, method, out_length",
318
    [
319
        # test_five_crop
320
        (F.five_crop, T.FiveCrop, 5),
321
        # test_ten_crop
322
        (F.ten_crop, T.TenCrop, 10),
323
    ],
324
)
325
@pytest.mark.parametrize("size", [(5,), [5], (4, 5), [4, 5]])
326
def test_x_crop(fn, method, out_length, size, device):
327
    meth_kwargs = fn_kwargs = {"size": size}
328
    scripted_fn = torch.jit.script(fn)
329

330
    tensor, pil_img = _create_data(height=20, width=20, device=device)
331
    transformed_t_list = fn(tensor, **fn_kwargs)
332
    transformed_p_list = fn(pil_img, **fn_kwargs)
333
    assert len(transformed_t_list) == len(transformed_p_list)
334
    assert len(transformed_t_list) == out_length
335
    for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
336
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)
337

338
    transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
339
    assert len(transformed_t_list) == len(transformed_t_list_script)
340
    assert len(transformed_t_list_script) == out_length
341
    for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
342
        assert_equal(transformed_tensor, transformed_tensor_script)
343

344
    # test for class interface
345
    fn = method(**meth_kwargs)
346
    scripted_fn = torch.jit.script(fn)
347
    output = scripted_fn(tensor)
348
    assert len(output) == len(transformed_t_list_script)
349

350
    # test on batch of tensors
351
    batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
352
    torch.manual_seed(12)
353
    transformed_batch_list = fn(batch_tensors)
354

355
    for i in range(len(batch_tensors)):
356
        img_tensor = batch_tensors[i, ...]
357
        torch.manual_seed(12)
358
        transformed_img_list = fn(img_tensor)
359
        for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
360
            assert_equal(transformed_img, transformed_batch[i, ...])
361

362

363
@pytest.mark.parametrize("method", ["FiveCrop", "TenCrop"])
364
def test_x_crop_save_load(method, tmpdir):
365
    fn = getattr(T, method)(size=[5])
366
    _test_fn_save_load(fn, tmpdir)
367

368

369
class TestResize:
370
    @pytest.mark.parametrize("size", [32, 34, 35, 36, 38])
371
    def test_resize_int(self, size):
372
        # TODO: Minimal check for bug-fix, improve this later
373
        x = torch.rand(3, 32, 46)
374
        t = T.Resize(size=size, antialias=True)
375
        y = t(x)
376
        # If size is an int, smaller edge of the image will be matched to this number.
377
        # i.e, if height > width, then image will be rescaled to (size * height / width, size).
378
        assert isinstance(y, torch.Tensor)
379
        assert y.shape[1] == size
380
        assert y.shape[2] == int(size * 46 / 32)
381

382
    @pytest.mark.parametrize("device", cpu_and_cuda())
383
    @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64])
384
    @pytest.mark.parametrize("size", [[32], [32, 32], (32, 32), [34, 35]])
385
    @pytest.mark.parametrize("max_size", [None, 35, 1000])
386
    @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
387
    def test_resize_scripted(self, dt, size, max_size, interpolation, device):
388
        tensor, _ = _create_data(height=34, width=36, device=device)
389
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
390

391
        if dt is not None:
392
            # This is a trivial cast to float of uint8 data to test all cases
393
            tensor = tensor.to(dt)
394
        if max_size is not None and len(size) != 1:
395
            pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified")
396

397
        transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size, antialias=True)
398
        s_transform = torch.jit.script(transform)
399
        _test_transform_vs_scripted(transform, s_transform, tensor)
400
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
401

402
    def test_resize_save_load(self, tmpdir):
403
        fn = T.Resize(size=[32], antialias=True)
404
        _test_fn_save_load(fn, tmpdir)
405

406
    @pytest.mark.parametrize("device", cpu_and_cuda())
407
    @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
408
    @pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]])
409
    @pytest.mark.parametrize("size", [(32,), [44], [32], [32, 32], (32, 32), [44, 55]])
410
    @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC, NEAREST_EXACT])
411
    @pytest.mark.parametrize("antialias", [None, True, False])
412
    def test_resized_crop(self, scale, ratio, size, interpolation, antialias, device):
413

414
        if antialias and interpolation in {NEAREST, NEAREST_EXACT}:
415
            pytest.skip(f"Can not resize if interpolation mode is {interpolation} and antialias=True")
416

417
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
418
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
419
        transform = T.RandomResizedCrop(
420
            size=size, scale=scale, ratio=ratio, interpolation=interpolation, antialias=antialias
421
        )
422
        s_transform = torch.jit.script(transform)
423
        _test_transform_vs_scripted(transform, s_transform, tensor)
424
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
425

426
    def test_resized_crop_save_load(self, tmpdir):
427
        fn = T.RandomResizedCrop(size=[32], antialias=True)
428
        _test_fn_save_load(fn, tmpdir)
429

430

431
def _test_random_affine_helper(device, **kwargs):
432
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
433
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
434
    transform = T.RandomAffine(**kwargs)
435
    s_transform = torch.jit.script(transform)
436

437
    _test_transform_vs_scripted(transform, s_transform, tensor)
438
    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
439

440

441
def test_random_affine_save_load(tmpdir):
442
    fn = T.RandomAffine(degrees=45.0)
443
    _test_fn_save_load(fn, tmpdir)
444

445

446
@pytest.mark.parametrize("device", cpu_and_cuda())
447
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
448
@pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]])
449
def test_random_affine_shear(device, interpolation, shear):
450
    _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear)
451

452

453
@pytest.mark.parametrize("device", cpu_and_cuda())
454
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
455
@pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
456
def test_random_affine_scale(device, interpolation, scale):
457
    _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, scale=scale)
458

459

460
@pytest.mark.parametrize("device", cpu_and_cuda())
461
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
462
@pytest.mark.parametrize("translate", [(0.1, 0.2), [0.2, 0.1]])
463
def test_random_affine_translate(device, interpolation, translate):
464
    _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, translate=translate)
465

466

467
@pytest.mark.parametrize("device", cpu_and_cuda())
468
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
469
@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
470
def test_random_affine_degrees(device, interpolation, degrees):
471
    _test_random_affine_helper(device, degrees=degrees, interpolation=interpolation)
472

473

474
@pytest.mark.parametrize("device", cpu_and_cuda())
475
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
476
@pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
477
def test_random_affine_fill(device, interpolation, fill):
478
    _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, fill=fill)
479

480

481
@pytest.mark.parametrize("device", cpu_and_cuda())
482
@pytest.mark.parametrize("center", [(0, 0), [10, 10], None, (56, 44)])
483
@pytest.mark.parametrize("expand", [True, False])
484
@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
485
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
486
@pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
487
def test_random_rotate(device, center, expand, degrees, interpolation, fill):
488
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
489
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
490

491
    transform = T.RandomRotation(degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill)
492
    s_transform = torch.jit.script(transform)
493

494
    _test_transform_vs_scripted(transform, s_transform, tensor)
495
    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
496

497

498
def test_random_rotate_save_load(tmpdir):
499
    fn = T.RandomRotation(degrees=45.0)
500
    _test_fn_save_load(fn, tmpdir)
501

502

503
@pytest.mark.parametrize("device", cpu_and_cuda())
504
@pytest.mark.parametrize("distortion_scale", np.linspace(0.1, 1.0, num=20))
505
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
506
@pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
507
def test_random_perspective(device, distortion_scale, interpolation, fill):
508
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
509
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
510

511
    transform = T.RandomPerspective(distortion_scale=distortion_scale, interpolation=interpolation, fill=fill)
512
    s_transform = torch.jit.script(transform)
513

514
    _test_transform_vs_scripted(transform, s_transform, tensor)
515
    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
516

517

518
def test_random_perspective_save_load(tmpdir):
519
    fn = T.RandomPerspective()
520
    _test_fn_save_load(fn, tmpdir)
521

522

523
@pytest.mark.parametrize("device", cpu_and_cuda())
524
@pytest.mark.parametrize(
525
    "Klass, meth_kwargs",
526
    [(T.Grayscale, {"num_output_channels": 1}), (T.Grayscale, {"num_output_channels": 3}), (T.RandomGrayscale, {})],
527
)
528
def test_to_grayscale(device, Klass, meth_kwargs):
529
    tol = 1.0 + 1e-10
530
    _test_class_op(Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max")
531

532

533
@pytest.mark.parametrize("device", cpu_and_cuda())
534
@pytest.mark.parametrize("in_dtype", int_dtypes() + float_dtypes())
535
@pytest.mark.parametrize("out_dtype", int_dtypes() + float_dtypes())
536
def test_convert_image_dtype(device, in_dtype, out_dtype):
537
    tensor, _ = _create_data(26, 34, device=device)
538
    batch_tensors = torch.rand(4, 3, 44, 56, device=device)
539

540
    in_tensor = tensor.to(in_dtype)
541
    in_batch_tensors = batch_tensors.to(in_dtype)
542

543
    fn = T.ConvertImageDtype(dtype=out_dtype)
544
    scripted_fn = torch.jit.script(fn)
545

546
    if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or (
547
        in_dtype == torch.float64 and out_dtype == torch.int64
548
    ):
549
        with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
550
            _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
551
        with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
552
            _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
553
        return
554

555
    _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
556
    _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
557

558

559
def test_convert_image_dtype_save_load(tmpdir):
560
    fn = T.ConvertImageDtype(dtype=torch.uint8)
561
    _test_fn_save_load(fn, tmpdir)
562

563

564
@pytest.mark.parametrize("device", cpu_and_cuda())
565
@pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy])
566
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
567
def test_autoaugment(device, policy, fill):
568
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
569
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
570

571
    transform = T.AutoAugment(policy=policy, fill=fill)
572
    s_transform = torch.jit.script(transform)
573
    for _ in range(25):
574
        _test_transform_vs_scripted(transform, s_transform, tensor)
575
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
576

577

578
@pytest.mark.parametrize("device", cpu_and_cuda())
579
@pytest.mark.parametrize("num_ops", [1, 2, 3])
580
@pytest.mark.parametrize("magnitude", [7, 9, 11])
581
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
582
def test_randaugment(device, num_ops, magnitude, fill):
583
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
584
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
585

586
    transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
587
    s_transform = torch.jit.script(transform)
588
    for _ in range(25):
589
        _test_transform_vs_scripted(transform, s_transform, tensor)
590
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
591

592

593
@pytest.mark.parametrize("device", cpu_and_cuda())
594
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
595
def test_trivialaugmentwide(device, fill):
596
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
597
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
598

599
    transform = T.TrivialAugmentWide(fill=fill)
600
    s_transform = torch.jit.script(transform)
601
    for _ in range(25):
602
        _test_transform_vs_scripted(transform, s_transform, tensor)
603
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
604

605

606
@pytest.mark.parametrize("device", cpu_and_cuda())
607
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
608
def test_augmix(device, fill):
609
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
610
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
611

612
    class DeterministicAugMix(T.AugMix):
613
        def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
614
            # patch the method to ensure that the order of rand calls doesn't affect the outcome
615
            return params.softmax(dim=-1)
616

617
    transform = DeterministicAugMix(fill=fill)
618
    s_transform = torch.jit.script(transform)
619
    for _ in range(25):
620
        _test_transform_vs_scripted(transform, s_transform, tensor)
621
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
622

623

624
@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide, T.AugMix])
625
def test_autoaugment_save_load(augmentation, tmpdir):
626
    fn = augmentation()
627
    _test_fn_save_load(fn, tmpdir)
628

629

630
@pytest.mark.parametrize("interpolation", [F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR])
631
@pytest.mark.parametrize("mode", ["X", "Y"])
632
def test_autoaugment__op_apply_shear(interpolation, mode):
633
    # We check that torchvision's implementation of shear is equivalent
634
    # to official CIFAR10 autoaugment implementation:
635
    # https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L290
636
    image_size = 32
637

638
    def shear(pil_img, level, mode, resample):
639
        if mode == "X":
640
            matrix = (1, level, 0, 0, 1, 0)
641
        elif mode == "Y":
642
            matrix = (1, 0, 0, level, 1, 0)
643
        return pil_img.transform((image_size, image_size), PIL.Image.AFFINE, matrix, resample=resample)
644

645
    t_img, pil_img = _create_data(image_size, image_size)
646

647
    resample_pil = {
648
        F.InterpolationMode.NEAREST: PIL.Image.NEAREST,
649
        F.InterpolationMode.BILINEAR: PIL.Image.BILINEAR,
650
    }[interpolation]
651

652
    level = 0.3
653
    expected_out = shear(pil_img, level, mode=mode, resample=resample_pil)
654

655
    # Check pil output vs expected pil
656
    out = _apply_op(pil_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0)
657
    assert out == expected_out
658

659
    if interpolation == F.InterpolationMode.BILINEAR:
660
        # We skip bilinear mode for tensors as
661
        # affine transformation results are not exactly the same
662
        # between tensors and pil images
663
        # MAE as around 1.40
664
        # Max Abs error can be 163 or 170
665
        return
666

667
    # Check tensor output vs expected pil
668
    out = _apply_op(t_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0)
669
    _assert_approx_equal_tensor_to_pil(out, expected_out)
670

671

672
@pytest.mark.parametrize("device", cpu_and_cuda())
673
@pytest.mark.parametrize(
674
    "config",
675
    [
676
        {},
677
        {"value": 1},
678
        {"value": 0.2},
679
        {"value": "random"},
680
        {"value": (1, 1, 1)},
681
        {"value": (0.2, 0.2, 0.2)},
682
        {"value": [1, 1, 1]},
683
        {"value": [0.2, 0.2, 0.2]},
684
        {"value": "random", "ratio": (0.1, 0.2)},
685
    ],
686
)
687
def test_random_erasing(device, config):
688
    tensor, _ = _create_data(24, 32, channels=3, device=device)
689
    batch_tensors = torch.rand(4, 3, 44, 56, device=device)
690

691
    fn = T.RandomErasing(**config)
692
    scripted_fn = torch.jit.script(fn)
693
    _test_transform_vs_scripted(fn, scripted_fn, tensor)
694
    _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
695

696

697
def test_random_erasing_save_load(tmpdir):
698
    fn = T.RandomErasing(value=0.2)
699
    _test_fn_save_load(fn, tmpdir)
700

701

702
def test_random_erasing_with_invalid_data():
703
    img = torch.rand(3, 60, 60)
704
    # Test Set 0: invalid value
705
    random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
706
    with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value or 3"):
707
        random_erasing(img)
708

709

710
@pytest.mark.parametrize("device", cpu_and_cuda())
711
def test_normalize(device, tmpdir):
712
    fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
713
    tensor, _ = _create_data(26, 34, device=device)
714

715
    with pytest.raises(TypeError, match="Input tensor should be a float tensor"):
716
        fn(tensor)
717

718
    batch_tensors = torch.rand(4, 3, 44, 56, device=device)
719
    tensor = tensor.to(dtype=torch.float32) / 255.0
720
    # test for class interface
721
    scripted_fn = torch.jit.script(fn)
722

723
    _test_transform_vs_scripted(fn, scripted_fn, tensor)
724
    _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
725

726
    scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
727

728

729
@pytest.mark.parametrize("device", cpu_and_cuda())
730
def test_linear_transformation(device, tmpdir):
731
    c, h, w = 3, 24, 32
732

733
    tensor, _ = _create_data(h, w, channels=c, device=device)
734

735
    matrix = torch.rand(c * h * w, c * h * w, device=device)
736
    mean_vector = torch.rand(c * h * w, device=device)
737

738
    fn = T.LinearTransformation(matrix, mean_vector)
739
    scripted_fn = torch.jit.script(fn)
740

741
    _test_transform_vs_scripted(fn, scripted_fn, tensor)
742

743
    batch_tensors = torch.rand(4, c, h, w, device=device)
744
    # We skip some tests from _test_transform_vs_scripted_on_batch as
745
    # results for scripted and non-scripted transformations are not exactly the same
746
    torch.manual_seed(12)
747
    transformed_batch = fn(batch_tensors)
748
    torch.manual_seed(12)
749
    s_transformed_batch = scripted_fn(batch_tensors)
750
    assert_equal(transformed_batch, s_transformed_batch)
751

752
    scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
753

754

755
@pytest.mark.parametrize("device", cpu_and_cuda())
756
def test_compose(device):
757
    tensor, _ = _create_data(26, 34, device=device)
758
    tensor = tensor.to(dtype=torch.float32) / 255.0
759
    transforms = T.Compose(
760
        [
761
            T.CenterCrop(10),
762
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
763
        ]
764
    )
765
    s_transforms = torch.nn.Sequential(*transforms.transforms)
766

767
    scripted_fn = torch.jit.script(s_transforms)
768
    torch.manual_seed(12)
769
    transformed_tensor = transforms(tensor)
770
    torch.manual_seed(12)
771
    transformed_tensor_script = scripted_fn(tensor)
772
    assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
773

774
    t = T.Compose(
775
        [
776
            lambda x: x,
777
        ]
778
    )
779
    with pytest.raises(RuntimeError, match="cannot call a value of type 'Tensor'"):
780
        torch.jit.script(t)
781

782

783
@pytest.mark.parametrize("device", cpu_and_cuda())
784
def test_random_apply(device):
785
    tensor, _ = _create_data(26, 34, device=device)
786
    tensor = tensor.to(dtype=torch.float32) / 255.0
787

788
    transforms = T.RandomApply(
789
        [
790
            T.RandomHorizontalFlip(),
791
            T.ColorJitter(),
792
        ],
793
        p=0.4,
794
    )
795
    s_transforms = T.RandomApply(
796
        torch.nn.ModuleList(
797
            [
798
                T.RandomHorizontalFlip(),
799
                T.ColorJitter(),
800
            ]
801
        ),
802
        p=0.4,
803
    )
804

805
    scripted_fn = torch.jit.script(s_transforms)
806
    torch.manual_seed(12)
807
    transformed_tensor = transforms(tensor)
808
    torch.manual_seed(12)
809
    transformed_tensor_script = scripted_fn(tensor)
810
    assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
811

812
    if device == "cpu":
813
        # Can't check this twice, otherwise
814
        # "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply"
815
        transforms = T.RandomApply(
816
            [
817
                T.ColorJitter(),
818
            ],
819
            p=0.3,
820
        )
821
        with pytest.raises(RuntimeError, match="Module 'RandomApply' has no attribute 'transforms'"):
822
            torch.jit.script(transforms)
823

824

825
@pytest.mark.parametrize("device", cpu_and_cuda())
826
@pytest.mark.parametrize(
827
    "meth_kwargs",
828
    [
829
        {"kernel_size": 3, "sigma": 0.75},
830
        {"kernel_size": 23, "sigma": [0.1, 2.0]},
831
        {"kernel_size": 23, "sigma": (0.1, 2.0)},
832
        {"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
833
        {"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
834
        {"kernel_size": [23], "sigma": 0.75},
835
    ],
836
)
837
@pytest.mark.parametrize("channels", [1, 3])
838
def test_gaussian_blur(device, channels, meth_kwargs):
839
    if all(
840
        [
841
            device == "cuda",
842
            channels == 1,
843
            meth_kwargs["kernel_size"] in [23, [23]],
844
            torch.version.cuda == "11.3",
845
            sys.platform in ("win32", "cygwin"),
846
        ]
847
    ):
848
        pytest.skip("Fails on Windows, see https://github.com/pytorch/vision/issues/5464")
849

850
    tol = 1.0 + 1e-10
851
    torch.manual_seed(12)
852
    _test_class_op(
853
        T.GaussianBlur,
854
        meth_kwargs=meth_kwargs,
855
        channels=channels,
856
        test_exact_match=False,
857
        device=device,
858
        agg_method="max",
859
        tol=tol,
860
    )
861

862

863
@pytest.mark.parametrize("device", cpu_and_cuda())
864
@pytest.mark.parametrize(
865
    "fill",
866
    [
867
        1,
868
        1.0,
869
        [1],
870
        [1.0],
871
        (1,),
872
        (1.0,),
873
        [1, 2, 3],
874
        [1.0, 2.0, 3.0],
875
        (1, 2, 3),
876
        (1.0, 2.0, 3.0),
877
    ],
878
)
879
@pytest.mark.parametrize("channels", [1, 3])
880
def test_elastic_transform(device, channels, fill):
881
    if isinstance(fill, (list, tuple)) and len(fill) > 1 and channels == 1:
882
        # For this the test would correctly fail, since the number of channels in the image does not match `fill`.
883
        # Thus, this is not an issue in the transform, but rather a problem of parametrization that just gives the
884
        # product of `fill` and `channels`.
885
        return
886

887
    _test_class_op(
888
        T.ElasticTransform,
889
        meth_kwargs=dict(fill=fill),
890
        channels=channels,
891
        device=device,
892
    )
893

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.