8
from common_utils import (
9
_assert_approx_equal_tensor_to_pil,
10
_assert_equal_tensor_to_pil,
19
from torchvision import transforms as T
20
from torchvision.transforms import functional as F, InterpolationMode
21
from torchvision.transforms.autoaugment import _apply_op
23
NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
24
InterpolationMode.NEAREST,
25
InterpolationMode.NEAREST_EXACT,
26
InterpolationMode.BILINEAR,
27
InterpolationMode.BICUBIC,
31
def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
33
out1 = transform(tensor)
35
out2 = s_transform(tensor)
36
assert_equal(out1, out2, msg=msg)
39
def _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors, msg=None):
41
transformed_batch = transform(batch_tensors)
43
for i in range(len(batch_tensors)):
44
img_tensor = batch_tensors[i, ...]
46
transformed_img = transform(img_tensor)
47
assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
50
s_transformed_batch = s_transform(batch_tensors)
51
assert_equal(transformed_batch, s_transformed_batch, msg=msg)
54
def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=True, **match_kwargs):
55
fn_kwargs = fn_kwargs or {}
57
tensor, pil_img = _create_data(height=10, width=10, channels=channels, device=device)
58
transformed_tensor = f(tensor, **fn_kwargs)
59
transformed_pil_img = f(pil_img, **fn_kwargs)
61
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
63
_assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
66
def _test_class_op(transform_cls, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs):
67
meth_kwargs = meth_kwargs or {}
70
f = transform_cls(**meth_kwargs)
71
scripted_fn = torch.jit.script(f)
73
tensor, pil_img = _create_data(26, 34, channels, device=device)
76
transformed_tensor = f(tensor)
78
transformed_pil_img = f(pil_img)
80
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
82
_assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
85
transformed_tensor_script = scripted_fn(tensor)
86
assert_equal(transformed_tensor, transformed_tensor_script)
88
batch_tensors = _create_data_batch(height=23, width=34, channels=channels, num_samples=4, device=device)
89
_test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
91
with get_tmp_dir() as tmp_dir:
92
scripted_fn.save(os.path.join(tmp_dir, f"t_{transform_cls.__name__}.pt"))
95
def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
96
_test_functional_op(func, device, channels, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
97
_test_class_op(method, device, channels, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
100
def _test_fn_save_load(fn, tmpdir):
101
scripted_fn = torch.jit.script(fn)
102
p = os.path.join(tmpdir, f"t_op_list_{getattr(fn, '__name__', fn.__class__.__name__)}.pt")
104
_ = torch.jit.load(p)
107
@pytest.mark.parametrize("device", cpu_and_cuda())
108
@pytest.mark.parametrize(
109
"func,method,fn_kwargs,match_kwargs",
111
(F.hflip, T.RandomHorizontalFlip, None, {}),
112
(F.vflip, T.RandomVerticalFlip, None, {}),
113
(F.invert, T.RandomInvert, None, {}),
114
(F.posterize, T.RandomPosterize, {"bits": 4}, {}),
115
(F.solarize, T.RandomSolarize, {"threshold": 192.0}, {}),
116
(F.adjust_sharpness, T.RandomAdjustSharpness, {"sharpness_factor": 2.0}, {}),
119
T.RandomAutocontrast,
121
{"test_exact_match": False, "agg_method": "max", "tol": (1 + 1e-5), "allowed_percentage_diff": 0.05},
123
(F.equalize, T.RandomEqualize, None, {}),
126
@pytest.mark.parametrize("channels", [1, 3])
127
def test_random(func, method, device, channels, fn_kwargs, match_kwargs):
128
_test_op(func, method, device, channels, fn_kwargs, fn_kwargs, **match_kwargs)
131
@pytest.mark.parametrize("seed", range(10))
132
@pytest.mark.parametrize("device", cpu_and_cuda())
133
@pytest.mark.parametrize("channels", [1, 3])
134
class TestColorJitter:
135
@pytest.fixture(autouse=True)
136
def set_random_seed(self, seed):
137
torch.random.manual_seed(seed)
139
@pytest.mark.parametrize("brightness", [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]])
140
def test_color_jitter_brightness(self, brightness, device, channels):
142
meth_kwargs = {"brightness": brightness}
145
meth_kwargs=meth_kwargs,
146
test_exact_match=False,
153
@pytest.mark.parametrize("contrast", [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]])
154
def test_color_jitter_contrast(self, contrast, device, channels):
156
meth_kwargs = {"contrast": contrast}
159
meth_kwargs=meth_kwargs,
160
test_exact_match=False,
167
@pytest.mark.parametrize("saturation", [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]])
168
def test_color_jitter_saturation(self, saturation, device, channels):
170
meth_kwargs = {"saturation": saturation}
173
meth_kwargs=meth_kwargs,
174
test_exact_match=False,
181
@pytest.mark.parametrize("hue", [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]])
182
def test_color_jitter_hue(self, hue, device, channels):
183
meth_kwargs = {"hue": hue}
186
meth_kwargs=meth_kwargs,
187
test_exact_match=False,
194
def test_color_jitter_all(self, device, channels):
196
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
199
meth_kwargs=meth_kwargs,
200
test_exact_match=False,
208
@pytest.mark.parametrize("device", cpu_and_cuda())
209
@pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"])
210
@pytest.mark.parametrize("mul", [1, -1])
211
def test_pad(m, mul, device):
212
fill = 127 if m == "constant" else 0
215
_test_functional_op(F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, device=device)
217
fn_kwargs = meth_kwargs = {
218
"padding": [mul * 2],
222
_test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
224
fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
225
_test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
227
fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
228
_test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
231
@pytest.mark.parametrize("device", cpu_and_cuda())
232
def test_crop(device):
233
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
238
"pad_if_needed": True,
240
_test_op(F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
243
fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5}
244
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
246
fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5}
247
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
249
fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5}
250
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
252
fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5}
253
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
255
fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15}
256
_test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)
259
@pytest.mark.parametrize("device", cpu_and_cuda())
260
@pytest.mark.parametrize(
263
{"padding_mode": "constant", "fill": 0},
264
{"padding_mode": "constant", "fill": 10},
265
{"padding_mode": "edge"},
266
{"padding_mode": "reflect"},
269
@pytest.mark.parametrize("pad_if_needed", [True, False])
270
@pytest.mark.parametrize("padding", [[5], [5, 4], [1, 2, 3, 4]])
271
@pytest.mark.parametrize("size", [5, [5], [6, 6]])
272
def test_random_crop(size, padding, pad_if_needed, padding_config, device):
273
config = dict(padding_config)
274
config["size"] = size
275
config["padding"] = padding
276
config["pad_if_needed"] = pad_if_needed
277
_test_class_op(T.RandomCrop, device, meth_kwargs=config)
280
def test_random_crop_save_load(tmpdir):
281
fn = T.RandomCrop(32, [4], pad_if_needed=True)
282
_test_fn_save_load(fn, tmpdir)
285
@pytest.mark.parametrize("device", cpu_and_cuda())
286
def test_center_crop(device, tmpdir):
287
fn_kwargs = {"output_size": (4, 5)}
288
meth_kwargs = {"size": (4, 5)}
289
_test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
290
fn_kwargs = {"output_size": (5,)}
291
meth_kwargs = {"size": (5,)}
292
_test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
293
tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=device)
295
f = T.CenterCrop(size=5)
296
scripted_fn = torch.jit.script(f)
300
f = T.CenterCrop(size=[5])
301
scripted_fn = torch.jit.script(f)
305
f = T.CenterCrop(size=(6, 6))
306
scripted_fn = torch.jit.script(f)
310
def test_center_crop_save_load(tmpdir):
311
fn = T.CenterCrop(size=[5])
312
_test_fn_save_load(fn, tmpdir)
315
@pytest.mark.parametrize("device", cpu_and_cuda())
316
@pytest.mark.parametrize(
317
"fn, method, out_length",
320
(F.five_crop, T.FiveCrop, 5),
322
(F.ten_crop, T.TenCrop, 10),
325
@pytest.mark.parametrize("size", [(5,), [5], (4, 5), [4, 5]])
326
def test_x_crop(fn, method, out_length, size, device):
327
meth_kwargs = fn_kwargs = {"size": size}
328
scripted_fn = torch.jit.script(fn)
330
tensor, pil_img = _create_data(height=20, width=20, device=device)
331
transformed_t_list = fn(tensor, **fn_kwargs)
332
transformed_p_list = fn(pil_img, **fn_kwargs)
333
assert len(transformed_t_list) == len(transformed_p_list)
334
assert len(transformed_t_list) == out_length
335
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
336
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)
338
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
339
assert len(transformed_t_list) == len(transformed_t_list_script)
340
assert len(transformed_t_list_script) == out_length
341
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
342
assert_equal(transformed_tensor, transformed_tensor_script)
345
fn = method(**meth_kwargs)
346
scripted_fn = torch.jit.script(fn)
347
output = scripted_fn(tensor)
348
assert len(output) == len(transformed_t_list_script)
351
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
352
torch.manual_seed(12)
353
transformed_batch_list = fn(batch_tensors)
355
for i in range(len(batch_tensors)):
356
img_tensor = batch_tensors[i, ...]
357
torch.manual_seed(12)
358
transformed_img_list = fn(img_tensor)
359
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
360
assert_equal(transformed_img, transformed_batch[i, ...])
363
@pytest.mark.parametrize("method", ["FiveCrop", "TenCrop"])
364
def test_x_crop_save_load(method, tmpdir):
365
fn = getattr(T, method)(size=[5])
366
_test_fn_save_load(fn, tmpdir)
370
@pytest.mark.parametrize("size", [32, 34, 35, 36, 38])
371
def test_resize_int(self, size):
373
x = torch.rand(3, 32, 46)
374
t = T.Resize(size=size, antialias=True)
378
assert isinstance(y, torch.Tensor)
379
assert y.shape[1] == size
380
assert y.shape[2] == int(size * 46 / 32)
382
@pytest.mark.parametrize("device", cpu_and_cuda())
383
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64])
384
@pytest.mark.parametrize("size", [[32], [32, 32], (32, 32), [34, 35]])
385
@pytest.mark.parametrize("max_size", [None, 35, 1000])
386
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
387
def test_resize_scripted(self, dt, size, max_size, interpolation, device):
388
tensor, _ = _create_data(height=34, width=36, device=device)
389
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
393
tensor = tensor.to(dt)
394
if max_size is not None and len(size) != 1:
395
pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified")
397
transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size, antialias=True)
398
s_transform = torch.jit.script(transform)
399
_test_transform_vs_scripted(transform, s_transform, tensor)
400
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
402
def test_resize_save_load(self, tmpdir):
403
fn = T.Resize(size=[32], antialias=True)
404
_test_fn_save_load(fn, tmpdir)
406
@pytest.mark.parametrize("device", cpu_and_cuda())
407
@pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
408
@pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]])
409
@pytest.mark.parametrize("size", [(32,), [44], [32], [32, 32], (32, 32), [44, 55]])
410
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC, NEAREST_EXACT])
411
@pytest.mark.parametrize("antialias", [None, True, False])
412
def test_resized_crop(self, scale, ratio, size, interpolation, antialias, device):
414
if antialias and interpolation in {NEAREST, NEAREST_EXACT}:
415
pytest.skip(f"Can not resize if interpolation mode is {interpolation} and antialias=True")
417
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
418
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
419
transform = T.RandomResizedCrop(
420
size=size, scale=scale, ratio=ratio, interpolation=interpolation, antialias=antialias
422
s_transform = torch.jit.script(transform)
423
_test_transform_vs_scripted(transform, s_transform, tensor)
424
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
426
def test_resized_crop_save_load(self, tmpdir):
427
fn = T.RandomResizedCrop(size=[32], antialias=True)
428
_test_fn_save_load(fn, tmpdir)
431
def _test_random_affine_helper(device, **kwargs):
432
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
433
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
434
transform = T.RandomAffine(**kwargs)
435
s_transform = torch.jit.script(transform)
437
_test_transform_vs_scripted(transform, s_transform, tensor)
438
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
441
def test_random_affine_save_load(tmpdir):
442
fn = T.RandomAffine(degrees=45.0)
443
_test_fn_save_load(fn, tmpdir)
446
@pytest.mark.parametrize("device", cpu_and_cuda())
447
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
448
@pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]])
449
def test_random_affine_shear(device, interpolation, shear):
450
_test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear)
453
@pytest.mark.parametrize("device", cpu_and_cuda())
454
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
455
@pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
456
def test_random_affine_scale(device, interpolation, scale):
457
_test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, scale=scale)
460
@pytest.mark.parametrize("device", cpu_and_cuda())
461
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
462
@pytest.mark.parametrize("translate", [(0.1, 0.2), [0.2, 0.1]])
463
def test_random_affine_translate(device, interpolation, translate):
464
_test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, translate=translate)
467
@pytest.mark.parametrize("device", cpu_and_cuda())
468
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
469
@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
470
def test_random_affine_degrees(device, interpolation, degrees):
471
_test_random_affine_helper(device, degrees=degrees, interpolation=interpolation)
474
@pytest.mark.parametrize("device", cpu_and_cuda())
475
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
476
@pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
477
def test_random_affine_fill(device, interpolation, fill):
478
_test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, fill=fill)
481
@pytest.mark.parametrize("device", cpu_and_cuda())
482
@pytest.mark.parametrize("center", [(0, 0), [10, 10], None, (56, 44)])
483
@pytest.mark.parametrize("expand", [True, False])
484
@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
485
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
486
@pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
487
def test_random_rotate(device, center, expand, degrees, interpolation, fill):
488
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
489
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
491
transform = T.RandomRotation(degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill)
492
s_transform = torch.jit.script(transform)
494
_test_transform_vs_scripted(transform, s_transform, tensor)
495
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
498
def test_random_rotate_save_load(tmpdir):
499
fn = T.RandomRotation(degrees=45.0)
500
_test_fn_save_load(fn, tmpdir)
503
@pytest.mark.parametrize("device", cpu_and_cuda())
504
@pytest.mark.parametrize("distortion_scale", np.linspace(0.1, 1.0, num=20))
505
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
506
@pytest.mark.parametrize("fill", [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
507
def test_random_perspective(device, distortion_scale, interpolation, fill):
508
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
509
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
511
transform = T.RandomPerspective(distortion_scale=distortion_scale, interpolation=interpolation, fill=fill)
512
s_transform = torch.jit.script(transform)
514
_test_transform_vs_scripted(transform, s_transform, tensor)
515
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
518
def test_random_perspective_save_load(tmpdir):
519
fn = T.RandomPerspective()
520
_test_fn_save_load(fn, tmpdir)
523
@pytest.mark.parametrize("device", cpu_and_cuda())
524
@pytest.mark.parametrize(
525
"Klass, meth_kwargs",
526
[(T.Grayscale, {"num_output_channels": 1}), (T.Grayscale, {"num_output_channels": 3}), (T.RandomGrayscale, {})],
528
def test_to_grayscale(device, Klass, meth_kwargs):
530
_test_class_op(Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max")
533
@pytest.mark.parametrize("device", cpu_and_cuda())
534
@pytest.mark.parametrize("in_dtype", int_dtypes() + float_dtypes())
535
@pytest.mark.parametrize("out_dtype", int_dtypes() + float_dtypes())
536
def test_convert_image_dtype(device, in_dtype, out_dtype):
537
tensor, _ = _create_data(26, 34, device=device)
538
batch_tensors = torch.rand(4, 3, 44, 56, device=device)
540
in_tensor = tensor.to(in_dtype)
541
in_batch_tensors = batch_tensors.to(in_dtype)
543
fn = T.ConvertImageDtype(dtype=out_dtype)
544
scripted_fn = torch.jit.script(fn)
546
if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or (
547
in_dtype == torch.float64 and out_dtype == torch.int64
549
with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
550
_test_transform_vs_scripted(fn, scripted_fn, in_tensor)
551
with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
552
_test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
555
_test_transform_vs_scripted(fn, scripted_fn, in_tensor)
556
_test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
559
def test_convert_image_dtype_save_load(tmpdir):
560
fn = T.ConvertImageDtype(dtype=torch.uint8)
561
_test_fn_save_load(fn, tmpdir)
564
@pytest.mark.parametrize("device", cpu_and_cuda())
565
@pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy])
566
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
567
def test_autoaugment(device, policy, fill):
568
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
569
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
571
transform = T.AutoAugment(policy=policy, fill=fill)
572
s_transform = torch.jit.script(transform)
574
_test_transform_vs_scripted(transform, s_transform, tensor)
575
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
578
@pytest.mark.parametrize("device", cpu_and_cuda())
579
@pytest.mark.parametrize("num_ops", [1, 2, 3])
580
@pytest.mark.parametrize("magnitude", [7, 9, 11])
581
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
582
def test_randaugment(device, num_ops, magnitude, fill):
583
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
584
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
586
transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
587
s_transform = torch.jit.script(transform)
589
_test_transform_vs_scripted(transform, s_transform, tensor)
590
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
593
@pytest.mark.parametrize("device", cpu_and_cuda())
594
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
595
def test_trivialaugmentwide(device, fill):
596
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
597
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
599
transform = T.TrivialAugmentWide(fill=fill)
600
s_transform = torch.jit.script(transform)
602
_test_transform_vs_scripted(transform, s_transform, tensor)
603
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
606
@pytest.mark.parametrize("device", cpu_and_cuda())
607
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
608
def test_augmix(device, fill):
609
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
610
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
612
class DeterministicAugMix(T.AugMix):
613
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
615
return params.softmax(dim=-1)
617
transform = DeterministicAugMix(fill=fill)
618
s_transform = torch.jit.script(transform)
620
_test_transform_vs_scripted(transform, s_transform, tensor)
621
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
624
@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide, T.AugMix])
625
def test_autoaugment_save_load(augmentation, tmpdir):
627
_test_fn_save_load(fn, tmpdir)
630
@pytest.mark.parametrize("interpolation", [F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR])
631
@pytest.mark.parametrize("mode", ["X", "Y"])
632
def test_autoaugment__op_apply_shear(interpolation, mode):
638
def shear(pil_img, level, mode, resample):
640
matrix = (1, level, 0, 0, 1, 0)
642
matrix = (1, 0, 0, level, 1, 0)
643
return pil_img.transform((image_size, image_size), PIL.Image.AFFINE, matrix, resample=resample)
645
t_img, pil_img = _create_data(image_size, image_size)
648
F.InterpolationMode.NEAREST: PIL.Image.NEAREST,
649
F.InterpolationMode.BILINEAR: PIL.Image.BILINEAR,
653
expected_out = shear(pil_img, level, mode=mode, resample=resample_pil)
656
out = _apply_op(pil_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0)
657
assert out == expected_out
659
if interpolation == F.InterpolationMode.BILINEAR:
668
out = _apply_op(t_img, op_name=f"Shear{mode}", magnitude=level, interpolation=interpolation, fill=0)
669
_assert_approx_equal_tensor_to_pil(out, expected_out)
672
@pytest.mark.parametrize("device", cpu_and_cuda())
673
@pytest.mark.parametrize(
680
{"value": (1, 1, 1)},
681
{"value": (0.2, 0.2, 0.2)},
682
{"value": [1, 1, 1]},
683
{"value": [0.2, 0.2, 0.2]},
684
{"value": "random", "ratio": (0.1, 0.2)},
687
def test_random_erasing(device, config):
688
tensor, _ = _create_data(24, 32, channels=3, device=device)
689
batch_tensors = torch.rand(4, 3, 44, 56, device=device)
691
fn = T.RandomErasing(**config)
692
scripted_fn = torch.jit.script(fn)
693
_test_transform_vs_scripted(fn, scripted_fn, tensor)
694
_test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
697
def test_random_erasing_save_load(tmpdir):
698
fn = T.RandomErasing(value=0.2)
699
_test_fn_save_load(fn, tmpdir)
702
def test_random_erasing_with_invalid_data():
703
img = torch.rand(3, 60, 60)
705
random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
706
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value or 3"):
710
@pytest.mark.parametrize("device", cpu_and_cuda())
711
def test_normalize(device, tmpdir):
712
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
713
tensor, _ = _create_data(26, 34, device=device)
715
with pytest.raises(TypeError, match="Input tensor should be a float tensor"):
718
batch_tensors = torch.rand(4, 3, 44, 56, device=device)
719
tensor = tensor.to(dtype=torch.float32) / 255.0
721
scripted_fn = torch.jit.script(fn)
723
_test_transform_vs_scripted(fn, scripted_fn, tensor)
724
_test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
726
scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
729
@pytest.mark.parametrize("device", cpu_and_cuda())
730
def test_linear_transformation(device, tmpdir):
733
tensor, _ = _create_data(h, w, channels=c, device=device)
735
matrix = torch.rand(c * h * w, c * h * w, device=device)
736
mean_vector = torch.rand(c * h * w, device=device)
738
fn = T.LinearTransformation(matrix, mean_vector)
739
scripted_fn = torch.jit.script(fn)
741
_test_transform_vs_scripted(fn, scripted_fn, tensor)
743
batch_tensors = torch.rand(4, c, h, w, device=device)
746
torch.manual_seed(12)
747
transformed_batch = fn(batch_tensors)
748
torch.manual_seed(12)
749
s_transformed_batch = scripted_fn(batch_tensors)
750
assert_equal(transformed_batch, s_transformed_batch)
752
scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
755
@pytest.mark.parametrize("device", cpu_and_cuda())
756
def test_compose(device):
757
tensor, _ = _create_data(26, 34, device=device)
758
tensor = tensor.to(dtype=torch.float32) / 255.0
759
transforms = T.Compose(
762
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
765
s_transforms = torch.nn.Sequential(*transforms.transforms)
767
scripted_fn = torch.jit.script(s_transforms)
768
torch.manual_seed(12)
769
transformed_tensor = transforms(tensor)
770
torch.manual_seed(12)
771
transformed_tensor_script = scripted_fn(tensor)
772
assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
779
with pytest.raises(RuntimeError, match="cannot call a value of type 'Tensor'"):
783
@pytest.mark.parametrize("device", cpu_and_cuda())
784
def test_random_apply(device):
785
tensor, _ = _create_data(26, 34, device=device)
786
tensor = tensor.to(dtype=torch.float32) / 255.0
788
transforms = T.RandomApply(
790
T.RandomHorizontalFlip(),
795
s_transforms = T.RandomApply(
798
T.RandomHorizontalFlip(),
805
scripted_fn = torch.jit.script(s_transforms)
806
torch.manual_seed(12)
807
transformed_tensor = transforms(tensor)
808
torch.manual_seed(12)
809
transformed_tensor_script = scripted_fn(tensor)
810
assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
815
transforms = T.RandomApply(
821
with pytest.raises(RuntimeError, match="Module 'RandomApply' has no attribute 'transforms'"):
822
torch.jit.script(transforms)
825
@pytest.mark.parametrize("device", cpu_and_cuda())
826
@pytest.mark.parametrize(
829
{"kernel_size": 3, "sigma": 0.75},
830
{"kernel_size": 23, "sigma": [0.1, 2.0]},
831
{"kernel_size": 23, "sigma": (0.1, 2.0)},
832
{"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
833
{"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
834
{"kernel_size": [23], "sigma": 0.75},
837
@pytest.mark.parametrize("channels", [1, 3])
838
def test_gaussian_blur(device, channels, meth_kwargs):
843
meth_kwargs["kernel_size"] in [23, [23]],
844
torch.version.cuda == "11.3",
845
sys.platform in ("win32", "cygwin"),
848
pytest.skip("Fails on Windows, see https://github.com/pytorch/vision/issues/5464")
851
torch.manual_seed(12)
854
meth_kwargs=meth_kwargs,
856
test_exact_match=False,
863
@pytest.mark.parametrize("device", cpu_and_cuda())
864
@pytest.mark.parametrize(
879
@pytest.mark.parametrize("channels", [1, 3])
880
def test_elastic_transform(device, channels, fill):
881
if isinstance(fill, (list, tuple)) and len(fill) > 1 and channels == 1:
889
meth_kwargs=dict(fill=fill),