vision

Форк
0
/
test_functional_tensor.py 
1282 строки · 48.1 Кб
1
import colorsys
2
import itertools
3
import math
4
import os
5
from functools import partial
6
from typing import Sequence
7

8
import numpy as np
9
import PIL.Image
10
import pytest
11
import torch
12
import torchvision.transforms as T
13
import torchvision.transforms._functional_pil as F_pil
14
import torchvision.transforms._functional_tensor as F_t
15
import torchvision.transforms.functional as F
16
from common_utils import (
17
    _assert_approx_equal_tensor_to_pil,
18
    _assert_equal_tensor_to_pil,
19
    _create_data,
20
    _create_data_batch,
21
    _test_fn_on_batch,
22
    assert_equal,
23
    cpu_and_cuda,
24
    needs_cuda,
25
)
26
from torchvision.transforms import InterpolationMode
27

28
NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
29
    InterpolationMode.NEAREST,
30
    InterpolationMode.NEAREST_EXACT,
31
    InterpolationMode.BILINEAR,
32
    InterpolationMode.BICUBIC,
33
)
34

35

36
@pytest.mark.parametrize("device", cpu_and_cuda())
37
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions])
38
def test_image_sizes(device, fn):
39
    script_F = torch.jit.script(fn)
40

41
    img_tensor, pil_img = _create_data(16, 18, 3, device=device)
42
    value_img = fn(img_tensor)
43
    value_pil_img = fn(pil_img)
44
    assert value_img == value_pil_img
45

46
    value_img_script = script_F(img_tensor)
47
    assert value_img == value_img_script
48

49
    batch_tensors = _create_data_batch(16, 18, 3, num_samples=4, device=device)
50
    value_img_batch = fn(batch_tensors)
51
    assert value_img == value_img_batch
52

53

54
@needs_cuda
55
def test_scale_channel():
56
    """Make sure that _scale_channel gives the same results on CPU and GPU as
57
    histc or bincount are used depending on the device.
58
    """
59
    # TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed,
60
    # only use bincount and remove that test.
61
    size = (1_000,)
62
    img_chan = torch.randint(0, 256, size=size).to("cpu")
63
    scaled_cpu = F_t._scale_channel(img_chan)
64
    scaled_cuda = F_t._scale_channel(img_chan.to("cuda"))
65
    assert_equal(scaled_cpu, scaled_cuda.to("cpu"))
66

67

68
class TestRotate:
69

70
    ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16]
71
    scripted_rotate = torch.jit.script(F.rotate)
72
    IMG_W = 26
73

74
    @pytest.mark.parametrize("device", cpu_and_cuda())
75
    @pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)])
76
    @pytest.mark.parametrize(
77
        "center",
78
        [
79
            None,
80
            (int(IMG_W * 0.3), int(IMG_W * 0.4)),
81
            [int(IMG_W * 0.5), int(IMG_W * 0.6)],
82
        ],
83
    )
84
    @pytest.mark.parametrize("dt", ALL_DTYPES)
85
    @pytest.mark.parametrize("angle", range(-180, 180, 34))
86
    @pytest.mark.parametrize("expand", [True, False])
87
    @pytest.mark.parametrize(
88
        "fill",
89
        [
90
            None,
91
            [0, 0, 0],
92
            (1, 2, 3),
93
            [255, 255, 255],
94
            [
95
                1,
96
            ],
97
            (2.0,),
98
        ],
99
    )
100
    @pytest.mark.parametrize("fn", [F.rotate, scripted_rotate])
101
    def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn):
102
        tensor, pil_img = _create_data(height, width, device=device)
103

104
        if dt == torch.float16 and torch.device(device).type == "cpu":
105
            # skip float16 on CPU case
106
            return
107

108
        if dt is not None:
109
            tensor = tensor.to(dtype=dt)
110

111
        f_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill
112
        out_pil_img = F.rotate(pil_img, angle=angle, interpolation=NEAREST, expand=expand, center=center, fill=f_pil)
113
        out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
114

115
        out_tensor = fn(tensor, angle=angle, interpolation=NEAREST, expand=expand, center=center, fill=fill).cpu()
116

117
        if out_tensor.dtype != torch.uint8:
118
            out_tensor = out_tensor.to(torch.uint8)
119

120
        assert (
121
            out_tensor.shape == out_pil_tensor.shape
122
        ), f"{(height, width, NEAREST, dt, angle, expand, center)}: {out_tensor.shape} vs {out_pil_tensor.shape}"
123

124
        num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
125
        ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
126
        # Tolerance : less than 3% of different pixels
127
        assert ratio_diff_pixels < 0.03, (
128
            f"{(height, width, NEAREST, dt, angle, expand, center, fill)}: "
129
            f"{ratio_diff_pixels}\n{out_tensor[0, :7, :7]} vs \n"
130
            f"{out_pil_tensor[0, :7, :7]}"
131
        )
132

133
    @pytest.mark.parametrize("device", cpu_and_cuda())
134
    @pytest.mark.parametrize("dt", ALL_DTYPES)
135
    def test_rotate_batch(self, device, dt):
136
        if dt == torch.float16 and device == "cpu":
137
            # skip float16 on CPU case
138
            return
139

140
        batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
141
        if dt is not None:
142
            batch_tensors = batch_tensors.to(dtype=dt)
143

144
        center = (20, 22)
145
        _test_fn_on_batch(batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center)
146

147
    def test_rotate_interpolation_type(self):
148
        tensor, _ = _create_data(26, 26)
149
        res1 = F.rotate(tensor, 45, interpolation=PIL.Image.BILINEAR)
150
        res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
151
        assert_equal(res1, res2)
152

153

154
class TestAffine:
155

156
    ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16]
157
    scripted_affine = torch.jit.script(F.affine)
158

159
    @pytest.mark.parametrize("device", cpu_and_cuda())
160
    @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
161
    @pytest.mark.parametrize("dt", ALL_DTYPES)
162
    def test_identity_map(self, device, height, width, dt):
163
        # Tests on square and rectangular images
164
        tensor, pil_img = _create_data(height, width, device=device)
165

166
        if dt == torch.float16 and device == "cpu":
167
            # skip float16 on CPU case
168
            return
169

170
        if dt is not None:
171
            tensor = tensor.to(dtype=dt)
172

173
        # 1) identity map
174
        out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
175

176
        assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
177
        out_tensor = self.scripted_affine(
178
            tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
179
        )
180
        assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
181

182
    @pytest.mark.parametrize("device", cpu_and_cuda())
183
    @pytest.mark.parametrize("height, width", [(26, 26)])
184
    @pytest.mark.parametrize("dt", ALL_DTYPES)
185
    @pytest.mark.parametrize(
186
        "angle, config",
187
        [
188
            (90, {"k": 1, "dims": (-1, -2)}),
189
            (45, None),
190
            (30, None),
191
            (-30, None),
192
            (-45, None),
193
            (-90, {"k": -1, "dims": (-1, -2)}),
194
            (180, {"k": 2, "dims": (-1, -2)}),
195
        ],
196
    )
197
    @pytest.mark.parametrize("fn", [F.affine, scripted_affine])
198
    def test_square_rotations(self, device, height, width, dt, angle, config, fn):
199
        # 2) Test rotation
200
        tensor, pil_img = _create_data(height, width, device=device)
201

202
        if dt == torch.float16 and device == "cpu":
203
            # skip float16 on CPU case
204
            return
205

206
        if dt is not None:
207
            tensor = tensor.to(dtype=dt)
208

209
        out_pil_img = F.affine(
210
            pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
211
        )
212
        out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(device)
213

214
        out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
215
        if config is not None:
216
            assert_equal(torch.rot90(tensor, **config), out_tensor)
217

218
        if out_tensor.dtype != torch.uint8:
219
            out_tensor = out_tensor.to(torch.uint8)
220

221
        num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
222
        ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
223
        # Tolerance : less than 6% of different pixels
224
        assert ratio_diff_pixels < 0.06
225

226
    @pytest.mark.parametrize("device", cpu_and_cuda())
227
    @pytest.mark.parametrize("height, width", [(32, 26)])
228
    @pytest.mark.parametrize("dt", ALL_DTYPES)
229
    @pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120])
230
    @pytest.mark.parametrize("fn", [F.affine, scripted_affine])
231
    @pytest.mark.parametrize("center", [None, [0, 0]])
232
    def test_rect_rotations(self, device, height, width, dt, angle, fn, center):
233
        # Tests on rectangular images
234
        tensor, pil_img = _create_data(height, width, device=device)
235

236
        if dt == torch.float16 and device == "cpu":
237
            # skip float16 on CPU case
238
            return
239

240
        if dt is not None:
241
            tensor = tensor.to(dtype=dt)
242

243
        out_pil_img = F.affine(
244
            pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
245
        )
246
        out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
247

248
        out_tensor = fn(
249
            tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
250
        ).cpu()
251

252
        if out_tensor.dtype != torch.uint8:
253
            out_tensor = out_tensor.to(torch.uint8)
254

255
        num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
256
        ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
257
        # Tolerance : less than 3% of different pixels
258
        assert ratio_diff_pixels < 0.03
259

260
    @pytest.mark.parametrize("device", cpu_and_cuda())
261
    @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
262
    @pytest.mark.parametrize("dt", ALL_DTYPES)
263
    @pytest.mark.parametrize("t", [[10, 12], (-12, -13)])
264
    @pytest.mark.parametrize("fn", [F.affine, scripted_affine])
265
    def test_translations(self, device, height, width, dt, t, fn):
266
        # 3) Test translation
267
        tensor, pil_img = _create_data(height, width, device=device)
268

269
        if dt == torch.float16 and device == "cpu":
270
            # skip float16 on CPU case
271
            return
272

273
        if dt is not None:
274
            tensor = tensor.to(dtype=dt)
275

276
        out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
277

278
        out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
279

280
        if out_tensor.dtype != torch.uint8:
281
            out_tensor = out_tensor.to(torch.uint8)
282

283
        _assert_equal_tensor_to_pil(out_tensor, out_pil_img)
284

285
    @pytest.mark.parametrize("device", cpu_and_cuda())
286
    @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
287
    @pytest.mark.parametrize("dt", ALL_DTYPES)
288
    @pytest.mark.parametrize(
289
        "a, t, s, sh, f",
290
        [
291
            (45.5, [5, 6], 1.0, [0.0, 0.0], None),
292
            (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
293
            (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
294
            (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
295
            (85, (10, -10), 0.7, [0.0, 0.0], [1]),
296
            (0, [0, 0], 1.0, [35.0], (2.0,)),
297
            (-25, [0, 0], 1.2, [0.0, 15.0], None),
298
            (-45, [-10, 0], 0.7, [2.0, 5.0], None),
299
            (-45, [-10, -10], 1.2, [4.0, 5.0], None),
300
            (-90, [0, 0], 1.0, [0.0, 0.0], None),
301
        ],
302
    )
303
    @pytest.mark.parametrize("fn", [F.affine, scripted_affine])
304
    def test_all_ops(self, device, height, width, dt, a, t, s, sh, f, fn):
305
        # 4) Test rotation + translation + scale + shear
306
        tensor, pil_img = _create_data(height, width, device=device)
307

308
        if dt == torch.float16 and device == "cpu":
309
            # skip float16 on CPU case
310
            return
311

312
        if dt is not None:
313
            tensor = tensor.to(dtype=dt)
314

315
        f_pil = int(f[0]) if f is not None and len(f) == 1 else f
316
        out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=NEAREST, fill=f_pil)
317
        out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
318

319
        out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=NEAREST, fill=f).cpu()
320

321
        if out_tensor.dtype != torch.uint8:
322
            out_tensor = out_tensor.to(torch.uint8)
323

324
        num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
325
        ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
326
        # Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
327
        tol = 0.06 if device == "cuda" else 0.05
328
        assert ratio_diff_pixels < tol
329

330
    @pytest.mark.parametrize("device", cpu_and_cuda())
331
    @pytest.mark.parametrize("dt", ALL_DTYPES)
332
    def test_batches(self, device, dt):
333
        if dt == torch.float16 and device == "cpu":
334
            # skip float16 on CPU case
335
            return
336

337
        batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
338
        if dt is not None:
339
            batch_tensors = batch_tensors.to(dtype=dt)
340

341
        _test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0])
342

343
    @pytest.mark.parametrize("device", cpu_and_cuda())
344
    def test_interpolation_type(self, device):
345
        tensor, pil_img = _create_data(26, 26, device=device)
346

347
        res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=PIL.Image.BILINEAR)
348
        res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
349
        assert_equal(res1, res2)
350

351

352
def _get_data_dims_and_points_for_perspective():
353
    # Ideally we would parametrize independently over data dims and points, but
354
    # we want to tests on some points that also depend on the data dims.
355
    # Pytest doesn't support covariant parametrization, so we do it somewhat manually here.
356

357
    data_dims = [(26, 34), (26, 26)]
358
    points = [
359
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
360
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
361
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
362
    ]
363

364
    dims_and_points = list(itertools.product(data_dims, points))
365

366
    # up to here, we could just have used 2 @parametrized.
367
    # Down below is the covarariant part as the points depend on the data dims.
368

369
    n = 10
370
    for dim in data_dims:
371
        points += [(dim, T.RandomPerspective.get_params(dim[1], dim[0], i / n)) for i in range(n)]
372
    return dims_and_points
373

374

375
@pytest.mark.parametrize("device", cpu_and_cuda())
376
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
377
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
378
@pytest.mark.parametrize("fill", (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1], (2.0,)))
379
@pytest.mark.parametrize("fn", [F.perspective, torch.jit.script(F.perspective)])
380
def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn):
381

382
    if dt == torch.float16 and device == "cpu":
383
        # skip float16 on CPU case
384
        return
385

386
    data_dims, (spoints, epoints) = dims_and_points
387

388
    tensor, pil_img = _create_data(*data_dims, device=device)
389
    if dt is not None:
390
        tensor = tensor.to(dtype=dt)
391

392
    interpolation = NEAREST
393
    fill_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill
394
    out_pil_img = F.perspective(
395
        pil_img, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill_pil
396
    )
397
    out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
398
    out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill).cpu()
399

400
    if out_tensor.dtype != torch.uint8:
401
        out_tensor = out_tensor.to(torch.uint8)
402

403
    num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
404
    ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
405
    # Tolerance : less than 5% of different pixels
406
    assert ratio_diff_pixels < 0.05
407

408

409
@pytest.mark.parametrize("device", cpu_and_cuda())
410
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
411
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
412
def test_perspective_batch(device, dims_and_points, dt):
413

414
    if dt == torch.float16 and device == "cpu":
415
        # skip float16 on CPU case
416
        return
417

418
    data_dims, (spoints, epoints) = dims_and_points
419

420
    batch_tensors = _create_data_batch(*data_dims, num_samples=4, device=device)
421
    if dt is not None:
422
        batch_tensors = batch_tensors.to(dtype=dt)
423

424
    # Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at
425
    # the border may be entirely different due to small rounding errors.
426
    scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8
427
    _test_fn_on_batch(
428
        batch_tensors,
429
        F.perspective,
430
        scripted_fn_atol=scripted_fn_atol,
431
        startpoints=spoints,
432
        endpoints=epoints,
433
        interpolation=NEAREST,
434
    )
435

436

437
def test_perspective_interpolation_type():
438
    spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
439
    epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
440
    tensor = torch.randint(0, 256, (3, 26, 26))
441

442
    res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=PIL.Image.BILINEAR)
443
    res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
444
    assert_equal(res1, res2)
445

446

447
@pytest.mark.parametrize("device", cpu_and_cuda())
448
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
449
@pytest.mark.parametrize("size", [32, 26, [32], [32, 32], (32, 32), [26, 35]])
450
@pytest.mark.parametrize("max_size", [None, 34, 40, 1000])
451
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
452
def test_resize(device, dt, size, max_size, interpolation):
453

454
    if dt == torch.float16 and device == "cpu":
455
        # skip float16 on CPU case
456
        return
457

458
    if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
459
        return  # unsupported
460

461
    torch.manual_seed(12)
462
    script_fn = torch.jit.script(F.resize)
463
    tensor, pil_img = _create_data(26, 36, device=device)
464
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
465

466
    if dt is not None:
467
        # This is a trivial cast to float of uint8 data to test all cases
468
        tensor = tensor.to(dt)
469
        batch_tensors = batch_tensors.to(dt)
470

471
    resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
472
    resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size, antialias=True)
473

474
    assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
475

476
    if interpolation != NEAREST:
477
        # We can not check values if mode = NEAREST, as results are different
478
        # E.g. resized_tensor  = [[a, a, b, c, d, d, e, ...]]
479
        # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
480
        resized_tensor_f = resized_tensor
481
        # we need to cast to uint8 to compare with PIL image
482
        if resized_tensor_f.dtype == torch.uint8:
483
            resized_tensor_f = resized_tensor_f.to(torch.float)
484

485
        # Pay attention to high tolerance for MAE
486
        _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=3.0)
487

488
    if isinstance(size, int):
489
        script_size = [size]
490
    else:
491
        script_size = size
492

493
    resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True)
494
    assert_equal(resized_tensor, resize_result)
495

496
    _test_fn_on_batch(
497
        batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size, antialias=True
498
    )
499

500

501
@pytest.mark.parametrize("device", cpu_and_cuda())
502
def test_resize_asserts(device):
503

504
    tensor, pil_img = _create_data(26, 36, device=device)
505

506
    res1 = F.resize(tensor, size=32, interpolation=PIL.Image.BILINEAR)
507
    res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
508
    assert_equal(res1, res2)
509

510
    for img in (tensor, pil_img):
511
        exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
512
        with pytest.raises(ValueError, match=exp_msg):
513
            F.resize(img, size=(32, 34), max_size=35)
514
        with pytest.raises(ValueError, match="max_size = 32 must be strictly greater"):
515
            F.resize(img, size=32, max_size=32)
516

517

518
@pytest.mark.parametrize("device", cpu_and_cuda())
519
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
520
@pytest.mark.parametrize("size", [[96, 72], [96, 420], [420, 72]])
521
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC])
522
def test_resize_antialias(device, dt, size, interpolation):
523

524
    if dt == torch.float16 and device == "cpu":
525
        # skip float16 on CPU case
526
        return
527

528
    torch.manual_seed(12)
529
    script_fn = torch.jit.script(F.resize)
530
    tensor, pil_img = _create_data(320, 290, device=device)
531

532
    if dt is not None:
533
        # This is a trivial cast to float of uint8 data to test all cases
534
        tensor = tensor.to(dt)
535

536
    resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
537
    resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, antialias=True)
538

539
    assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
540

541
    resized_tensor_f = resized_tensor
542
    # we need to cast to uint8 to compare with PIL image
543
    if resized_tensor_f.dtype == torch.uint8:
544
        resized_tensor_f = resized_tensor_f.to(torch.float)
545

546
    _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}")
547

548
    accepted_tol = 1.0 + 1e-5
549
    if interpolation == BICUBIC:
550
        # this overall mean value to make the tests pass
551
        # High value is mostly required for test cases with
552
        # downsampling and upsampling where we can not exactly
553
        # match PIL implementation.
554
        accepted_tol = 15.0
555

556
    _assert_approx_equal_tensor_to_pil(
557
        resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", msg=f"{size}, {interpolation}, {dt}"
558
    )
559

560
    if isinstance(size, int):
561
        script_size = [
562
            size,
563
        ]
564
    else:
565
        script_size = size
566

567
    resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True)
568
    assert_equal(resized_tensor, resize_result)
569

570

571
def check_functional_vs_PIL_vs_scripted(
572
    fn, fn_pil, fn_t, config, device, dtype, channels=3, tol=2.0 + 1e-10, agg_method="max"
573
):
574

575
    script_fn = torch.jit.script(fn)
576
    torch.manual_seed(15)
577
    tensor, pil_img = _create_data(26, 34, channels=channels, device=device)
578
    batch_tensors = _create_data_batch(16, 18, num_samples=4, channels=channels, device=device)
579

580
    if dtype is not None:
581
        tensor = F.convert_image_dtype(tensor, dtype)
582
        batch_tensors = F.convert_image_dtype(batch_tensors, dtype)
583

584
    out_fn_t = fn_t(tensor, **config)
585
    out_pil = fn_pil(pil_img, **config)
586
    out_scripted = script_fn(tensor, **config)
587
    assert out_fn_t.dtype == out_scripted.dtype
588
    assert out_fn_t.size()[1:] == out_pil.size[::-1]
589

590
    rbg_tensor = out_fn_t
591

592
    if out_fn_t.dtype != torch.uint8:
593
        rbg_tensor = F.convert_image_dtype(out_fn_t, torch.uint8)
594

595
    # Check that max difference does not exceed 2 in [0, 255] range
596
    # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
597
    _assert_approx_equal_tensor_to_pil(rbg_tensor.float(), out_pil, tol=tol, agg_method=agg_method)
598

599
    atol = 1e-6
600
    if out_fn_t.dtype == torch.uint8 and "cuda" in torch.device(device).type:
601
        atol = 1.0
602
    assert out_fn_t.allclose(out_scripted, atol=atol)
603

604
    # FIXME: fn will be scripted again in _test_fn_on_batch. We could avoid that.
605
    _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config)
606

607

608
@pytest.mark.parametrize("device", cpu_and_cuda())
609
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
610
@pytest.mark.parametrize("config", [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
611
@pytest.mark.parametrize("channels", [1, 3])
612
def test_adjust_brightness(device, dtype, config, channels):
613
    check_functional_vs_PIL_vs_scripted(
614
        F.adjust_brightness,
615
        F_pil.adjust_brightness,
616
        F_t.adjust_brightness,
617
        config,
618
        device,
619
        dtype,
620
        channels,
621
    )
622

623

624
@pytest.mark.parametrize("device", cpu_and_cuda())
625
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
626
@pytest.mark.parametrize("channels", [1, 3])
627
def test_invert(device, dtype, channels):
628
    check_functional_vs_PIL_vs_scripted(
629
        F.invert, F_pil.invert, F_t.invert, {}, device, dtype, channels, tol=1.0, agg_method="max"
630
    )
631

632

633
@pytest.mark.parametrize("device", cpu_and_cuda())
634
@pytest.mark.parametrize("config", [{"bits": bits} for bits in range(0, 8)])
635
@pytest.mark.parametrize("channels", [1, 3])
636
def test_posterize(device, config, channels):
637
    check_functional_vs_PIL_vs_scripted(
638
        F.posterize,
639
        F_pil.posterize,
640
        F_t.posterize,
641
        config,
642
        device,
643
        dtype=None,
644
        channels=channels,
645
        tol=1.0,
646
        agg_method="max",
647
    )
648

649

650
@pytest.mark.parametrize("device", cpu_and_cuda())
651
@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
652
@pytest.mark.parametrize("channels", [1, 3])
653
def test_solarize1(device, config, channels):
654
    check_functional_vs_PIL_vs_scripted(
655
        F.solarize,
656
        F_pil.solarize,
657
        F_t.solarize,
658
        config,
659
        device,
660
        dtype=None,
661
        channels=channels,
662
        tol=1.0,
663
        agg_method="max",
664
    )
665

666

667
@pytest.mark.parametrize("device", cpu_and_cuda())
668
@pytest.mark.parametrize("dtype", (torch.float32, torch.float64))
669
@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
670
@pytest.mark.parametrize("channels", [1, 3])
671
def test_solarize2(device, dtype, config, channels):
672
    check_functional_vs_PIL_vs_scripted(
673
        F.solarize,
674
        lambda img, threshold: F_pil.solarize(img, 255 * threshold),
675
        F_t.solarize,
676
        config,
677
        device,
678
        dtype,
679
        channels,
680
        tol=1.0,
681
        agg_method="max",
682
    )
683

684

685
@pytest.mark.parametrize(
686
    ("dtype", "threshold"),
687
    [
688
        *[
689
            (dtype, threshold)
690
            for dtype, threshold in itertools.product(
691
                [torch.float32, torch.float16],
692
                [0.0, 0.25, 0.5, 0.75, 1.0],
693
            )
694
        ],
695
        *[(torch.uint8, threshold) for threshold in [0, 64, 128, 192, 255]],
696
        *[(torch.int64, threshold) for threshold in [0, 2**32, 2**63 - 1]],
697
    ],
698
)
699
@pytest.mark.parametrize("device", cpu_and_cuda())
700
def test_solarize_threshold_within_bound(threshold, dtype, device):
701
    make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
702
    img = make_img((3, 12, 23), dtype=dtype, device=device)
703
    F_t.solarize(img, threshold)
704

705

706
@pytest.mark.parametrize(
707
    ("dtype", "threshold"),
708
    [
709
        (torch.float32, 1.5),
710
        (torch.float16, 1.5),
711
        (torch.uint8, 260),
712
        (torch.int64, 2**64),
713
    ],
714
)
715
@pytest.mark.parametrize("device", cpu_and_cuda())
716
def test_solarize_threshold_above_bound(threshold, dtype, device):
717
    make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
718
    img = make_img((3, 12, 23), dtype=dtype, device=device)
719
    with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
720
        F_t.solarize(img, threshold)
721

722

723
@pytest.mark.parametrize("device", cpu_and_cuda())
724
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
725
@pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
726
@pytest.mark.parametrize("channels", [1, 3])
727
def test_adjust_sharpness(device, dtype, config, channels):
728
    check_functional_vs_PIL_vs_scripted(
729
        F.adjust_sharpness,
730
        F_pil.adjust_sharpness,
731
        F_t.adjust_sharpness,
732
        config,
733
        device,
734
        dtype,
735
        channels,
736
    )
737

738

739
@pytest.mark.parametrize("device", cpu_and_cuda())
740
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
741
@pytest.mark.parametrize("channels", [1, 3])
742
def test_autocontrast(device, dtype, channels):
743
    check_functional_vs_PIL_vs_scripted(
744
        F.autocontrast, F_pil.autocontrast, F_t.autocontrast, {}, device, dtype, channels, tol=1.0, agg_method="max"
745
    )
746

747

748
@pytest.mark.parametrize("device", cpu_and_cuda())
749
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
750
@pytest.mark.parametrize("channels", [1, 3])
751
def test_autocontrast_equal_minmax(device, dtype, channels):
752
    a = _create_data_batch(32, 32, num_samples=1, channels=channels, device=device)
753
    a = a / 2.0 + 0.3
754
    assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
755

756
    a[0, 0] = 0.7
757
    assert (F.autocontrast(a)[0] == F.autocontrast(a[0])).all()
758

759

760
@pytest.mark.parametrize("device", cpu_and_cuda())
761
@pytest.mark.parametrize("channels", [1, 3])
762
def test_equalize(device, channels):
763
    torch.use_deterministic_algorithms(False)
764
    check_functional_vs_PIL_vs_scripted(
765
        F.equalize,
766
        F_pil.equalize,
767
        F_t.equalize,
768
        {},
769
        device,
770
        dtype=None,
771
        channels=channels,
772
        tol=1.0,
773
        agg_method="max",
774
    )
775

776

777
@pytest.mark.parametrize("device", cpu_and_cuda())
778
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
779
@pytest.mark.parametrize("config", [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
780
@pytest.mark.parametrize("channels", [1, 3])
781
def test_adjust_contrast(device, dtype, config, channels):
782
    check_functional_vs_PIL_vs_scripted(
783
        F.adjust_contrast, F_pil.adjust_contrast, F_t.adjust_contrast, config, device, dtype, channels
784
    )
785

786

787
@pytest.mark.parametrize("device", cpu_and_cuda())
788
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
789
@pytest.mark.parametrize("config", [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
790
@pytest.mark.parametrize("channels", [1, 3])
791
def test_adjust_saturation(device, dtype, config, channels):
792
    check_functional_vs_PIL_vs_scripted(
793
        F.adjust_saturation, F_pil.adjust_saturation, F_t.adjust_saturation, config, device, dtype, channels
794
    )
795

796

797
@pytest.mark.parametrize("device", cpu_and_cuda())
798
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
799
@pytest.mark.parametrize("config", [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
800
@pytest.mark.parametrize("channels", [1, 3])
801
def test_adjust_hue(device, dtype, config, channels):
802
    check_functional_vs_PIL_vs_scripted(
803
        F.adjust_hue, F_pil.adjust_hue, F_t.adjust_hue, config, device, dtype, channels, tol=16.1, agg_method="max"
804
    )
805

806

807
@pytest.mark.parametrize("device", cpu_and_cuda())
808
@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64))
809
@pytest.mark.parametrize("config", [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
810
@pytest.mark.parametrize("channels", [1, 3])
811
def test_adjust_gamma(device, dtype, config, channels):
812
    check_functional_vs_PIL_vs_scripted(
813
        F.adjust_gamma,
814
        F_pil.adjust_gamma,
815
        F_t.adjust_gamma,
816
        config,
817
        device,
818
        dtype,
819
        channels,
820
    )
821

822

823
@pytest.mark.parametrize("device", cpu_and_cuda())
824
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
825
@pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]])
826
@pytest.mark.parametrize(
827
    "config",
828
    [
829
        {"padding_mode": "constant", "fill": 0},
830
        {"padding_mode": "constant", "fill": 10},
831
        {"padding_mode": "constant", "fill": 20.2},
832
        {"padding_mode": "edge"},
833
        {"padding_mode": "reflect"},
834
        {"padding_mode": "symmetric"},
835
    ],
836
)
837
def test_pad(device, dt, pad, config):
838
    script_fn = torch.jit.script(F.pad)
839
    tensor, pil_img = _create_data(7, 8, device=device)
840
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
841

842
    if dt == torch.float16 and device == "cpu":
843
        # skip float16 on CPU case
844
        return
845

846
    if dt is not None:
847
        # This is a trivial cast to float of uint8 data to test all cases
848
        tensor = tensor.to(dt)
849
        batch_tensors = batch_tensors.to(dt)
850

851
    pad_tensor = F_t.pad(tensor, pad, **config)
852
    pad_pil_img = F_pil.pad(pil_img, pad, **config)
853

854
    pad_tensor_8b = pad_tensor
855
    # we need to cast to uint8 to compare with PIL image
856
    if pad_tensor_8b.dtype != torch.uint8:
857
        pad_tensor_8b = pad_tensor_8b.to(torch.uint8)
858

859
    _assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg=f"{pad}, {config}")
860

861
    if isinstance(pad, int):
862
        script_pad = [
863
            pad,
864
        ]
865
    else:
866
        script_pad = pad
867
    pad_tensor_script = script_fn(tensor, script_pad, **config)
868
    assert_equal(pad_tensor, pad_tensor_script, msg=f"{pad}, {config}")
869

870
    _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config)
871

872

873
@pytest.mark.parametrize("device", cpu_and_cuda())
874
@pytest.mark.parametrize("mode", [NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC])
875
def test_resized_crop(device, mode):
876
    # test values of F.resized_crop in several cases:
877
    # 1) resize to the same size, crop to the same size => should be identity
878
    tensor, _ = _create_data(26, 36, device=device)
879

880
    out_tensor = F.resized_crop(
881
        tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode, antialias=True
882
    )
883
    assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
884

885
    # 2) resize by half and crop a TL corner
886
    tensor, _ = _create_data(26, 36, device=device)
887
    out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
888
    expected_out_tensor = tensor[:, :20:2, :30:2]
889
    assert_equal(
890
        expected_out_tensor,
891
        out_tensor,
892
        msg=f"{expected_out_tensor[0, :10, :10]} vs {out_tensor[0, :10, :10]}",
893
    )
894

895
    batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
896
    _test_fn_on_batch(
897
        batch_tensors,
898
        F.resized_crop,
899
        top=1,
900
        left=2,
901
        height=20,
902
        width=30,
903
        size=[10, 15],
904
        interpolation=NEAREST,
905
    )
906

907

908
@pytest.mark.parametrize("device", cpu_and_cuda())
909
@pytest.mark.parametrize(
910
    "func, args",
911
    [
912
        (F_t.get_dimensions, ()),
913
        (F_t.get_image_size, ()),
914
        (F_t.get_image_num_channels, ()),
915
        (F_t.vflip, ()),
916
        (F_t.hflip, ()),
917
        (F_t.crop, (1, 2, 4, 5)),
918
        (F_t.adjust_brightness, (0.0,)),
919
        (F_t.adjust_contrast, (1.0,)),
920
        (F_t.adjust_hue, (-0.5,)),
921
        (F_t.adjust_saturation, (2.0,)),
922
        (F_t.pad, ([2], 2, "constant")),
923
        (F_t.resize, ([10, 11],)),
924
        (F_t.perspective, ([0.2])),
925
        (F_t.gaussian_blur, ((2, 2), (0.7, 0.5))),
926
        (F_t.invert, ()),
927
        (F_t.posterize, (0,)),
928
        (F_t.solarize, (0.3,)),
929
        (F_t.adjust_sharpness, (0.3,)),
930
        (F_t.autocontrast, ()),
931
        (F_t.equalize, ()),
932
    ],
933
)
934
def test_assert_image_tensor(device, func, args):
935
    shape = (100,)
936
    tensor = torch.rand(*shape, dtype=torch.float, device=device)
937
    with pytest.raises(Exception, match=r"Tensor is not a torch image."):
938
        func(tensor, *args)
939

940

941
@pytest.mark.parametrize("device", cpu_and_cuda())
942
def test_vflip(device):
943
    script_vflip = torch.jit.script(F.vflip)
944

945
    img_tensor, pil_img = _create_data(16, 18, device=device)
946
    vflipped_img = F.vflip(img_tensor)
947
    vflipped_pil_img = F.vflip(pil_img)
948
    _assert_equal_tensor_to_pil(vflipped_img, vflipped_pil_img)
949

950
    # scriptable function test
951
    vflipped_img_script = script_vflip(img_tensor)
952
    assert_equal(vflipped_img, vflipped_img_script)
953

954
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
955
    _test_fn_on_batch(batch_tensors, F.vflip)
956

957

958
@pytest.mark.parametrize("device", cpu_and_cuda())
959
def test_hflip(device):
960
    script_hflip = torch.jit.script(F.hflip)
961

962
    img_tensor, pil_img = _create_data(16, 18, device=device)
963
    hflipped_img = F.hflip(img_tensor)
964
    hflipped_pil_img = F.hflip(pil_img)
965
    _assert_equal_tensor_to_pil(hflipped_img, hflipped_pil_img)
966

967
    # scriptable function test
968
    hflipped_img_script = script_hflip(img_tensor)
969
    assert_equal(hflipped_img, hflipped_img_script)
970

971
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
972
    _test_fn_on_batch(batch_tensors, F.hflip)
973

974

975
@pytest.mark.parametrize("device", cpu_and_cuda())
976
@pytest.mark.parametrize(
977
    "top, left, height, width",
978
    [
979
        (1, 2, 4, 5),  # crop inside top-left corner
980
        (2, 12, 3, 4),  # crop inside top-right corner
981
        (8, 3, 5, 6),  # crop inside bottom-left corner
982
        (8, 11, 4, 3),  # crop inside bottom-right corner
983
        (50, 50, 10, 10),  # crop outside the image
984
        (-50, -50, 10, 10),  # crop outside the image
985
    ],
986
)
987
def test_crop(device, top, left, height, width):
988
    script_crop = torch.jit.script(F.crop)
989

990
    img_tensor, pil_img = _create_data(16, 18, device=device)
991

992
    pil_img_cropped = F.crop(pil_img, top, left, height, width)
993

994
    img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
995
    _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)
996

997
    img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
998
    _assert_equal_tensor_to_pil(img_tensor_cropped, pil_img_cropped)
999

1000
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1001
    _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)
1002

1003

1004
@pytest.mark.parametrize("device", cpu_and_cuda())
1005
@pytest.mark.parametrize("image_size", ("small", "large"))
1006
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
1007
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
1008
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
1009
@pytest.mark.parametrize("fn", [F.gaussian_blur, torch.jit.script(F.gaussian_blur)])
1010
def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
1011

1012
    # true_cv2_results = {
1013
    #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
1014
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
1015
    #     "3_3_0.8": ...
1016
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
1017
    #     "3_3_0.5": ...
1018
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
1019
    #     "3_5_0.8": ...
1020
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
1021
    #     "3_5_0.5": ...
1022
    #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
1023
    #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
1024
    #     "23_23_1.7": ...
1025
    # }
1026
    p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
1027

1028
    true_cv2_results = torch.load(p, weights_only=False)
1029

1030
    if image_size == "small":
1031
        tensor = (
1032
            torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
1033
        )
1034
    else:
1035
        tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)
1036

1037
    if dt == torch.float16 and device == "cpu":
1038
        # skip float16 on CPU case
1039
        return
1040

1041
    if dt is not None:
1042
        tensor = tensor.to(dtype=dt)
1043

1044
    _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
1045
    _sigma = sigma[0] if sigma is not None else None
1046
    shape = tensor.shape
1047
    gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
1048
    if gt_key not in true_cv2_results:
1049
        return
1050

1051
    true_out = (
1052
        torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
1053
    )
1054

1055
    out = fn(tensor, kernel_size=ksize, sigma=sigma)
1056
    torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
1057

1058

1059
@pytest.mark.parametrize("device", cpu_and_cuda())
1060
def test_hsv2rgb(device):
1061
    scripted_fn = torch.jit.script(F_t._hsv2rgb)
1062
    shape = (3, 100, 150)
1063
    for _ in range(10):
1064
        hsv_img = torch.rand(*shape, dtype=torch.float, device=device)
1065
        rgb_img = F_t._hsv2rgb(hsv_img)
1066
        ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
1067

1068
        (
1069
            h,
1070
            s,
1071
            v,
1072
        ) = hsv_img.unbind(0)
1073
        h = h.flatten().cpu().numpy()
1074
        s = s.flatten().cpu().numpy()
1075
        v = v.flatten().cpu().numpy()
1076

1077
        rgb = []
1078
        for h1, s1, v1 in zip(h, s, v):
1079
            rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
1080
        colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=device)
1081
        torch.testing.assert_close(ft_img, colorsys_img, rtol=0.0, atol=1e-5)
1082

1083
        s_rgb_img = scripted_fn(hsv_img)
1084
        torch.testing.assert_close(rgb_img, s_rgb_img)
1085

1086
    batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float()
1087
    _test_fn_on_batch(batch_tensors, F_t._hsv2rgb)
1088

1089

1090
@pytest.mark.parametrize("device", cpu_and_cuda())
1091
def test_rgb2hsv(device):
1092
    scripted_fn = torch.jit.script(F_t._rgb2hsv)
1093
    shape = (3, 150, 100)
1094
    for _ in range(10):
1095
        rgb_img = torch.rand(*shape, dtype=torch.float, device=device)
1096
        hsv_img = F_t._rgb2hsv(rgb_img)
1097
        ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
1098

1099
        (
1100
            r,
1101
            g,
1102
            b,
1103
        ) = rgb_img.unbind(dim=-3)
1104
        r = r.flatten().cpu().numpy()
1105
        g = g.flatten().cpu().numpy()
1106
        b = b.flatten().cpu().numpy()
1107

1108
        hsv = []
1109
        for r1, g1, b1 in zip(r, g, b):
1110
            hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))
1111

1112
        colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=device)
1113

1114
        ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
1115
        colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1)
1116

1117
        max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max()
1118
        max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max()
1119
        max_diff = max(max_diff_h, max_diff_sv)
1120
        assert max_diff < 1e-5
1121

1122
        s_hsv_img = scripted_fn(rgb_img)
1123
        torch.testing.assert_close(hsv_img, s_hsv_img, rtol=1e-5, atol=1e-7)
1124

1125
    batch_tensors = _create_data_batch(120, 100, num_samples=4, device=device).float()
1126
    _test_fn_on_batch(batch_tensors, F_t._rgb2hsv)
1127

1128

1129
@pytest.mark.parametrize("device", cpu_and_cuda())
1130
@pytest.mark.parametrize("num_output_channels", (3, 1))
1131
def test_rgb_to_grayscale(device, num_output_channels):
1132
    script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
1133

1134
    img_tensor, pil_img = _create_data(32, 34, device=device)
1135

1136
    gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
1137
    gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
1138

1139
    _assert_approx_equal_tensor_to_pil(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
1140

1141
    s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
1142
    assert_equal(s_gray_tensor, gray_tensor)
1143

1144
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1145
    _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)
1146

1147

1148
@pytest.mark.parametrize("device", cpu_and_cuda())
1149
def test_center_crop(device):
1150
    script_center_crop = torch.jit.script(F.center_crop)
1151

1152
    img_tensor, pil_img = _create_data(32, 34, device=device)
1153

1154
    cropped_pil_image = F.center_crop(pil_img, [10, 11])
1155

1156
    cropped_tensor = F.center_crop(img_tensor, [10, 11])
1157
    _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
1158

1159
    cropped_tensor = script_center_crop(img_tensor, [10, 11])
1160
    _assert_equal_tensor_to_pil(cropped_tensor, cropped_pil_image)
1161

1162
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1163
    _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])
1164

1165

1166
@pytest.mark.parametrize("device", cpu_and_cuda())
1167
def test_five_crop(device):
1168
    script_five_crop = torch.jit.script(F.five_crop)
1169

1170
    img_tensor, pil_img = _create_data(32, 34, device=device)
1171

1172
    cropped_pil_images = F.five_crop(pil_img, [10, 11])
1173

1174
    cropped_tensors = F.five_crop(img_tensor, [10, 11])
1175
    for i in range(5):
1176
        _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
1177

1178
    cropped_tensors = script_five_crop(img_tensor, [10, 11])
1179
    for i in range(5):
1180
        _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
1181

1182
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1183
    tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
1184
    for i in range(len(batch_tensors)):
1185
        img_tensor = batch_tensors[i, ...]
1186
        tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
1187
        assert len(tuple_transformed_imgs) == len(tuple_transformed_batches)
1188

1189
        for j in range(len(tuple_transformed_imgs)):
1190
            true_transformed_img = tuple_transformed_imgs[j]
1191
            transformed_img = tuple_transformed_batches[j][i, ...]
1192
            assert_equal(true_transformed_img, transformed_img)
1193

1194
    # scriptable function test
1195
    s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
1196
    for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
1197
        assert_equal(transformed_batch, s_transformed_batch)
1198

1199

1200
@pytest.mark.parametrize("device", cpu_and_cuda())
1201
def test_ten_crop(device):
1202
    script_ten_crop = torch.jit.script(F.ten_crop)
1203

1204
    img_tensor, pil_img = _create_data(32, 34, device=device)
1205

1206
    cropped_pil_images = F.ten_crop(pil_img, [10, 11])
1207

1208
    cropped_tensors = F.ten_crop(img_tensor, [10, 11])
1209
    for i in range(10):
1210
        _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
1211

1212
    cropped_tensors = script_ten_crop(img_tensor, [10, 11])
1213
    for i in range(10):
1214
        _assert_equal_tensor_to_pil(cropped_tensors[i], cropped_pil_images[i])
1215

1216
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1217
    tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
1218
    for i in range(len(batch_tensors)):
1219
        img_tensor = batch_tensors[i, ...]
1220
        tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
1221
        assert len(tuple_transformed_imgs) == len(tuple_transformed_batches)
1222

1223
        for j in range(len(tuple_transformed_imgs)):
1224
            true_transformed_img = tuple_transformed_imgs[j]
1225
            transformed_img = tuple_transformed_batches[j][i, ...]
1226
            assert_equal(true_transformed_img, transformed_img)
1227

1228
    # scriptable function test
1229
    s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
1230
    for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
1231
        assert_equal(transformed_batch, s_transformed_batch)
1232

1233

1234
def test_elastic_transform_asserts():
1235
    with pytest.raises(TypeError, match="Argument displacement should be a Tensor"):
1236
        _ = F.elastic_transform("abc", displacement=None)
1237

1238
    with pytest.raises(TypeError, match="img should be PIL Image or Tensor"):
1239
        _ = F.elastic_transform("abc", displacement=torch.rand(1))
1240

1241
    img_tensor = torch.rand(1, 3, 32, 24)
1242
    with pytest.raises(ValueError, match="Argument displacement shape should"):
1243
        _ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2))
1244

1245

1246
@pytest.mark.parametrize("device", cpu_and_cuda())
1247
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
1248
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
1249
@pytest.mark.parametrize(
1250
    "fill",
1251
    [None, [255, 255, 255], (2.0,)],
1252
)
1253
def test_elastic_transform_consistency(device, interpolation, dt, fill):
1254
    script_elastic_transform = torch.jit.script(F.elastic_transform)
1255
    img_tensor, _ = _create_data(32, 34, device=device)
1256
    # As there is no PIL implementation for elastic_transform,
1257
    # thus we do not run tests tensor vs pillow
1258

1259
    if dt is not None:
1260
        img_tensor = img_tensor.to(dt)
1261

1262
    displacement = T.ElasticTransform.get_params([1.5, 1.5], [2.0, 2.0], [32, 34])
1263
    kwargs = dict(
1264
        displacement=displacement,
1265
        interpolation=interpolation,
1266
        fill=fill,
1267
    )
1268

1269
    out_tensor1 = F.elastic_transform(img_tensor, **kwargs)
1270
    out_tensor2 = script_elastic_transform(img_tensor, **kwargs)
1271
    assert_equal(out_tensor1, out_tensor2)
1272

1273
    batch_tensors = _create_data_batch(16, 18, num_samples=4, device=device)
1274
    displacement = T.ElasticTransform.get_params([1.5, 1.5], [2.0, 2.0], [16, 18])
1275
    kwargs["displacement"] = displacement
1276
    if dt is not None:
1277
        batch_tensors = batch_tensors.to(dt)
1278
    _test_fn_on_batch(batch_tensors, F.elastic_transform, **kwargs)
1279

1280

1281
if __name__ == "__main__":
1282
    pytest.main([__file__])
1283

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.