5
from functools import partial
6
from typing import Sequence
12
import torchvision.transforms as T
13
import torchvision.transforms._functional_pil as F_pil
14
import torchvision.transforms._functional_tensor as F_t
15
import torchvision.transforms.functional as F
16
from common_utils import (
17
_assert_approx_equal_tensor_to_pil,
18
_assert_equal_tensor_to_pil,
26
from torchvision.transforms import InterpolationMode
28
NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
29
InterpolationMode.NEAREST,
30
InterpolationMode.NEAREST_EXACT,
31
InterpolationMode.BILINEAR,
32
InterpolationMode.BICUBIC,
36
@pytest.mark.parametrize("device", cpu_and_cuda())
37
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions])
38
def test_image_sizes(device, fn):
39
script_F = torch.jit.script(fn)
41
img_tensor, pil_img = _create_data(16, 18, 3, device=device)
42
value_img = fn(img_tensor)
43
value_pil_img = fn(pil_img)
44
assert value_img == value_pil_img
46
value_img_script = script_F(img_tensor)
47
assert value_img == value_img_script
49
batch_tensors = _create_data_batch(16, 18, 3, num_samples=4, device=device)
50
value_img_batch = fn(batch_tensors)
51
assert value_img == value_img_batch
55
def test_scale_channel():
56
"""Make sure that _scale_channel gives the same results on CPU and GPU as
57
histc or bincount are used depending on the device.
59
# TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed,
60
# only use bincount and remove that test.
62
img_chan = torch.randint(0, 256, size=size).to("cpu")
63
scaled_cpu = F_t._scale_channel(img_chan)
64
scaled_cuda = F_t._scale_channel(img_chan.to("cuda"))
65
assert_equal(scaled_cpu, scaled_cuda.to("cpu"))
70
ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16]
71
scripted_rotate = torch.jit.script(F.rotate)
74
@pytest.mark.parametrize("device", cpu_and_cuda())
75
@pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)])
76
@pytest.mark.parametrize(
80
(int(IMG_W * 0.3), int(IMG_W * 0.4)),
81
[int(IMG_W * 0.5), int(IMG_W * 0.6)],
84
@pytest.mark.parametrize("dt", ALL_DTYPES)
85
@pytest.mark.parametrize("angle", range(-180, 180, 34))
86
@pytest.mark.parametrize("expand", [True, False])
87
@pytest.mark.parametrize(
100
@pytest.mark.parametrize("fn", [F.rotate, scripted_rotate])
101
def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn):
102
tensor, pil_img = _create_data(height, width, device=device)
104
if dt == torch.float16 and torch.device(device).type == "cpu":
105
# skip float16 on CPU case
109
tensor = tensor.to(dtype=dt)
111
f_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill
112
out_pil_img = F.rotate(pil_img, angle=angle, interpolation=NEAREST, expand=expand, center=center, fill=f_pil)
113
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
115
out_tensor = fn(tensor, angle=angle, interpolation=NEAREST, expand=expand, center=center, fill=fill).cpu()
117
if out_tensor.dtype != torch.uint8:
118
out_tensor = out_tensor.to(torch.uint8)
121
out_tensor.shape == out_pil_tensor.shape
122
), f"{(height, width, NEAREST, dt, angle, expand, center)}: {out_tensor.shape} vs {out_pil_tensor.shape}"
124
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
125
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
126
# Tolerance : less than 3% of different pixels
127
assert ratio_diff_pixels < 0.03, (
128
f"{(height, width, NEAREST, dt, angle, expand, center, fill)}: "
129
f"{ratio_diff_pixels}\n{out_tensor[0, :7, :7]} vs \n"
130
f"{out_pil_tensor[0, :7, :7]}"
133
@pytest.mark.parametrize("device", cpu_and_cuda())
134
@pytest.mark.parametrize("dt", ALL_DTYPES)
135
def test_rotate_batch(self, device, dt):
136
if dt == torch.float16 and device == "cpu":
137
# skip float16 on CPU case
140
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
142
batch_tensors = batch_tensors.to(dtype=dt)
145
_test_fn_on_batch(batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center)
147
def test_rotate_interpolation_type(self):
148
tensor, _ = _create_data(26, 26)
149
res1 = F.rotate(tensor, 45, interpolation=PIL.Image.BILINEAR)
150
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
151
assert_equal(res1, res2)
156
ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16]
157
scripted_affine = torch.jit.script(F.affine)
159
@pytest.mark.parametrize("device", cpu_and_cuda())
160
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
161
@pytest.mark.parametrize("dt", ALL_DTYPES)
162
def test_identity_map(self, device, height, width, dt):
163
# Tests on square and rectangular images
164
tensor, pil_img = _create_data(height, width, device=device)
166
if dt == torch.float16 and device == "cpu":
167
# skip float16 on CPU case
171
tensor = tensor.to(dtype=dt)
174
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
176
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
177
out_tensor = self.scripted_affine(
178
tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
180
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
182
@pytest.mark.parametrize("device", cpu_and_cuda())
183
@pytest.mark.parametrize("height, width", [(26, 26)])
184
@pytest.mark.parametrize("dt", ALL_DTYPES)
185
@pytest.mark.parametrize(
188
(90, {"k": 1, "dims": (-1, -2)}),
193
(-90, {"k": -1, "dims": (-1, -2)}),
194
(180, {"k": 2, "dims": (-1, -2)}),
197
@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
198
def test_square_rotations(self, device, height, width, dt, angle, config, fn):
200
tensor, pil_img = _create_data(height, width, device=device)
202
if dt == torch.float16 and device == "cpu":
203
# skip float16 on CPU case
207
tensor = tensor.to(dtype=dt)
209
out_pil_img = F.affine(
210
pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
212
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(device)
214
out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
215
if config is not None:
216
assert_equal(torch.rot90(tensor, **config), out_tensor)
218
if out_tensor.dtype != torch.uint8:
219
out_tensor = out_tensor.to(torch.uint8)
221
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
222
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
223
# Tolerance : less than 6% of different pixels
224
assert ratio_diff_pixels < 0.06
226
@pytest.mark.parametrize("device", cpu_and_cuda())
227
@pytest.mark.parametrize("height, width", [(32, 26)])
228
@pytest.mark.parametrize("dt", ALL_DTYPES)
229
@pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120])
230
@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
231
@pytest.mark.parametrize("center", [None, [0, 0]])
232
def test_rect_rotations(self, device, height, width, dt, angle, fn, center):
233
# Tests on rectangular images
234
tensor, pil_img = _create_data(height, width, device=device)
236
if dt == torch.float16 and device == "cpu":
237
# skip float16 on CPU case
241
tensor = tensor.to(dtype=dt)
243
out_pil_img = F.affine(
244
pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
246
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
249
tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
252
if out_tensor.dtype != torch.uint8:
253
out_tensor = out_tensor.to(torch.uint8)
255
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
256
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
257
# Tolerance : less than 3% of different pixels
258
assert ratio_diff_pixels < 0.03
260
@pytest.mark.parametrize("device", cpu_and_cuda())
261
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
262
@pytest.mark.parametrize("dt", ALL_DTYPES)
263
@pytest.mark.parametrize("t", [[10, 12], (-12, -13)])
264
@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
265
def test_translations(self, device, height, width, dt, t, fn):
266
# 3) Test translation
267
tensor, pil_img = _create_data(height, width, device=device)
269
if dt == torch.float16 and device == "cpu":
270
# skip float16 on CPU case
274
tensor = tensor.to(dtype=dt)
276
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
278
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
280
if out_tensor.dtype != torch.uint8:
281
out_tensor = out_tensor.to(torch.uint8)
283
_assert_equal_tensor_to_pil(out_tensor, out_pil_img)
285
@pytest.mark.parametrize("device", cpu_and_cuda())
286
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
287
@pytest.mark.parametrize("dt", ALL_DTYPES)
288
@pytest.mark.parametrize(
291
(45.5, [5, 6], 1.0, [0.0, 0.0], None),
292
(33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
293
(45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
294
(33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
295
(85, (10, -10), 0.7, [0.0, 0.0], [1]),
296
(0, [0, 0], 1.0, [35.0], (2.0,)),
297
(-25, [0, 0], 1.2, [0.0, 15.0], None),
298
(-45, [-10, 0], 0.7, [2.0, 5.0], None),
299
(-45, [-10, -10], 1.2, [4.0, 5.0], None),
300
(-90, [0, 0], 1.0, [0.0, 0.0], None),
303
@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
304
def test_all_ops(self, device, height, width, dt, a, t, s, sh, f, fn):
305
# 4) Test rotation + translation + scale + shear
306
tensor, pil_img = _create_data(height, width, device=device)
308
if dt == torch.float16 and device == "cpu":
309
# skip float16 on CPU case
313
tensor = tensor.to(dtype=dt)
315
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
316
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=NEAREST, fill=f_pil)
317
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
319
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=NEAREST, fill=f).cpu()
321
if out_tensor.dtype != torch.uint8:
322
out_tensor = out_tensor.to(torch.uint8)
324
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
325
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
326
# Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
327
tol = 0.06 if device == "cuda" else 0.05
328
assert ratio_diff_pixels < tol
330
@pytest.mark.parametrize("device", cpu_and_cuda())
331
@pytest.mark.parametrize("dt", ALL_DTYPES)
332
def test_batches(self, device, dt):
333
if dt == torch.float16 and device == "cpu":
334
# skip float16 on CPU case
337
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
339
batch_tensors = batch_tensors.to(dtype=dt)
341
_test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0])
343
@pytest.mark.parametrize("device", cpu_and_cuda())
344
def test_interpolation_type(self, device):
345
tensor, pil_img = _create_data(26, 26, device=device)
347
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=PIL.Image.BILINEAR)
348
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
349
assert_equal(res1, res2)
352
def _get_data_dims_and_points_for_perspective():
353
# Ideally we would parametrize independently over data dims and points, but
354
# we want to tests on some points that also depend on the data dims.
355
# Pytest doesn't support covariant parametrization, so we do it somewhat manually here.
357
data_dims = [(26, 34), (26, 26)]
359
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
360
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
361
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
364
dims_and_points = list(itertools.product(data_dims, points))
366
# up to here, we could just have used 2 @parametrized.
367
# Down below is the covarariant part as the points depend on the data dims.
370
for dim in data_dims:
371
points += [(dim, T.RandomPerspective.get_params(dim[1], dim[0], i / n)) for i in range(n)]
372
return dims_and_points
375
@pytest.mark.parametrize("device", cpu_and_cuda())
376
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
377
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
378
@pytest.mark.parametrize("fill", (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1], (2.0,)))
379
@pytest.mark.parametrize("fn", [F.perspective, torch.jit.script(F.perspective)])
380
def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
382
if dt == torch.float16 and device == "cpu":
383
# skip float16 on CPU case
386
data_dims, (spoints, epoints) = dims_and_points
388
tensor, pil_img = _create_data(*data_dims, device=device)
390
tensor = tensor.to(dtype=dt)
392
interpolation = NEAREST
393
fill_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill
394
out_pil_img = F.perspective(
395
pil_img, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill_pil
397
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
398
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill).cpu()
400
if out_tensor.dtype != torch.uint8:
401
out_tensor = out_tensor.to(torch.uint8)
403
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
404
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
405
# Tolerance : less than 5% of different pixels
406
assert ratio_diff_pixels < 0.05
409
@pytest.mark.parametrize("device", cpu_and_cuda())
410
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
411
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
412
def test_perspective_batch(device, dims_and_points, dt):
414
if dt == torch.float16 and device == "cpu":
415
# skip float16 on CPU case
418
data_dims, (spoints, epoints) = dims_and_points
420
batch_tensors = _create_data_batch(*data_dims, num_samples=4, device=device)
422
batch_tensors = batch_tensors.to(dtype=dt)
424
# Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at
425
# the border may be entirely different due to small rounding errors.
426
scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8
430
scripted_fn_atol=scripted_fn_atol,
433
interpolation=NEAREST,
437
def test_perspective_interpolation_type():
438
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
439
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
440
tensor = torch.randint(0, 256, (3, 26, 26))
442
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=PIL.Image.BILINEAR)
443
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
444
assert_equal(res1, res2)
447
@pytest.mark.parametrize("device", cpu_and_cuda())
448
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
449
@pytest.mark.parametrize("size", [32, 26, [32], [32, 32], (32, 32), [26, 35]])
450
@pytest.mark.parametrize("max_size", [None, 34, 40, 1000])
451
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
452
def test_resize(device, dt, size, max_size, interpolation):
454
if dt == torch.float16 and device == "cpu":
455
# skip float16 on CPU case
458
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
461
torch.manual_seed(12)
462
script_fn = torch.jit.script(F.resize)
463
tensor, pil_img = _create_data(26, 36, device=device)
464
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
467
# This is a trivial cast to float of uint8 data to test all cases
468
tensor = tensor.to(dt)
469
batch_tensors = batch_tensors.to(dt)
471
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
472
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
474
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
476
if interpolation != NEAREST:
477
# We can not check values if mode = NEAREST, as results are different
478
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
479
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
480
resized_tensor_f = resized_tensor
481
# we need to cast to uint8 to compare with PIL image
482
if resized_tensor_f.dtype == torch.uint8:
483
resized_tensor_f = resized_tensor_f.to(torch.float)
485
# Pay attention to high tolerance for MAE
486
_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=3.0)
488
if isinstance(size, int):
493
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True)
494
assert_equal(resized_tensor, resize_result)
497
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True
501
@pytest.mark.parametrize("device", cpu_and_cuda())
502
def test_resize_asserts(device):
504
tensor, pil_img = _create_data(26, 36, device=device)
506
res1 = F.resize(tensor, size=32, interpolation=PIL.Image.BILINEAR)
507
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
508
assert_equal(res1, res2)
510
for img in (tensor, pil_img):
511
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
512
with pytest.raises(ValueError, match=exp_msg):
513
F.resize(img, size=(32, 34), max_size=35)
514
with pytest.raises(ValueError, match="max_size = 32 must be strictly greater"):
515
F.resize(img, size=32, max_size=32)
518
@pytest.mark.parametrize("device", cpu_and_cuda())
519
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
520
@pytest.mark.parametrize("size", [[96, 72], [96, 420], [420, 72]])
521
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
522
def test_resize_antialias(device, dt, size, interpolation):
524
if dt == torch.float16 and device == "cpu":
525
# skip float16 on CPU case
528
torch.manual_seed(12)
529
script_fn = torch.jit.script(F.resize)
530
tensor, pil_img = _create_data(320, 290, device=device)
533
# This is a trivial cast to float of uint8 data to test all cases
534
tensor = tensor.to(dt)
536
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
537
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, antialias=True)
539
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
541
resized_tensor_f = resized_tensor
542
# we need to cast to uint8 to compare with PIL image
543
if resized_tensor_f.dtype == torch.uint8:
544
resized_tensor_f = resized_tensor_f.to(torch.float)
546
_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}")
548
accepted_tol = 1.0 + 1e-5
549
if interpolation == BICUBIC:
550
# this overall mean value to make the tests pass
551
# High value is mostly required for test cases with
552
# downsampling and upsampling where we can not exactly
553
# match PIL implementation.
556
_assert_approx_equal_tensor_to_pil(
557
resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", msg=f"{size}, {interpolation}, {dt}"
560
if isinstance(size, int):
567
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True)
568
assert_equal(resized_tensor, resize_result)
571
def check_functional_vs_PIL_vs_scripted(
572
fn, fn_pil, fn_t, config, device, dtype, channels=3, tol=2.0 + 1e-10, agg_method="max"
575
script_fn = torch.jit.script(fn)
576
torch.manual_seed(15)
577
tensor, pil_img = _create_data(26, 34, channels=channels, device=device)
578
batch_tensors = _create_data_batch(16, 18, num_samples=4, channels=channels, device=device)
580
if dtype is not None:
581
tensor = F.convert_image_dtype(tensor, dtype)
582
batch_tensors = F.convert_image_dtype(batch_tensors, dtype)
584
out_fn_t = fn_t(tensor, **config)
585
out_pil = fn_pil(pil_img, **config)
586
out_scripted = script_fn(tensor, **config)
587
assert out_fn_t.dtype == out_scripted.dtype
588
assert out_fn_t.size()[1:] == out_pil.size[::-1]
590
rbg_tensor = out_fn_t
592
if out_fn_t.dtype != torch.uint8:
593
rbg_tensor = F.convert_image_dtype(out_fn_t, torch.uint8)
595
# Check that max difference does not exceed 2 in [0, 255] range
596
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
597
_assert_approx_equal_tensor_to_pil(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method)
600
if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type:
602
assert out_fn_t.allclose(out_scripted, atol=atol)
604
# FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that.
605
_test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
608
@pytest.mark.parametrize("device", cpu_and_cuda())
609
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
610
@pytest.mark.parametrize("config", [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
611
@pytest.mark.parametrize("channels", [1, 3])
612
def test_adjust_brightness(device, dtype, config, channels):
613
check_functional_vs_PIL_vs_scripted(
615
F_pil.adjust_brightness,
616
F_t.adjust_brightness,
624
@pytest.mark.parametrize("device", cpu_and_cuda())
625
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
626
@pytest.mark.parametrize("channels", [1, 3])
627
def test_invert(device, dtype, channels):
628
check_functional_vs_PIL_vs_scripted(
629
F.invert, F_pil.invert, F_t.invert, {}, device, dtype, channels, tol=1.0, agg_method="max"
633
@pytest.mark.parametrize("device", cpu_and_cuda())
634
@pytest.mark.parametrize("config", [{"bits": bits} for bits in range(0, 8)])
635
@pytest.mark.parametrize("channels", [1, 3])
636
def test_posterize(device, config, channels):
637
check_functional_vs_PIL_vs_scripted(
650
@pytest.mark.parametrize("device", cpu_and_cuda())
651
@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
652
@pytest.mark.parametrize("channels", [1, 3])
653
def test_solarize1(device, config, channels):
654
check_functional_vs_PIL_vs_scripted(
667
@pytest.mark.parametrize("device", cpu_and_cuda())
668
@pytest.mark.parametrize("dtype", (torch.float32, torch.float64))
669
@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
670
@pytest.mark.parametrize("channels", [1, 3])
671
def test_solarize2(device, dtype, config, channels):
672
check_functional_vs_PIL_vs_scripted(
674
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
685
@pytest.mark.parametrize(
686
("dtype", "threshold"),
690
for dtype, threshold in itertools.product(
691
[torch.float32, torch.float16],
692
[0.0, 0.25, 0.5, 0.75, 1.0],
695
*[(torch.uint8, threshold) for threshold in [0, 64, 128, 192, 255]],
696
*[(torch.int64, threshold) for threshold in [0, 2**32, 2**63 - 1]],
699
@pytest.mark.parametrize("device", cpu_and_cuda())
700
def test_solarize_threshold_within_bound(threshold, dtype, device):
701
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
702
img = make_img((3, 12, 23), dtype=dtype, device=device)
703
F_t.solarize(img, threshold)
706
@pytest.mark.parametrize(
707
("dtype", "threshold"),
709
(torch.float32, 1.5),
710
(torch.float16, 1.5),
712
(torch.int64, 2**64),
715
@pytest.mark.parametrize("device", cpu_and_cuda())
716
def test_solarize_threshold_above_bound(threshold, dtype, device):
717
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
718
img = make_img((3, 12, 23), dtype=dtype, device=device)
719
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
720
F_t.solarize(img, threshold)
723
@pytest.mark.parametrize("device", cpu_and_cuda())
724
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
725
@pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
726
@pytest.mark.parametrize("channels", [1, 3])
727
def test_adjust_sharpness(device, dtype, config, channels):
728
check_functional_vs_PIL_vs_scripted(
730
F_pil.adjust_sharpness,
731
F_t.adjust_sharpness,
739
@pytest.mark.parametrize("device", cpu_and_cuda())
740
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
741
@pytest.mark.parametrize("channels", [1, 3])
742
def test_autocontrast(device, dtype, channels):
743
check_functional_vs_PIL_vs_scripted(
744
F.autocontrast, F_pil.autocontrast, F_t.autocontrast, {}, device, dtype, channels, tol=1.0, agg_method="max"
748
@pytest.mark.parametrize("device", cpu_and_cuda())
749
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
750
@pytest.mark.parametrize("channels", [1, 3])
751
def test_autocontrast_equal_minmax(device, dtype, channels):
752
a = _create_data_batch(32, 32, num_samples=1, channels=channels, device=device)
754
assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
757
assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
760
@pytest.mark.parametrize("device", cpu_and_cuda())
761
@pytest.mark.parametrize("channels", [1, 3])
762
def test_equalize(device, channels):
763
torch.use_deterministic_algorithms(False)
764
check_functional_vs_PIL_vs_scripted(
777
@pytest.mark.parametrize("device", cpu_and_cuda())
778
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
779
@pytest.mark.parametrize("config", [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
780
@pytest.mark.parametrize("channels", [1, 3])
781
def test_adjust_contrast(device, dtype, config, channels):
782
check_functional_vs_PIL_vs_scripted(
783
F.adjust_contrast, F_pil.adjust_contrast, F_t.adjust_contrast, config, device, dtype, channels
787
@pytest.mark.parametrize("device", cpu_and_cuda())
788
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
789
@pytest.mark.parametrize("config", [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
790
@pytest.mark.parametrize("channels", [1, 3])
791
def test_adjust_saturation(device, dtype, config, channels):
792
check_functional_vs_PIL_vs_scripted(
793
F.adjust_saturation, F_pil.adjust_saturation, F_t.adjust_saturation, config, device, dtype, channels
797
@pytest.mark.parametrize("device", cpu_and_cuda())
798
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
799
@pytest.mark.parametrize("config", [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
800
@pytest.mark.parametrize("channels", [1, 3])
801
def test_adjust_hue(device, dtype, config, channels):
802
check_functional_vs_PIL_vs_scripted(
803
F.adjust_hue, F_pil.adjust_hue, F_t.adjust_hue, config, device, dtype, channels, tol=16.1, agg_method="max"
807
@pytest.mark.parametrize("device", cpu_and_cuda())
808
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
809
@pytest.mark.parametrize("config", [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
810
@pytest.mark.parametrize("channels", [1, 3])
811
def test_adjust_gamma(device, dtype, config, channels):
812
check_functional_vs_PIL_vs_scripted(
823
@pytest.mark.parametrize("device", cpu_and_cuda())
824
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
825
@pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]])
826
@pytest.mark.parametrize(
829
{"padding_mode": "constant", "fill": 0},
830
{"padding_mode": "constant", "fill": 10},
831
{"padding_mode": "constant", "fill": 20.2},
832
{"padding_mode": "edge"},
833
{"padding_mode": "reflect"},
834
{"padding_mode": "symmetric"},
837
def test_pad(device, dt, pad, config):
838
script_fn = torch.jit.script(F.pad)
839
tensor, pil_img = _create_data(7, 8, device=device)
840
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
842
if dt == torch.float16 and device == "cpu":
843
# skip float16 on CPU case
847
# This is a trivial cast to float of uint8 data to test all cases
848
tensor = tensor.to(dt)
849
batch_tensors = batch_tensors.to(dt)
851
pad_tensor = F_t.pad(tensor, pad, **config)
852
pad_pil_img = F_pil.pad(pil_img, pad, **config)
854
pad_tensor_8b = pad_tensor
855
# we need to cast to uint8 to compare with PIL image
856
if pad_tensor_8b.dtype != torch.uint8:
857
pad_tensor_8b = pad_tensor_8b.to(torch.uint8)
859
_assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg=f"{pad}, {config}")
861
if isinstance(pad, int):
867
pad_tensor_script = script_fn(tensor, script_pad, **config)
868
assert_equal(pad_tensor, pad_tensor_script, msg=f"{pad}, {config}")
870
_test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config)
873
@pytest.mark.parametrize("device", cpu_and_cuda())
874
@pytest.mark.parametrize("mode", [NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC])
875
def test_resized_crop(device, mode):
876
# test values of F.resized_crop in several cases:
877
# 1) resize to the same size, crop to the same size => should be identity
878
tensor, _ = _create_data(26, 36, device=device)
880
out_tensor = F.resized_crop(
881
tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode, antialias=True
883
assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
885
# 2) resize by half and crop a TL corner
886
tensor, _ = _create_data(26, 36, device=device)
887
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
888
expected_out_tensor = tensor[:, :20:2, :30:2]
892
msg=f"{expected_out_tensor[0, :10, :10]} vs {out_tensor[0, :10, :10]}",
895
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
904
interpolation=NEAREST,
908
@pytest.mark.parametrize("device", cpu_and_cuda())
909
@pytest.mark.parametrize(
912
(F_t.get_dimensions, ()),
913
(F_t.get_image_size, ()),
914
(F_t.get_image_num_channels, ()),
917
(F_t.crop, (1, 2, 4, 5)),
918
(F_t.adjust_brightness, (0.0,)),
919
(F_t.adjust_contrast, (1.0,)),
920
(F_t.adjust_hue, (-0.5,)),
921
(F_t.adjust_saturation, (2.0,)),
922
(F_t.pad, ([2], 2, "constant")),
923
(F_t.resize, ([10, 11],)),
924
(F_t.perspective, ([0.2])),
925
(F_t.gaussian_blur, ((2, 2), (0.7, 0.5))),
927
(F_t.posterize, (0,)),
928
(F_t.solarize, (0.3,)),
929
(F_t.adjust_sharpness, (0.3,)),
930
(F_t.autocontrast, ()),
934
def test_assert_image_tensor(device, func, args):
936
tensor = torch.rand(*shape, dtype=torch.float, device=device)
937
with pytest.raises(Exception, match=r"Tensor is not a torch image."):
941
@pytest.mark.parametrize("device", cpu_and_cuda())
942
def test_vflip(device):
943
script_vflip = torch.jit.script(F.vflip)
945
img_tensor, pil_img = _create_data(16, 18, device=device)
946
vflipped_img = F.vflip(img_tensor)
947
vflipped_pil_img = F.vflip(pil_img)
948
_assert_equal_tensor_to_pil(vflipped_img, vflipped_pil_img)
950
# scriptable function test
951
vflipped_img_script = script_vflip(img_tensor)
952
assert_equal(vflipped_img, vflipped_img_script)
954
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
955
_test_fn_on_batch(batch_tensors, F.vflip)
958
@pytest.mark.parametrize("device", cpu_and_cuda())
959
def test_hflip(device):
960
script_hflip = torch.jit.script(F.hflip)
962
img_tensor, pil_img = _create_data(16, 18, device=device)
963
hflipped_img = F.hflip(img_tensor)
964
hflipped_pil_img = F.hflip(pil_img)
965
_assert_equal_tensor_to_pil(hflipped_img, hflipped_pil_img)
967
# scriptable function test
968
hflipped_img_script = script_hflip(img_tensor)
969
assert_equal(hflipped_img, hflipped_img_script)
971
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
972
_test_fn_on_batch(batch_tensors, F.hflip)
975
@pytest.mark.parametrize("device", cpu_and_cuda())
976
@pytest.mark.parametrize(
977
"top, left, height, width",
979
(1, 2, 4, 5), # crop inside top-left corner
980
(2, 12, 3, 4), # crop inside top-right corner
981
(8, 3, 5, 6), # crop inside bottom-left corner
982
(8, 11, 4, 3), # crop inside bottom-right corner
983
(50, 50, 10, 10), # crop outside the image
984
(-50, -50, 10, 10), # crop outside the image
987
def test_crop(device, top, left, height, width):
988
script_crop = torch.jit.script(F.crop)
990
img_tensor, pil_img = _create_data(16, 18, device=device)
992
pil_img_cropped = F.crop(pil_img, top, left, height, width)
994
img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
995
_assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)
997
img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
998
_assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)
1000
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1001
_test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
1004
@pytest.mark.parametrize("device", cpu_and_cuda())
1005
@pytest.mark.parametrize("image_size", ("small", "large"))
1006
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
1007
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
1008
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
1009
@pytest.mark.parametrize("fn", [F.gaussian_blur, torch.jit.script(F.gaussian_blur)])
1010
def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
1012
# true_cv2_results = {
1013
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
1014
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
1016
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
1018
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
1020
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
1022
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
1023
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
1026
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
1028
true_cv2_results = torch.load(p, weights_only=False)
1030
if image_size == "small":
1032
torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
1035
tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)
1037
if dt == torch.float16 and device == "cpu":
1038
# skip float16 on CPU case
1042
tensor = tensor.to(dtype=dt)
1044
_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
1045
_sigma = sigma[0] if sigma is not None else None
1046
shape = tensor.shape
1047
gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
1048
if gt_key not in true_cv2_results:
1052
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
1055
out = fn(tensor, kernel_size=ksize, sigma=sigma)
1056
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
1059
@pytest.mark.parametrize("device", cpu_and_cuda())
1060
def test_hsv2rgb(device):
1061
scripted_fn = torch.jit.script(F_t._hsv2rgb)
1062
shape = (3, 100, 150)
1064
hsv_img = torch.rand(*shape, dtype=torch.float, device=device)
1065
rgb_img = F_t._hsv2rgb(hsv_img)
1066
ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
1072
) = hsv_img.unbind(0)
1073
h = h.flatten().cpu().numpy()
1074
s = s.flatten().cpu().numpy()
1075
v = v.flatten().cpu().numpy()
1078
for h1, s1, v1 in zip(h, s, v):
1079
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
1080
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=device)
1081
torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5)
1083
s_rgb_img = scripted_fn(hsv_img)
1084
torch.testing.assert_close(rgb_img, s_rgb_img)
1086
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float()
1087
_test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
1090
@pytest.mark.parametrize("device", cpu_and_cuda())
1091
def test_rgb2hsv(device):
1092
scripted_fn = torch.jit.script(F_t._rgb2hsv)
1093
shape = (3, 150, 100)
1095
rgb_img = torch.rand(*shape, dtype=torch.float, device=device)
1096
hsv_img = F_t._rgb2hsv(rgb_img)
1097
ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
1103
) = rgb_img.unbind(dim=-3)
1104
r = r.flatten().cpu().numpy()
1105
g = g.flatten().cpu().numpy()
1106
b = b.flatten().cpu().numpy()
1109
for r1, g1, b1 in zip(r, g, b):
1110
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))
1112
colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=device)
1114
ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
1115
colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1)
1117
max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max()
1118
max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max()
1119
max_diff = max(max_diff_h, max_diff_sv)
1120
assert max_diff < 1e-5
1122
s_hsv_img = scripted_fn(rgb_img)
1123
torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
1125
batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float()
1126
_test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
1129
@pytest.mark.parametrize("device", cpu_and_cuda())
1130
@pytest.mark.parametrize("num_output_channels", (3, 1))
1131
def test_rgb_to_grayscale(device, num_output_channels):
1132
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
1134
img_tensor, pil_img = _create_data(32, 34, device=device)
1136
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
1137
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
1139
_assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
1141
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
1142
assert_equal(s_gray_tensor, gray_tensor)
1144
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1145
_test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
1148
@pytest.mark.parametrize("device", cpu_and_cuda())
1149
def test_center_crop(device):
1150
script_center_crop = torch.jit.script(F.center_crop)
1152
img_tensor, pil_img = _create_data(32, 34, device=device)
1154
cropped_pil_image = F.center_crop(pil_img, [10, 11])
1156
cropped_tensor = F.center_crop(img_tensor, [10, 11])
1157
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
1159
cropped_tensor = script_center_crop(img_tensor, [10, 11])
1160
_assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
1162
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1163
_test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
1166
@pytest.mark.parametrize("device", cpu_and_cuda())
1167
def test_five_crop(device):
1168
script_five_crop = torch.jit.script(F.five_crop)
1170
img_tensor, pil_img = _create_data(32, 34, device=device)
1172
cropped_pil_images = F.five_crop(pil_img, [10, 11])
1174
cropped_tensors = F.five_crop(img_tensor, [10, 11])
1176
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
1178
cropped_tensors = script_five_crop(img_tensor, [10, 11])
1180
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
1182
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1183
tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
1184
for i in range(len(batch_tensors)):
1185
img_tensor = batch_tensors[i, ...]
1186
tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
1187
assert len(tuple_transformed_imgs) == len(tuple_transformed_batches)
1189
for j in range(len(tuple_transformed_imgs)):
1190
true_transformed_img = tuple_transformed_imgs[j]
1191
transformed_img = tuple_transformed_batches[j][i, ...]
1192
assert_equal(true_transformed_img, transformed_img)
1194
# scriptable function test
1195
s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
1196
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
1197
assert_equal(transformed_batch, s_transformed_batch)
1200
@pytest.mark.parametrize("device", cpu_and_cuda())
1201
def test_ten_crop(device):
1202
script_ten_crop = torch.jit.script(F.ten_crop)
1204
img_tensor, pil_img = _create_data(32, 34, device=device)
1206
cropped_pil_images = F.ten_crop(pil_img, [10, 11])
1208
cropped_tensors = F.ten_crop(img_tensor, [10, 11])
1210
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
1212
cropped_tensors = script_ten_crop(img_tensor, [10, 11])
1214
_assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
1216
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1217
tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
1218
for i in range(len(batch_tensors)):
1219
img_tensor = batch_tensors[i, ...]
1220
tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
1221
assert len(tuple_transformed_imgs) == len(tuple_transformed_batches)
1223
for j in range(len(tuple_transformed_imgs)):
1224
true_transformed_img = tuple_transformed_imgs[j]
1225
transformed_img = tuple_transformed_batches[j][i, ...]
1226
assert_equal(true_transformed_img, transformed_img)
1228
# scriptable function test
1229
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
1230
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
1231
assert_equal(transformed_batch, s_transformed_batch)
1234
def test_elastic_transform_asserts():
1235
with pytest.raises(TypeError, match="Argument displacement should be a Tensor"):
1236
_ = F.elastic_transform("abc", displacement=None)
1238
with pytest.raises(TypeError, match="img should be PIL Image or Tensor"):
1239
_ = F.elastic_transform("abc", displacement=torch.rand(1))
1241
img_tensor = torch.rand(1, 3, 32, 24)
1242
with pytest.raises(ValueError, match="Argument displacement shape should"):
1243
_ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2))
1246
@pytest.mark.parametrize("device", cpu_and_cuda())
1247
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
1248
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
1249
@pytest.mark.parametrize(
1251
[None, [255, 255, 255], (2.0,)],
1253
def test_elastic_transform_consistency(device, interpolation, dt, fill):
1254
script_elastic_transform = torch.jit.script(F.elastic_transform)
1255
img_tensor, _ = _create_data(32, 34, device=device)
1256
# As there is no PIL implementation for elastic_transform,
1257
# thus we do not run tests tensor vs pillow
1260
img_tensor = img_tensor.to(dt)
1262
displacement = T.ElasticTransform.get_params([1.5, 1.5], [2.0, 2.0], [32, 34])
1264
displacement=displacement,
1265
interpolation=interpolation,
1269
out_tensor1 = F.elastic_transform(img_tensor, **kwargs)
1270
out_tensor2 = script_elastic_transform(img_tensor, **kwargs)
1271
assert_equal(out_tensor1, out_tensor2)
1273
batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1274
displacement = T.ElasticTransform.get_params([1.5, 1.5], [2.0, 2.0], [16, 18])
1275
kwargs["displacement"] = displacement
1277
batch_tensors = batch_tensors.to(dt)
1278
_test_fn_on_batch(batch_tensors, F.elastic_transform, **kwargs)
1281
if __name__ == "__main__":
1282
pytest.main([__file__])