vision

Форк
0
/
test_transforms.py 
2245 строк · 81.0 Кб
1
import math
2
import os
3
import random
4
import re
5
import sys
6
from functools import partial
7

8
import numpy as np
9
import pytest
10
import torch
11
import torchvision.transforms as transforms
12
import torchvision.transforms._functional_tensor as F_t
13
import torchvision.transforms.functional as F
14
from PIL import Image
15
from torch._utils_internal import get_file_path_2
16

17
try:
18
    import accimage
19
except ImportError:
20
    accimage = None
21

22
try:
23
    from scipy import stats
24
except ImportError:
25
    stats = None
26

27
from common_utils import assert_equal, cycle_over, float_dtypes, int_dtypes
28

29

30
GRACE_HOPPER = get_file_path_2(
31
    os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
32
)
33

34

35
def _get_grayscale_test_image(img, fill=None):
36
    img = img.convert("L")
37
    fill = (fill[0],) if isinstance(fill, tuple) else fill
38
    return img, fill
39

40

41
class TestConvertImageDtype:
42
    @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(float_dtypes()))
43
    def test_float_to_float(self, input_dtype, output_dtype):
44
        input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
45
        transform = transforms.ConvertImageDtype(output_dtype)
46
        transform_script = torch.jit.script(F.convert_image_dtype)
47

48
        output_image = transform(input_image)
49
        output_image_script = transform_script(input_image, output_dtype)
50

51
        torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
52

53
        actual_min, actual_max = output_image.tolist()
54
        desired_min, desired_max = 0.0, 1.0
55

56
        assert abs(actual_min - desired_min) < 1e-7
57
        assert abs(actual_max - desired_max) < 1e-7
58

59
    @pytest.mark.parametrize("input_dtype", float_dtypes())
60
    @pytest.mark.parametrize("output_dtype", int_dtypes())
61
    def test_float_to_int(self, input_dtype, output_dtype):
62
        input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
63
        transform = transforms.ConvertImageDtype(output_dtype)
64
        transform_script = torch.jit.script(F.convert_image_dtype)
65

66
        if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
67
            input_dtype == torch.float64 and output_dtype == torch.int64
68
        ):
69
            with pytest.raises(RuntimeError):
70
                transform(input_image)
71
        else:
72
            output_image = transform(input_image)
73
            output_image_script = transform_script(input_image, output_dtype)
74

75
            torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
76

77
            actual_min, actual_max = output_image.tolist()
78
            desired_min, desired_max = 0, torch.iinfo(output_dtype).max
79

80
            assert actual_min == desired_min
81
            assert actual_max == desired_max
82

83
    @pytest.mark.parametrize("input_dtype", int_dtypes())
84
    @pytest.mark.parametrize("output_dtype", float_dtypes())
85
    def test_int_to_float(self, input_dtype, output_dtype):
86
        input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
87
        transform = transforms.ConvertImageDtype(output_dtype)
88
        transform_script = torch.jit.script(F.convert_image_dtype)
89

90
        output_image = transform(input_image)
91
        output_image_script = transform_script(input_image, output_dtype)
92

93
        torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)
94

95
        actual_min, actual_max = output_image.tolist()
96
        desired_min, desired_max = 0.0, 1.0
97

98
        assert abs(actual_min - desired_min) < 1e-7
99
        assert actual_min >= desired_min
100
        assert abs(actual_max - desired_max) < 1e-7
101
        assert actual_max <= desired_max
102

103
    @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes()))
104
    def test_dtype_int_to_int(self, input_dtype, output_dtype):
105
        input_max = torch.iinfo(input_dtype).max
106
        input_image = torch.tensor((0, input_max), dtype=input_dtype)
107
        output_max = torch.iinfo(output_dtype).max
108

109
        transform = transforms.ConvertImageDtype(output_dtype)
110
        transform_script = torch.jit.script(F.convert_image_dtype)
111

112
        output_image = transform(input_image)
113
        output_image_script = transform_script(input_image, output_dtype)
114

115
        torch.testing.assert_close(
116
            output_image_script,
117
            output_image,
118
            rtol=0.0,
119
            atol=1e-6,
120
            msg=f"{output_image_script} vs {output_image}",
121
        )
122

123
        actual_min, actual_max = output_image.tolist()
124
        desired_min, desired_max = 0, output_max
125

126
        # see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
127
        if input_max >= output_max:
128
            error_term = 0
129
        else:
130
            error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)
131

132
        assert actual_min == desired_min
133
        assert actual_max == (desired_max + error_term)
134

135
    @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes()))
136
    def test_int_to_int_consistency(self, input_dtype, output_dtype):
137
        input_max = torch.iinfo(input_dtype).max
138
        input_image = torch.tensor((0, input_max), dtype=input_dtype)
139

140
        output_max = torch.iinfo(output_dtype).max
141
        if output_max <= input_max:
142
            return
143

144
        transform = transforms.ConvertImageDtype(output_dtype)
145
        inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
146
        output_image = inverse_transfrom(transform(input_image))
147

148
        actual_min, actual_max = output_image.tolist()
149
        desired_min, desired_max = 0, input_max
150

151
        assert actual_min == desired_min
152
        assert actual_max == desired_max
153

154

155
@pytest.mark.skipif(accimage is None, reason="accimage not available")
156
class TestAccImage:
157
    def test_accimage_to_tensor(self):
158
        trans = transforms.PILToTensor()
159

160
        expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
161
        output = trans(accimage.Image(GRACE_HOPPER))
162

163
        torch.testing.assert_close(output, expected_output)
164

165
    def test_accimage_pil_to_tensor(self):
166
        trans = transforms.PILToTensor()
167

168
        expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
169
        output = trans(accimage.Image(GRACE_HOPPER))
170

171
        assert expected_output.size() == output.size()
172
        torch.testing.assert_close(output, expected_output)
173

174
    def test_accimage_resize(self):
175
        trans = transforms.Compose(
176
            [
177
                transforms.Resize(256, interpolation=Image.LINEAR),
178
                transforms.PILToTensor(),
179
                transforms.ConvertImageDtype(dtype=torch.float),
180
            ]
181
        )
182

183
        # Checking if Compose, Resize and ToTensor can be printed as string
184
        trans.__repr__()
185

186
        expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
187
        output = trans(accimage.Image(GRACE_HOPPER))
188

189
        assert expected_output.size() == output.size()
190
        assert np.abs((expected_output - output).mean()) < 1e-3
191
        assert (expected_output - output).var() < 1e-5
192
        # note the high absolute tolerance
193
        torch.testing.assert_close(output.numpy(), expected_output.numpy(), rtol=1e-5, atol=5e-2)
194

195
    def test_accimage_crop(self):
196
        trans = transforms.Compose(
197
            [transforms.CenterCrop(256), transforms.PILToTensor(), transforms.ConvertImageDtype(dtype=torch.float)]
198
        )
199

200
        # Checking if Compose, CenterCrop and ToTensor can be printed as string
201
        trans.__repr__()
202

203
        expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
204
        output = trans(accimage.Image(GRACE_HOPPER))
205

206
        assert expected_output.size() == output.size()
207
        torch.testing.assert_close(output, expected_output)
208

209

210
class TestToTensor:
211
    @pytest.mark.parametrize("channels", [1, 3, 4])
212
    def test_to_tensor(self, channels):
213
        height, width = 4, 4
214
        trans = transforms.ToTensor()
215
        np_rng = np.random.RandomState(0)
216

217
        input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
218
        img = transforms.ToPILImage()(input_data)
219
        output = trans(img)
220
        torch.testing.assert_close(output, input_data)
221

222
        ndarray = np_rng.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
223
        output = trans(ndarray)
224
        expected_output = ndarray.transpose((2, 0, 1)) / 255.0
225
        torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False)
226

227
        ndarray = np_rng.rand(height, width, channels).astype(np.float32)
228
        output = trans(ndarray)
229
        expected_output = ndarray.transpose((2, 0, 1))
230
        torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False)
231

232
        # separate test for mode '1' PIL images
233
        input_data = torch.ByteTensor(1, height, width).bernoulli_()
234
        img = transforms.ToPILImage()(input_data.mul(255)).convert("1")
235
        output = trans(img)
236
        torch.testing.assert_close(input_data, output, check_dtype=False)
237

238
    def test_to_tensor_errors(self):
239
        height, width = 4, 4
240
        trans = transforms.ToTensor()
241
        np_rng = np.random.RandomState(0)
242

243
        with pytest.raises(TypeError):
244
            trans(np_rng.rand(1, height, width).tolist())
245

246
        with pytest.raises(ValueError):
247
            trans(np_rng.rand(height))
248

249
        with pytest.raises(ValueError):
250
            trans(np_rng.rand(1, 1, height, width))
251

252
    @pytest.mark.parametrize("dtype", [torch.float16, torch.float, torch.double])
253
    def test_to_tensor_with_other_default_dtypes(self, dtype):
254
        np_rng = np.random.RandomState(0)
255
        current_def_dtype = torch.get_default_dtype()
256

257
        t = transforms.ToTensor()
258
        np_arr = np_rng.randint(0, 255, (32, 32, 3), dtype=np.uint8)
259
        img = Image.fromarray(np_arr)
260

261
        torch.set_default_dtype(dtype)
262
        res = t(img)
263
        assert res.dtype == dtype, f"{res.dtype} vs {dtype}"
264

265
        torch.set_default_dtype(current_def_dtype)
266

267
    @pytest.mark.parametrize("channels", [1, 3, 4])
268
    def test_pil_to_tensor(self, channels):
269
        height, width = 4, 4
270
        trans = transforms.PILToTensor()
271
        np_rng = np.random.RandomState(0)
272

273
        input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
274
        img = transforms.ToPILImage()(input_data)
275
        output = trans(img)
276
        torch.testing.assert_close(input_data, output)
277

278
        input_data = np_rng.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
279
        img = transforms.ToPILImage()(input_data)
280
        output = trans(img)
281
        expected_output = input_data.transpose((2, 0, 1))
282
        torch.testing.assert_close(output.numpy(), expected_output)
283

284
        input_data = torch.as_tensor(np_rng.rand(channels, height, width).astype(np.float32))
285
        img = transforms.ToPILImage()(input_data)  # CHW -> HWC and (* 255).byte()
286
        output = trans(img)  # HWC -> CHW
287
        expected_output = (input_data * 255).byte()
288
        torch.testing.assert_close(output, expected_output)
289

290
        # separate test for mode '1' PIL images
291
        input_data = torch.ByteTensor(1, height, width).bernoulli_()
292
        img = transforms.ToPILImage()(input_data.mul(255)).convert("1")
293
        output = trans(img).view(torch.uint8).bool().to(torch.uint8)
294
        torch.testing.assert_close(input_data, output)
295

296
    def test_pil_to_tensor_errors(self):
297
        height, width = 4, 4
298
        trans = transforms.PILToTensor()
299
        np_rng = np.random.RandomState(0)
300

301
        with pytest.raises(TypeError):
302
            trans(np_rng.rand(1, height, width).tolist())
303

304
        with pytest.raises(TypeError):
305
            trans(np_rng.rand(1, height, width))
306

307

308
def test_randomresized_params():
309
    height = random.randint(24, 32) * 2
310
    width = random.randint(24, 32) * 2
311
    img = torch.ones(3, height, width)
312
    to_pil_image = transforms.ToPILImage()
313
    img = to_pil_image(img)
314
    size = 100
315
    epsilon = 0.05
316
    min_scale = 0.25
317
    for _ in range(10):
318
        scale_min = max(round(random.random(), 2), min_scale)
319
        scale_range = (scale_min, scale_min + round(random.random(), 2))
320
        aspect_min = max(round(random.random(), 2), epsilon)
321
        aspect_ratio_range = (aspect_min, aspect_min + round(random.random(), 2))
322
        randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range, antialias=True)
323
        i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range)
324
        aspect_ratio_obtained = w / h
325
        assert (
326
            min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained
327
            and aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon
328
        ) or aspect_ratio_obtained == 1.0
329
        assert isinstance(i, int)
330
        assert isinstance(j, int)
331
        assert isinstance(h, int)
332
        assert isinstance(w, int)
333

334

335
@pytest.mark.parametrize(
336
    "height, width",
337
    [
338
        # height, width
339
        # square image
340
        (28, 28),
341
        (27, 27),
342
        # rectangular image: h < w
343
        (28, 34),
344
        (29, 35),
345
        # rectangular image: h > w
346
        (34, 28),
347
        (35, 29),
348
    ],
349
)
350
@pytest.mark.parametrize(
351
    "osize",
352
    [
353
        # single integer
354
        22,
355
        27,
356
        28,
357
        36,
358
        # single integer in tuple/list
359
        [
360
            22,
361
        ],
362
        (27,),
363
    ],
364
)
365
@pytest.mark.parametrize("max_size", (None, 37, 1000))
366
def test_resize(height, width, osize, max_size):
367
    img = Image.new("RGB", size=(width, height), color=127)
368

369
    t = transforms.Resize(osize, max_size=max_size, antialias=True)
370
    result = t(img)
371

372
    msg = f"{height}, {width} - {osize} - {max_size}"
373
    osize = osize[0] if isinstance(osize, (list, tuple)) else osize
374
    # If size is an int, smaller edge of the image will be matched to this number.
375
    # i.e, if height > width, then image will be rescaled to (size * height / width, size).
376
    if height < width:
377
        exp_w, exp_h = (int(osize * width / height), osize)  # (w, h)
378
        if max_size is not None and max_size < exp_w:
379
            exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
380
        assert result.size == (exp_w, exp_h), msg
381
    elif width < height:
382
        exp_w, exp_h = (osize, int(osize * height / width))  # (w, h)
383
        if max_size is not None and max_size < exp_h:
384
            exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
385
        assert result.size == (exp_w, exp_h), msg
386
    else:
387
        exp_w, exp_h = (osize, osize)  # (w, h)
388
        if max_size is not None and max_size < osize:
389
            exp_w, exp_h = max_size, max_size
390
        assert result.size == (exp_w, exp_h), msg
391

392

393
@pytest.mark.parametrize(
394
    "height, width",
395
    [
396
        # height, width
397
        # square image
398
        (28, 28),
399
        (27, 27),
400
        # rectangular image: h < w
401
        (28, 34),
402
        (29, 35),
403
        # rectangular image: h > w
404
        (34, 28),
405
        (35, 29),
406
    ],
407
)
408
@pytest.mark.parametrize(
409
    "osize",
410
    [
411
        # two integers sequence output
412
        [22, 22],
413
        [22, 28],
414
        [22, 36],
415
        [27, 22],
416
        [36, 22],
417
        [28, 28],
418
        [28, 37],
419
        [37, 27],
420
        [37, 37],
421
    ],
422
)
423
def test_resize_sequence_output(height, width, osize):
424
    img = Image.new("RGB", size=(width, height), color=127)
425
    oheight, owidth = osize
426

427
    t = transforms.Resize(osize, antialias=True)
428
    result = t(img)
429

430
    assert (owidth, oheight) == result.size
431

432

433
def test_resize_antialias_error():
434
    osize = [37, 37]
435
    img = Image.new("RGB", size=(35, 29), color=127)
436

437
    with pytest.warns(UserWarning, match=r"Anti-alias option is always applied for PIL Image input"):
438
        t = transforms.Resize(osize, antialias=False)
439
        t(img)
440

441

442
@pytest.mark.parametrize("height, width", ((32, 64), (64, 32)))
443
def test_resize_size_equals_small_edge_size(height, width):
444
    # Non-regression test for https://github.com/pytorch/vision/issues/5405
445
    # max_size used to be ignored if size == small_edge_size
446
    max_size = 40
447
    img = Image.new("RGB", size=(width, height), color=127)
448

449
    small_edge = min(height, width)
450
    t = transforms.Resize(small_edge, max_size=max_size, antialias=True)
451
    result = t(img)
452
    assert max(result.size) == max_size
453

454

455
def test_resize_equal_input_output_sizes():
456
    # Regression test for https://github.com/pytorch/vision/issues/7518
457
    height, width = 28, 27
458
    img = Image.new("RGB", size=(width, height))
459

460
    t = transforms.Resize((height, width), antialias=True)
461
    result = t(img)
462
    assert result is img
463

464

465
class TestPad:
466
    @pytest.mark.parametrize("fill", [85, 85.0])
467
    def test_pad(self, fill):
468
        height = random.randint(10, 32) * 2
469
        width = random.randint(10, 32) * 2
470
        img = torch.ones(3, height, width, dtype=torch.uint8)
471
        padding = random.randint(1, 20)
472
        result = transforms.Compose(
473
            [
474
                transforms.ToPILImage(),
475
                transforms.Pad(padding, fill=fill),
476
                transforms.PILToTensor(),
477
            ]
478
        )(img)
479
        assert result.size(1) == height + 2 * padding
480
        assert result.size(2) == width + 2 * padding
481
        # check that all elements in the padded region correspond
482
        # to the pad value
483
        h_padded = result[:, :padding, :]
484
        w_padded = result[:, :, :padding]
485
        torch.testing.assert_close(h_padded, torch.full_like(h_padded, fill_value=fill), rtol=0.0, atol=0.0)
486
        torch.testing.assert_close(w_padded, torch.full_like(w_padded, fill_value=fill), rtol=0.0, atol=0.0)
487
        pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img))
488

489
    def test_pad_with_tuple_of_pad_values(self):
490
        height = random.randint(10, 32) * 2
491
        width = random.randint(10, 32) * 2
492
        img = transforms.ToPILImage()(torch.ones(3, height, width))
493

494
        padding = tuple(random.randint(1, 20) for _ in range(2))
495
        output = transforms.Pad(padding)(img)
496
        assert output.size == (width + padding[0] * 2, height + padding[1] * 2)
497

498
        padding = [random.randint(1, 20) for _ in range(4)]
499
        output = transforms.Pad(padding)(img)
500
        assert output.size[0] == width + padding[0] + padding[2]
501
        assert output.size[1] == height + padding[1] + padding[3]
502

503
        # Checking if Padding can be printed as string
504
        transforms.Pad(padding).__repr__()
505

506
    def test_pad_with_non_constant_padding_modes(self):
507
        """Unit tests for edge, reflect, symmetric padding"""
508
        img = torch.zeros(3, 27, 27).byte()
509
        img[:, :, 0] = 1  # Constant value added to leftmost edge
510
        img = transforms.ToPILImage()(img)
511
        img = F.pad(img, 1, (200, 200, 200))
512

513
        # pad 3 to all sidess
514
        edge_padded_img = F.pad(img, 3, padding_mode="edge")
515
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
516
        # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
517
        edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
518
        assert_equal(edge_middle_slice, np.asarray([200, 200, 200, 200, 1, 0], dtype=np.uint8))
519
        assert transforms.PILToTensor()(edge_padded_img).size() == (3, 35, 35)
520

521
        # Pad 3 to left/right, 2 to top/bottom
522
        reflect_padded_img = F.pad(img, (3, 2), padding_mode="reflect")
523
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
524
        # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
525
        reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
526
        assert_equal(reflect_middle_slice, np.asarray([0, 0, 1, 200, 1, 0], dtype=np.uint8))
527
        assert transforms.PILToTensor()(reflect_padded_img).size() == (3, 33, 35)
528

529
        # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
530
        symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode="symmetric")
531
        # First 6 elements of leftmost edge in the middle of the image, values are in order:
532
        # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
533
        symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
534
        assert_equal(symmetric_middle_slice, np.asarray([0, 1, 200, 200, 1, 0], dtype=np.uint8))
535
        assert transforms.PILToTensor()(symmetric_padded_img).size() == (3, 32, 34)
536

537
        # Check negative padding explicitly for symmetric case, since it is not
538
        # implemented for tensor case to compare to
539
        # Crop 1 to left, pad 2 to top, pad 3 to right, crop 3 to bottom
540
        symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode="symmetric")
541
        symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3]
542
        symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:]
543
        assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8))
544
        assert_equal(symmetric_neg_middle_right, np.asarray([200, 200, 0, 0], dtype=np.uint8))
545
        assert transforms.PILToTensor()(symmetric_padded_img_neg).size() == (3, 28, 31)
546

547
    def test_pad_raises_with_invalid_pad_sequence_len(self):
548
        with pytest.raises(ValueError):
549
            transforms.Pad(())
550

551
        with pytest.raises(ValueError):
552
            transforms.Pad((1, 2, 3))
553

554
        with pytest.raises(ValueError):
555
            transforms.Pad((1, 2, 3, 4, 5))
556

557
    def test_pad_with_mode_F_images(self):
558
        pad = 2
559
        transform = transforms.Pad(pad)
560

561
        img = Image.new("F", (10, 10))
562
        padded_img = transform(img)
563
        assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size])
564

565

566
@pytest.mark.parametrize(
567
    "fn, trans, kwargs",
568
    [
569
        (F.invert, transforms.RandomInvert, {}),
570
        (F.posterize, transforms.RandomPosterize, {"bits": 4}),
571
        (F.solarize, transforms.RandomSolarize, {"threshold": 192}),
572
        (F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}),
573
        (F.autocontrast, transforms.RandomAutocontrast, {}),
574
        (F.equalize, transforms.RandomEqualize, {}),
575
        (F.vflip, transforms.RandomVerticalFlip, {}),
576
        (F.hflip, transforms.RandomHorizontalFlip, {}),
577
        (partial(F.to_grayscale, num_output_channels=3), transforms.RandomGrayscale, {}),
578
    ],
579
)
580
@pytest.mark.parametrize("seed", range(10))
581
@pytest.mark.parametrize("p", (0, 1))
582
def test_randomness(fn, trans, kwargs, seed, p):
583
    torch.manual_seed(seed)
584
    img = transforms.ToPILImage()(torch.rand(3, 16, 18))
585

586
    expected_transformed_img = fn(img, **kwargs)
587
    randomly_transformed_img = trans(p=p, **kwargs)(img)
588

589
    if p == 0:
590
        assert randomly_transformed_img == img
591
    elif p == 1:
592
        assert randomly_transformed_img == expected_transformed_img
593

594
    trans(**kwargs).__repr__()
595

596

597
def test_autocontrast_equal_minmax():
598
    img_tensor = torch.tensor([[[10]], [[128]], [[245]]], dtype=torch.uint8).expand(3, 32, 32)
599
    img_pil = F.to_pil_image(img_tensor)
600

601
    img_tensor = F.autocontrast(img_tensor)
602
    img_pil = F.autocontrast(img_pil)
603
    torch.testing.assert_close(img_tensor, F.pil_to_tensor(img_pil))
604

605

606
class TestToPil:
607
    def _get_1_channel_tensor_various_types():
608
        img_data_float = torch.Tensor(1, 4, 4).uniform_()
609
        expected_output = img_data_float.mul(255).int().float().div(255).numpy()
610
        yield img_data_float, expected_output, "L"
611

612
        img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
613
        expected_output = img_data_byte.float().div(255.0).numpy()
614
        yield img_data_byte, expected_output, "L"
615

616
        img_data_short = torch.ShortTensor(1, 4, 4).random_()
617
        expected_output = img_data_short.numpy()
618
        yield img_data_short, expected_output, "I;16" if sys.byteorder == "little" else "I;16B"
619

620
        img_data_int = torch.IntTensor(1, 4, 4).random_()
621
        expected_output = img_data_int.numpy()
622
        yield img_data_int, expected_output, "I"
623

624
    def _get_2d_tensor_various_types():
625
        img_data_float = torch.Tensor(4, 4).uniform_()
626
        expected_output = img_data_float.mul(255).int().float().div(255).numpy()
627
        yield img_data_float, expected_output, "L"
628

629
        img_data_byte = torch.ByteTensor(4, 4).random_(0, 255)
630
        expected_output = img_data_byte.float().div(255.0).numpy()
631
        yield img_data_byte, expected_output, "L"
632

633
        img_data_short = torch.ShortTensor(4, 4).random_()
634
        expected_output = img_data_short.numpy()
635
        yield img_data_short, expected_output, "I;16" if sys.byteorder == "little" else "I;16B"
636

637
        img_data_int = torch.IntTensor(4, 4).random_()
638
        expected_output = img_data_int.numpy()
639
        yield img_data_int, expected_output, "I"
640

641
    @pytest.mark.parametrize("with_mode", [False, True])
642
    @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_1_channel_tensor_various_types())
643
    def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode):
644
        transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
645
        to_tensor = transforms.ToTensor()
646

647
        img = transform(img_data)
648
        assert img.mode == expected_mode
649
        torch.testing.assert_close(expected_output, to_tensor(img).numpy())
650

651
    def test_1_channel_float_tensor_to_pil_image(self):
652
        img_data = torch.Tensor(1, 4, 4).uniform_()
653
        # 'F' mode for torch.FloatTensor
654
        img_F_mode = transforms.ToPILImage(mode="F")(img_data)
655
        assert img_F_mode.mode == "F"
656
        torch.testing.assert_close(
657
            np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode)
658
        )
659

660
    @pytest.mark.parametrize("with_mode", [False, True])
661
    @pytest.mark.parametrize(
662
        "img_data, expected_mode",
663
        [
664
            (torch.Tensor(4, 4, 1).uniform_().numpy(), "L"),
665
            (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"),
666
            (torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16" if sys.byteorder == "little" else "I;16B"),
667
            (torch.IntTensor(4, 4, 1).random_().numpy(), "I"),
668
        ],
669
    )
670
    def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
671
        transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
672
        img = transform(img_data)
673
        assert img.mode == expected_mode
674
        if np.issubdtype(img_data.dtype, np.floating):
675
            img_data = (img_data * 255).astype(np.uint8)
676
        # note: we explicitly convert img's dtype because pytorch doesn't support uint16
677
        # and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
678
        torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype))
679

680
    @pytest.mark.parametrize("expected_mode", [None, "LA"])
681
    def test_2_channel_ndarray_to_pil_image(self, expected_mode):
682
        img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
683

684
        if expected_mode is None:
685
            img = transforms.ToPILImage()(img_data)
686
            assert img.mode == "LA"  # default should assume LA
687
        else:
688
            img = transforms.ToPILImage(mode=expected_mode)(img_data)
689
            assert img.mode == expected_mode
690
        split = img.split()
691
        for i in range(2):
692
            torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
693

694
    def test_2_channel_ndarray_to_pil_image_error(self):
695
        img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
696
        transforms.ToPILImage().__repr__()
697

698
        # should raise if we try a mode for 4 or 1 or 3 channel images
699
        with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
700
            transforms.ToPILImage(mode="RGBA")(img_data)
701
        with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
702
            transforms.ToPILImage(mode="P")(img_data)
703
        with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
704
            transforms.ToPILImage(mode="RGB")(img_data)
705

706
    @pytest.mark.parametrize("expected_mode", [None, "LA"])
707
    def test_2_channel_tensor_to_pil_image(self, expected_mode):
708
        img_data = torch.Tensor(2, 4, 4).uniform_()
709
        expected_output = img_data.mul(255).int().float().div(255)
710
        if expected_mode is None:
711
            img = transforms.ToPILImage()(img_data)
712
            assert img.mode == "LA"  # default should assume LA
713
        else:
714
            img = transforms.ToPILImage(mode=expected_mode)(img_data)
715
            assert img.mode == expected_mode
716

717
        split = img.split()
718
        for i in range(2):
719
            torch.testing.assert_close(expected_output[i].numpy(), F.to_tensor(split[i]).squeeze(0).numpy())
720

721
    def test_2_channel_tensor_to_pil_image_error(self):
722
        img_data = torch.Tensor(2, 4, 4).uniform_()
723

724
        # should raise if we try a mode for 4 or 1 or 3 channel images
725
        with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
726
            transforms.ToPILImage(mode="RGBA")(img_data)
727
        with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
728
            transforms.ToPILImage(mode="P")(img_data)
729
        with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"):
730
            transforms.ToPILImage(mode="RGB")(img_data)
731

732
    @pytest.mark.parametrize("with_mode", [False, True])
733
    @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_2d_tensor_various_types())
734
    def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode):
735
        transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
736
        to_tensor = transforms.ToTensor()
737

738
        img = transform(img_data)
739
        assert img.mode == expected_mode
740
        torch.testing.assert_close(expected_output, to_tensor(img).numpy()[0])
741

742
    @pytest.mark.parametrize("with_mode", [False, True])
743
    @pytest.mark.parametrize(
744
        "img_data, expected_mode",
745
        [
746
            (torch.Tensor(4, 4).uniform_().numpy(), "L"),
747
            (torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"),
748
            (torch.ShortTensor(4, 4).random_().numpy(), "I;16" if sys.byteorder == "little" else "I;16B"),
749
            (torch.IntTensor(4, 4).random_().numpy(), "I"),
750
        ],
751
    )
752
    def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
753
        transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
754
        img = transform(img_data)
755
        assert img.mode == expected_mode
756
        if np.issubdtype(img_data.dtype, np.floating):
757
            img_data = (img_data * 255).astype(np.uint8)
758
        np.testing.assert_allclose(img_data, img)
759

760
    @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
761
    def test_3_channel_tensor_to_pil_image(self, expected_mode):
762
        img_data = torch.Tensor(3, 4, 4).uniform_()
763
        expected_output = img_data.mul(255).int().float().div(255)
764

765
        if expected_mode is None:
766
            img = transforms.ToPILImage()(img_data)
767
            assert img.mode == "RGB"  # default should assume RGB
768
        else:
769
            img = transforms.ToPILImage(mode=expected_mode)(img_data)
770
            assert img.mode == expected_mode
771
        split = img.split()
772
        for i in range(3):
773
            torch.testing.assert_close(expected_output[i].numpy(), F.to_tensor(split[i]).squeeze(0).numpy())
774

775
    def test_3_channel_tensor_to_pil_image_error(self):
776
        img_data = torch.Tensor(3, 4, 4).uniform_()
777
        error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs"
778
        # should raise if we try a mode for 4 or 1 or 2 channel images
779
        with pytest.raises(ValueError, match=error_message_3d):
780
            transforms.ToPILImage(mode="RGBA")(img_data)
781
        with pytest.raises(ValueError, match=error_message_3d):
782
            transforms.ToPILImage(mode="P")(img_data)
783
        with pytest.raises(ValueError, match=error_message_3d):
784
            transforms.ToPILImage(mode="LA")(img_data)
785

786
        with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
787
            transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_())
788

789
    @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
790
    def test_3_channel_ndarray_to_pil_image(self, expected_mode):
791
        img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
792

793
        if expected_mode is None:
794
            img = transforms.ToPILImage()(img_data)
795
            assert img.mode == "RGB"  # default should assume RGB
796
        else:
797
            img = transforms.ToPILImage(mode=expected_mode)(img_data)
798
            assert img.mode == expected_mode
799
        split = img.split()
800
        for i in range(3):
801
            torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
802

803
    def test_3_channel_ndarray_to_pil_image_error(self):
804
        img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
805

806
        # Checking if ToPILImage can be printed as string
807
        transforms.ToPILImage().__repr__()
808

809
        error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs"
810
        # should raise if we try a mode for 4 or 1 or 2 channel images
811
        with pytest.raises(ValueError, match=error_message_3d):
812
            transforms.ToPILImage(mode="RGBA")(img_data)
813
        with pytest.raises(ValueError, match=error_message_3d):
814
            transforms.ToPILImage(mode="P")(img_data)
815
        with pytest.raises(ValueError, match=error_message_3d):
816
            transforms.ToPILImage(mode="LA")(img_data)
817

818
    @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"])
819
    def test_4_channel_tensor_to_pil_image(self, expected_mode):
820
        img_data = torch.Tensor(4, 4, 4).uniform_()
821
        expected_output = img_data.mul(255).int().float().div(255)
822

823
        if expected_mode is None:
824
            img = transforms.ToPILImage()(img_data)
825
            assert img.mode == "RGBA"  # default should assume RGBA
826
        else:
827
            img = transforms.ToPILImage(mode=expected_mode)(img_data)
828
            assert img.mode == expected_mode
829

830
        split = img.split()
831
        for i in range(4):
832
            torch.testing.assert_close(expected_output[i].numpy(), F.to_tensor(split[i]).squeeze(0).numpy())
833

834
    def test_4_channel_tensor_to_pil_image_error(self):
835
        img_data = torch.Tensor(4, 4, 4).uniform_()
836

837
        error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs"
838
        # should raise if we try a mode for 3 or 1 or 2 channel images
839
        with pytest.raises(ValueError, match=error_message_4d):
840
            transforms.ToPILImage(mode="RGB")(img_data)
841
        with pytest.raises(ValueError, match=error_message_4d):
842
            transforms.ToPILImage(mode="P")(img_data)
843
        with pytest.raises(ValueError, match=error_message_4d):
844
            transforms.ToPILImage(mode="LA")(img_data)
845

846
    @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"])
847
    def test_4_channel_ndarray_to_pil_image(self, expected_mode):
848
        img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
849

850
        if expected_mode is None:
851
            img = transforms.ToPILImage()(img_data)
852
            assert img.mode == "RGBA"  # default should assume RGBA
853
        else:
854
            img = transforms.ToPILImage(mode=expected_mode)(img_data)
855
            assert img.mode == expected_mode
856
        split = img.split()
857
        for i in range(4):
858
            torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
859

860
    def test_4_channel_ndarray_to_pil_image_error(self):
861
        img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
862

863
        error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs"
864
        # should raise if we try a mode for 3 or 1 or 2 channel images
865
        with pytest.raises(ValueError, match=error_message_4d):
866
            transforms.ToPILImage(mode="RGB")(img_data)
867
        with pytest.raises(ValueError, match=error_message_4d):
868
            transforms.ToPILImage(mode="P")(img_data)
869
        with pytest.raises(ValueError, match=error_message_4d):
870
            transforms.ToPILImage(mode="LA")(img_data)
871

872
    def test_ndarray_bad_types_to_pil_image(self):
873
        trans = transforms.ToPILImage()
874
        reg_msg = r"Input type \w+ is not supported"
875
        with pytest.raises(TypeError, match=reg_msg):
876
            trans(np.ones([4, 4, 1], np.int64))
877
        with pytest.raises(TypeError, match=reg_msg):
878
            trans(np.ones([4, 4, 1], np.uint16))
879
        with pytest.raises(TypeError, match=reg_msg):
880
            trans(np.ones([4, 4, 1], np.uint32))
881

882
        with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
883
            transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
884
        with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."):
885
            transforms.ToPILImage()(np.ones([4, 4, 6]))
886

887
    def test_tensor_bad_types_to_pil_image(self):
888
        with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
889
            transforms.ToPILImage()(torch.ones(1, 3, 4, 4))
890
        with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."):
891
            transforms.ToPILImage()(torch.ones(6, 4, 4))
892

893

894
def test_adjust_brightness():
895
    x_shape = [2, 2, 3]
896
    x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
897
    x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
898
    x_pil = Image.fromarray(x_np, mode="RGB")
899

900
    # test 0
901
    y_pil = F.adjust_brightness(x_pil, 1)
902
    y_np = np.array(y_pil)
903
    torch.testing.assert_close(y_np, x_np)
904

905
    # test 1
906
    y_pil = F.adjust_brightness(x_pil, 0.5)
907
    y_np = np.array(y_pil)
908
    y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
909
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
910
    torch.testing.assert_close(y_np, y_ans)
911

912
    # test 2
913
    y_pil = F.adjust_brightness(x_pil, 2)
914
    y_np = np.array(y_pil)
915
    y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
916
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
917
    torch.testing.assert_close(y_np, y_ans)
918

919

920
def test_adjust_contrast():
921
    x_shape = [2, 2, 3]
922
    x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
923
    x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
924
    x_pil = Image.fromarray(x_np, mode="RGB")
925

926
    # test 0
927
    y_pil = F.adjust_contrast(x_pil, 1)
928
    y_np = np.array(y_pil)
929
    torch.testing.assert_close(y_np, x_np)
930

931
    # test 1
932
    y_pil = F.adjust_contrast(x_pil, 0.5)
933
    y_np = np.array(y_pil)
934
    y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
935
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
936
    torch.testing.assert_close(y_np, y_ans)
937

938
    # test 2
939
    y_pil = F.adjust_contrast(x_pil, 2)
940
    y_np = np.array(y_pil)
941
    y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
942
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
943
    torch.testing.assert_close(y_np, y_ans)
944

945

946
def test_adjust_hue():
947
    x_shape = [2, 2, 3]
948
    x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
949
    x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
950
    x_pil = Image.fromarray(x_np, mode="RGB")
951

952
    with pytest.raises(ValueError):
953
        F.adjust_hue(x_pil, -0.7)
954
        F.adjust_hue(x_pil, 1)
955

956
    # test 0: almost same as x_data but not exact.
957
    # probably because hsv <-> rgb floating point ops
958
    y_pil = F.adjust_hue(x_pil, 0)
959
    y_np = np.array(y_pil)
960
    y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
961
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
962
    torch.testing.assert_close(y_np, y_ans)
963

964
    # test 1
965
    y_pil = F.adjust_hue(x_pil, 0.25)
966
    y_np = np.array(y_pil)
967
    y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
968
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
969
    torch.testing.assert_close(y_np, y_ans)
970

971
    # test 2
972
    y_pil = F.adjust_hue(x_pil, -0.25)
973
    y_np = np.array(y_pil)
974
    y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
975
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
976
    torch.testing.assert_close(y_np, y_ans)
977

978

979
def test_adjust_sharpness():
980
    x_shape = [4, 4, 3]
981
    x_data = [
982
        75,
983
        121,
984
        114,
985
        105,
986
        97,
987
        107,
988
        105,
989
        32,
990
        66,
991
        111,
992
        117,
993
        114,
994
        99,
995
        104,
996
        97,
997
        0,
998
        0,
999
        65,
1000
        108,
1001
        101,
1002
        120,
1003
        97,
1004
        110,
1005
        100,
1006
        101,
1007
        114,
1008
        32,
1009
        86,
1010
        114,
1011
        121,
1012
        110,
1013
        105,
1014
        111,
1015
        116,
1016
        105,
1017
        115,
1018
        0,
1019
        0,
1020
        73,
1021
        32,
1022
        108,
1023
        111,
1024
        118,
1025
        101,
1026
        32,
1027
        121,
1028
        111,
1029
        117,
1030
    ]
1031
    x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1032
    x_pil = Image.fromarray(x_np, mode="RGB")
1033

1034
    # test 0
1035
    y_pil = F.adjust_sharpness(x_pil, 1)
1036
    y_np = np.array(y_pil)
1037
    torch.testing.assert_close(y_np, x_np)
1038

1039
    # test 1
1040
    y_pil = F.adjust_sharpness(x_pil, 0.5)
1041
    y_np = np.array(y_pil)
1042
    y_ans = [
1043
        75,
1044
        121,
1045
        114,
1046
        105,
1047
        97,
1048
        107,
1049
        105,
1050
        32,
1051
        66,
1052
        111,
1053
        117,
1054
        114,
1055
        99,
1056
        104,
1057
        97,
1058
        30,
1059
        30,
1060
        74,
1061
        103,
1062
        96,
1063
        114,
1064
        97,
1065
        110,
1066
        100,
1067
        101,
1068
        114,
1069
        32,
1070
        81,
1071
        103,
1072
        108,
1073
        102,
1074
        101,
1075
        107,
1076
        116,
1077
        105,
1078
        115,
1079
        0,
1080
        0,
1081
        73,
1082
        32,
1083
        108,
1084
        111,
1085
        118,
1086
        101,
1087
        32,
1088
        121,
1089
        111,
1090
        117,
1091
    ]
1092
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1093
    torch.testing.assert_close(y_np, y_ans)
1094

1095
    # test 2
1096
    y_pil = F.adjust_sharpness(x_pil, 2)
1097
    y_np = np.array(y_pil)
1098
    y_ans = [
1099
        75,
1100
        121,
1101
        114,
1102
        105,
1103
        97,
1104
        107,
1105
        105,
1106
        32,
1107
        66,
1108
        111,
1109
        117,
1110
        114,
1111
        99,
1112
        104,
1113
        97,
1114
        0,
1115
        0,
1116
        46,
1117
        118,
1118
        111,
1119
        132,
1120
        97,
1121
        110,
1122
        100,
1123
        101,
1124
        114,
1125
        32,
1126
        95,
1127
        135,
1128
        146,
1129
        126,
1130
        112,
1131
        119,
1132
        116,
1133
        105,
1134
        115,
1135
        0,
1136
        0,
1137
        73,
1138
        32,
1139
        108,
1140
        111,
1141
        118,
1142
        101,
1143
        32,
1144
        121,
1145
        111,
1146
        117,
1147
    ]
1148
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1149
    torch.testing.assert_close(y_np, y_ans)
1150

1151
    # test 3
1152
    x_shape = [2, 2, 3]
1153
    x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
1154
    x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1155
    x_pil = Image.fromarray(x_np, mode="RGB")
1156
    x_th = torch.tensor(x_np.transpose(2, 0, 1))
1157
    y_pil = F.adjust_sharpness(x_pil, 2)
1158
    y_np = np.array(y_pil).transpose(2, 0, 1)
1159
    y_th = F.adjust_sharpness(x_th, 2)
1160
    torch.testing.assert_close(y_np, y_th.numpy())
1161

1162

1163
def test_adjust_gamma():
1164
    x_shape = [2, 2, 3]
1165
    x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
1166
    x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1167
    x_pil = Image.fromarray(x_np, mode="RGB")
1168

1169
    # test 0
1170
    y_pil = F.adjust_gamma(x_pil, 1)
1171
    y_np = np.array(y_pil)
1172
    torch.testing.assert_close(y_np, x_np)
1173

1174
    # test 1
1175
    y_pil = F.adjust_gamma(x_pil, 0.5)
1176
    y_np = np.array(y_pil)
1177
    y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16]
1178
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1179
    torch.testing.assert_close(y_np, y_ans)
1180

1181
    # test 2
1182
    y_pil = F.adjust_gamma(x_pil, 2)
1183
    y_np = np.array(y_pil)
1184
    y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0]
1185
    y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
1186
    torch.testing.assert_close(y_np, y_ans)
1187

1188

1189
def test_adjusts_L_mode():
1190
    x_shape = [2, 2, 3]
1191
    x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
1192
    x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1193
    x_rgb = Image.fromarray(x_np, mode="RGB")
1194

1195
    x_l = x_rgb.convert("L")
1196
    assert F.adjust_brightness(x_l, 2).mode == "L"
1197
    assert F.adjust_saturation(x_l, 2).mode == "L"
1198
    assert F.adjust_contrast(x_l, 2).mode == "L"
1199
    assert F.adjust_hue(x_l, 0.4).mode == "L"
1200
    assert F.adjust_sharpness(x_l, 2).mode == "L"
1201
    assert F.adjust_gamma(x_l, 0.5).mode == "L"
1202

1203

1204
def test_rotate():
1205
    x = np.zeros((100, 100, 3), dtype=np.uint8)
1206
    x[40, 40] = [255, 255, 255]
1207

1208
    with pytest.raises(TypeError, match=r"img should be PIL Image"):
1209
        F.rotate(x, 10)
1210

1211
    img = F.to_pil_image(x)
1212

1213
    result = F.rotate(img, 45)
1214
    assert result.size == (100, 100)
1215
    r, c, ch = np.where(result)
1216
    assert all(x in r for x in [49, 50])
1217
    assert all(x in c for x in [36])
1218
    assert all(x in ch for x in [0, 1, 2])
1219

1220
    result = F.rotate(img, 45, expand=True)
1221
    assert result.size == (142, 142)
1222
    r, c, ch = np.where(result)
1223
    assert all(x in r for x in [70, 71])
1224
    assert all(x in c for x in [57])
1225
    assert all(x in ch for x in [0, 1, 2])
1226

1227
    result = F.rotate(img, 45, center=(40, 40))
1228
    assert result.size == (100, 100)
1229
    r, c, ch = np.where(result)
1230
    assert all(x in r for x in [40])
1231
    assert all(x in c for x in [40])
1232
    assert all(x in ch for x in [0, 1, 2])
1233

1234
    result_a = F.rotate(img, 90)
1235
    result_b = F.rotate(img, -270)
1236

1237
    assert_equal(np.array(result_a), np.array(result_b))
1238

1239

1240
@pytest.mark.parametrize("mode", ["L", "RGB", "F"])
1241
def test_rotate_fill(mode):
1242
    img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB")
1243

1244
    num_bands = len(mode)
1245
    wrong_num_bands = num_bands + 1
1246
    fill = 127
1247

1248
    img_conv = img.convert(mode)
1249
    img_rot = F.rotate(img_conv, 45.0, fill=fill)
1250
    pixel = img_rot.getpixel((0, 0))
1251

1252
    if not isinstance(pixel, tuple):
1253
        pixel = (pixel,)
1254
    assert pixel == tuple([fill] * num_bands)
1255

1256
    with pytest.raises(ValueError):
1257
        F.rotate(img_conv, 45.0, fill=tuple([fill] * wrong_num_bands))
1258

1259

1260
def test_gaussian_blur_asserts():
1261
    np_img = np.ones((100, 100, 3), dtype=np.uint8) * 255
1262
    img = F.to_pil_image(np_img, "RGB")
1263

1264
    with pytest.raises(ValueError, match=r"If kernel_size is a sequence its length should be 2"):
1265
        F.gaussian_blur(img, [3])
1266
    with pytest.raises(ValueError, match=r"If kernel_size is a sequence its length should be 2"):
1267
        F.gaussian_blur(img, [3, 3, 3])
1268
    with pytest.raises(ValueError, match=r"Kernel size should be a tuple/list of two integers"):
1269
        transforms.GaussianBlur([3, 3, 3])
1270

1271
    with pytest.raises(ValueError, match=r"kernel_size should have odd and positive integers"):
1272
        F.gaussian_blur(img, [4, 4])
1273
    with pytest.raises(ValueError, match=r"Kernel size value should be an odd and positive number"):
1274
        transforms.GaussianBlur([4, 4])
1275

1276
    with pytest.raises(ValueError, match=r"kernel_size should have odd and positive integers"):
1277
        F.gaussian_blur(img, [-3, -3])
1278
    with pytest.raises(ValueError, match=r"Kernel size value should be an odd and positive number"):
1279
        transforms.GaussianBlur([-3, -3])
1280

1281
    with pytest.raises(ValueError, match=r"If sigma is a sequence, its length should be 2"):
1282
        F.gaussian_blur(img, 3, [1, 1, 1])
1283
    with pytest.raises(ValueError, match=r"sigma should be a single number or a list/tuple with length 2"):
1284
        transforms.GaussianBlur(3, [1, 1, 1])
1285

1286
    with pytest.raises(ValueError, match=r"sigma should have positive values"):
1287
        F.gaussian_blur(img, 3, -1.0)
1288
    with pytest.raises(ValueError, match=r"If sigma is a single number, it must be positive"):
1289
        transforms.GaussianBlur(3, -1.0)
1290

1291
    with pytest.raises(TypeError, match=r"kernel_size should be int or a sequence of integers"):
1292
        F.gaussian_blur(img, "kernel_size_string")
1293
    with pytest.raises(ValueError, match=r"Kernel size should be a tuple/list of two integers"):
1294
        transforms.GaussianBlur("kernel_size_string")
1295

1296
    with pytest.raises(TypeError, match=r"sigma should be either float or sequence of floats"):
1297
        F.gaussian_blur(img, 3, "sigma_string")
1298
    with pytest.raises(ValueError, match=r"sigma should be a single number or a list/tuple with length 2"):
1299
        transforms.GaussianBlur(3, "sigma_string")
1300

1301

1302
def test_lambda():
1303
    trans = transforms.Lambda(lambda x: x.add(10))
1304
    x = torch.randn(10)
1305
    y = trans(x)
1306
    assert_equal(y, torch.add(x, 10))
1307

1308
    trans = transforms.Lambda(lambda x: x.add_(10))
1309
    x = torch.randn(10)
1310
    y = trans(x)
1311
    assert_equal(y, x)
1312

1313
    # Checking if Lambda can be printed as string
1314
    trans.__repr__()
1315

1316

1317
def test_to_grayscale():
1318
    """Unit tests for grayscale transform"""
1319

1320
    x_shape = [2, 2, 3]
1321
    x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
1322
    x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1323
    x_pil = Image.fromarray(x_np, mode="RGB")
1324
    x_pil_2 = x_pil.convert("L")
1325
    gray_np = np.array(x_pil_2)
1326

1327
    # Test Set: Grayscale an image with desired number of output channels
1328
    # Case 1: RGB -> 1 channel grayscale
1329
    trans1 = transforms.Grayscale(num_output_channels=1)
1330
    gray_pil_1 = trans1(x_pil)
1331
    gray_np_1 = np.array(gray_pil_1)
1332
    assert gray_pil_1.mode == "L", "mode should be L"
1333
    assert gray_np_1.shape == tuple(x_shape[0:2]), "should be 1 channel"
1334
    assert_equal(gray_np, gray_np_1)
1335

1336
    # Case 2: RGB -> 3 channel grayscale
1337
    trans2 = transforms.Grayscale(num_output_channels=3)
1338
    gray_pil_2 = trans2(x_pil)
1339
    gray_np_2 = np.array(gray_pil_2)
1340
    assert gray_pil_2.mode == "RGB", "mode should be RGB"
1341
    assert gray_np_2.shape == tuple(x_shape), "should be 3 channel"
1342
    assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
1343
    assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
1344
    assert_equal(gray_np, gray_np_2[:, :, 0])
1345

1346
    # Case 3: 1 channel grayscale -> 1 channel grayscale
1347
    trans3 = transforms.Grayscale(num_output_channels=1)
1348
    gray_pil_3 = trans3(x_pil_2)
1349
    gray_np_3 = np.array(gray_pil_3)
1350
    assert gray_pil_3.mode == "L", "mode should be L"
1351
    assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel"
1352
    assert_equal(gray_np, gray_np_3)
1353

1354
    # Case 4: 1 channel grayscale -> 3 channel grayscale
1355
    trans4 = transforms.Grayscale(num_output_channels=3)
1356
    gray_pil_4 = trans4(x_pil_2)
1357
    gray_np_4 = np.array(gray_pil_4)
1358
    assert gray_pil_4.mode == "RGB", "mode should be RGB"
1359
    assert gray_np_4.shape == tuple(x_shape), "should be 3 channel"
1360
    assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
1361
    assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
1362
    assert_equal(gray_np, gray_np_4[:, :, 0])
1363

1364
    # Checking if Grayscale can be printed as string
1365
    trans4.__repr__()
1366

1367

1368
@pytest.mark.parametrize("seed", range(10))
1369
@pytest.mark.parametrize("p", (0, 1))
1370
def test_random_apply(p, seed):
1371
    torch.manual_seed(seed)
1372
    random_apply_transform = transforms.RandomApply([transforms.RandomRotation((45, 50))], p=p)
1373
    img = transforms.ToPILImage()(torch.rand(3, 30, 40))
1374
    out = random_apply_transform(img)
1375
    if p == 0:
1376
        assert out == img
1377
    elif p == 1:
1378
        assert out != img
1379

1380
    # Checking if RandomApply can be printed as string
1381
    random_apply_transform.__repr__()
1382

1383

1384
@pytest.mark.parametrize("seed", range(10))
1385
@pytest.mark.parametrize("proba_passthrough", (0, 1))
1386
def test_random_choice(proba_passthrough, seed):
1387
    random.seed(seed)  # RandomChoice relies on python builtin random.choice, not pytorch
1388

1389
    random_choice_transform = transforms.RandomChoice(
1390
        [
1391
            lambda x: x,  # passthrough
1392
            transforms.RandomRotation((45, 50)),
1393
        ],
1394
        p=[proba_passthrough, 1 - proba_passthrough],
1395
    )
1396

1397
    img = transforms.ToPILImage()(torch.rand(3, 30, 40))
1398
    out = random_choice_transform(img)
1399
    if proba_passthrough == 1:
1400
        assert out == img
1401
    elif proba_passthrough == 0:
1402
        assert out != img
1403

1404
    # Checking if RandomChoice can be printed as string
1405
    random_choice_transform.__repr__()
1406

1407

1408
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
1409
def test_random_order():
1410
    random_state = random.getstate()
1411
    random.seed(42)
1412
    random_order_transform = transforms.RandomOrder([transforms.Resize(20, antialias=True), transforms.CenterCrop(10)])
1413
    img = transforms.ToPILImage()(torch.rand(3, 25, 25))
1414
    num_samples = 250
1415
    num_normal_order = 0
1416
    resize_crop_out = transforms.CenterCrop(10)(transforms.Resize(20, antialias=True)(img))
1417
    for _ in range(num_samples):
1418
        out = random_order_transform(img)
1419
        if out == resize_crop_out:
1420
            num_normal_order += 1
1421

1422
    p_value = stats.binomtest(num_normal_order, num_samples, p=0.5).pvalue
1423
    random.setstate(random_state)
1424
    assert p_value > 0.0001
1425

1426
    # Checking if RandomOrder can be printed as string
1427
    random_order_transform.__repr__()
1428

1429

1430
def test_linear_transformation():
1431
    num_samples = 1000
1432
    x = torch.randn(num_samples, 3, 10, 10)
1433
    flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
1434
    # compute principal components
1435
    sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
1436
    u, s, _ = np.linalg.svd(sigma.numpy())
1437
    zca_epsilon = 1e-10  # avoid division by 0
1438
    d = torch.Tensor(np.diag(1.0 / np.sqrt(s + zca_epsilon)))
1439
    u = torch.Tensor(u)
1440
    principal_components = torch.mm(torch.mm(u, d), u.t())
1441
    mean_vector = torch.sum(flat_x, dim=0) / flat_x.size(0)
1442
    # initialize whitening matrix
1443
    whitening = transforms.LinearTransformation(principal_components, mean_vector)
1444
    # estimate covariance and mean using weak law of large number
1445
    num_features = flat_x.size(1)
1446
    cov = 0.0
1447
    mean = 0.0
1448
    for i in x:
1449
        xwhite = whitening(i)
1450
        xwhite = xwhite.view(1, -1).numpy()
1451
        cov += np.dot(xwhite, xwhite.T) / num_features
1452
        mean += np.sum(xwhite) / num_features
1453
    # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
1454
    torch.testing.assert_close(
1455
        cov / num_samples, np.identity(1), rtol=2e-3, atol=1e-8, check_dtype=False, msg="cov not close to 1"
1456
    )
1457
    torch.testing.assert_close(
1458
        mean / num_samples, 0, rtol=1e-3, atol=1e-8, check_dtype=False, msg="mean not close to 0"
1459
    )
1460

1461
    # Checking if LinearTransformation can be printed as string
1462
    whitening.__repr__()
1463

1464

1465
@pytest.mark.parametrize("dtype", int_dtypes())
1466
def test_max_value(dtype):
1467

1468
    assert F_t._max_value(dtype) == torch.iinfo(dtype).max
1469
    # remove float testing as it can lead to errors such as
1470
    # runtime error: 5.7896e+76 is outside the range of representable values of type 'float'
1471
    # for dtype in float_dtypes():
1472
    # self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max)
1473

1474

1475
@pytest.mark.xfail(
1476
    reason="torch.iinfo() is not supported by torchscript. See https://github.com/pytorch/pytorch/issues/41492."
1477
)
1478
def test_max_value_iinfo():
1479
    @torch.jit.script
1480
    def max_value(image: torch.Tensor) -> int:
1481
        return 1 if image.is_floating_point() else torch.iinfo(image.dtype).max
1482

1483

1484
@pytest.mark.parametrize("should_vflip", [True, False])
1485
@pytest.mark.parametrize("single_dim", [True, False])
1486
def test_ten_crop(should_vflip, single_dim):
1487
    to_pil_image = transforms.ToPILImage()
1488
    h = random.randint(5, 25)
1489
    w = random.randint(5, 25)
1490
    crop_h = random.randint(1, h)
1491
    crop_w = random.randint(1, w)
1492
    if single_dim:
1493
        crop_h = min(crop_h, crop_w)
1494
        crop_w = crop_h
1495
        transform = transforms.TenCrop(crop_h, vertical_flip=should_vflip)
1496
        five_crop = transforms.FiveCrop(crop_h)
1497
    else:
1498
        transform = transforms.TenCrop((crop_h, crop_w), vertical_flip=should_vflip)
1499
        five_crop = transforms.FiveCrop((crop_h, crop_w))
1500

1501
    img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
1502
    results = transform(img)
1503
    expected_output = five_crop(img)
1504

1505
    # Checking if FiveCrop and TenCrop can be printed as string
1506
    transform.__repr__()
1507
    five_crop.__repr__()
1508

1509
    if should_vflip:
1510
        vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
1511
        expected_output += five_crop(vflipped_img)
1512
    else:
1513
        hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
1514
        expected_output += five_crop(hflipped_img)
1515

1516
    assert len(results) == 10
1517
    assert results == expected_output
1518

1519

1520
@pytest.mark.parametrize("single_dim", [True, False])
1521
def test_five_crop(single_dim):
1522
    to_pil_image = transforms.ToPILImage()
1523
    h = random.randint(5, 25)
1524
    w = random.randint(5, 25)
1525
    crop_h = random.randint(1, h)
1526
    crop_w = random.randint(1, w)
1527
    if single_dim:
1528
        crop_h = min(crop_h, crop_w)
1529
        crop_w = crop_h
1530
        transform = transforms.FiveCrop(crop_h)
1531
    else:
1532
        transform = transforms.FiveCrop((crop_h, crop_w))
1533

1534
    img = torch.FloatTensor(3, h, w).uniform_()
1535

1536
    results = transform(to_pil_image(img))
1537

1538
    assert len(results) == 5
1539
    for crop in results:
1540
        assert crop.size == (crop_w, crop_h)
1541

1542
    to_pil_image = transforms.ToPILImage()
1543
    tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
1544
    tr = to_pil_image(img[:, 0:crop_h, w - crop_w :])
1545
    bl = to_pil_image(img[:, h - crop_h :, 0:crop_w])
1546
    br = to_pil_image(img[:, h - crop_h :, w - crop_w :])
1547
    center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img))
1548
    expected_output = (tl, tr, bl, br, center)
1549
    assert results == expected_output
1550

1551

1552
@pytest.mark.parametrize("policy", transforms.AutoAugmentPolicy)
1553
@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
1554
@pytest.mark.parametrize("grayscale", [True, False])
1555
def test_autoaugment(policy, fill, grayscale):
1556
    random.seed(42)
1557
    img = Image.open(GRACE_HOPPER)
1558
    if grayscale:
1559
        img, fill = _get_grayscale_test_image(img, fill)
1560
    transform = transforms.AutoAugment(policy=policy, fill=fill)
1561
    for _ in range(100):
1562
        img = transform(img)
1563
    transform.__repr__()
1564

1565

1566
@pytest.mark.parametrize("num_ops", [1, 2, 3])
1567
@pytest.mark.parametrize("magnitude", [7, 9, 11])
1568
@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
1569
@pytest.mark.parametrize("grayscale", [True, False])
1570
def test_randaugment(num_ops, magnitude, fill, grayscale):
1571
    random.seed(42)
1572
    img = Image.open(GRACE_HOPPER)
1573
    if grayscale:
1574
        img, fill = _get_grayscale_test_image(img, fill)
1575
    transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
1576
    for _ in range(100):
1577
        img = transform(img)
1578
    transform.__repr__()
1579

1580

1581
@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
1582
@pytest.mark.parametrize("num_magnitude_bins", [10, 13, 30])
1583
@pytest.mark.parametrize("grayscale", [True, False])
1584
def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale):
1585
    random.seed(42)
1586
    img = Image.open(GRACE_HOPPER)
1587
    if grayscale:
1588
        img, fill = _get_grayscale_test_image(img, fill)
1589
    transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
1590
    for _ in range(100):
1591
        img = transform(img)
1592
    transform.__repr__()
1593

1594

1595
@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
1596
@pytest.mark.parametrize("severity", [1, 10])
1597
@pytest.mark.parametrize("mixture_width", [1, 2])
1598
@pytest.mark.parametrize("chain_depth", [-1, 2])
1599
@pytest.mark.parametrize("all_ops", [True, False])
1600
@pytest.mark.parametrize("grayscale", [True, False])
1601
def test_augmix(fill, severity, mixture_width, chain_depth, all_ops, grayscale):
1602
    random.seed(42)
1603
    img = Image.open(GRACE_HOPPER)
1604
    if grayscale:
1605
        img, fill = _get_grayscale_test_image(img, fill)
1606
    transform = transforms.AugMix(
1607
        fill=fill, severity=severity, mixture_width=mixture_width, chain_depth=chain_depth, all_ops=all_ops
1608
    )
1609
    for _ in range(100):
1610
        img = transform(img)
1611
    transform.__repr__()
1612

1613

1614
def test_random_crop():
1615
    height = random.randint(10, 32) * 2
1616
    width = random.randint(10, 32) * 2
1617
    oheight = random.randint(5, (height - 2) // 2) * 2
1618
    owidth = random.randint(5, (width - 2) // 2) * 2
1619
    img = torch.ones(3, height, width, dtype=torch.uint8)
1620
    result = transforms.Compose(
1621
        [
1622
            transforms.ToPILImage(),
1623
            transforms.RandomCrop((oheight, owidth)),
1624
            transforms.PILToTensor(),
1625
        ]
1626
    )(img)
1627
    assert result.size(1) == oheight
1628
    assert result.size(2) == owidth
1629

1630
    padding = random.randint(1, 20)
1631
    result = transforms.Compose(
1632
        [
1633
            transforms.ToPILImage(),
1634
            transforms.RandomCrop((oheight, owidth), padding=padding),
1635
            transforms.PILToTensor(),
1636
        ]
1637
    )(img)
1638
    assert result.size(1) == oheight
1639
    assert result.size(2) == owidth
1640

1641
    result = transforms.Compose(
1642
        [transforms.ToPILImage(), transforms.RandomCrop((height, width)), transforms.PILToTensor()]
1643
    )(img)
1644
    assert result.size(1) == height
1645
    assert result.size(2) == width
1646
    torch.testing.assert_close(result, img)
1647

1648
    result = transforms.Compose(
1649
        [
1650
            transforms.ToPILImage(),
1651
            transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
1652
            transforms.PILToTensor(),
1653
        ]
1654
    )(img)
1655
    assert result.size(1) == height + 1
1656
    assert result.size(2) == width + 1
1657

1658
    t = transforms.RandomCrop(33)
1659
    img = torch.ones(3, 32, 32)
1660
    with pytest.raises(ValueError, match=r"Required crop size .+ is larger than input image size .+"):
1661
        t(img)
1662

1663

1664
def test_center_crop():
1665
    height = random.randint(10, 32) * 2
1666
    width = random.randint(10, 32) * 2
1667
    oheight = random.randint(5, (height - 2) // 2) * 2
1668
    owidth = random.randint(5, (width - 2) // 2) * 2
1669

1670
    img = torch.ones(3, height, width, dtype=torch.uint8)
1671
    oh1 = (height - oheight) // 2
1672
    ow1 = (width - owidth) // 2
1673
    imgnarrow = img[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth]
1674
    imgnarrow.fill_(0)
1675
    result = transforms.Compose(
1676
        [
1677
            transforms.ToPILImage(),
1678
            transforms.CenterCrop((oheight, owidth)),
1679
            transforms.PILToTensor(),
1680
        ]
1681
    )(img)
1682
    assert result.sum() == 0
1683
    oheight += 1
1684
    owidth += 1
1685
    result = transforms.Compose(
1686
        [
1687
            transforms.ToPILImage(),
1688
            transforms.CenterCrop((oheight, owidth)),
1689
            transforms.PILToTensor(),
1690
        ]
1691
    )(img)
1692
    sum1 = result.sum()
1693
    assert sum1 > 1
1694
    oheight += 1
1695
    owidth += 1
1696
    result = transforms.Compose(
1697
        [
1698
            transforms.ToPILImage(),
1699
            transforms.CenterCrop((oheight, owidth)),
1700
            transforms.PILToTensor(),
1701
        ]
1702
    )(img)
1703
    sum2 = result.sum()
1704
    assert sum2 > 0
1705
    assert sum2 > sum1
1706

1707

1708
@pytest.mark.parametrize("odd_image_size", (True, False))
1709
@pytest.mark.parametrize("delta", (1, 3, 5))
1710
@pytest.mark.parametrize("delta_width", (-2, -1, 0, 1, 2))
1711
@pytest.mark.parametrize("delta_height", (-2, -1, 0, 1, 2))
1712
def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
1713
    """Tests when center crop size is larger than image size, along any dimension"""
1714

1715
    # Since height is independent of width, we can ignore images with odd height and even width and vice-versa.
1716
    input_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2)
1717
    if odd_image_size:
1718
        input_image_size = (input_image_size[0] + 1, input_image_size[1] + 1)
1719

1720
    delta_height *= delta
1721
    delta_width *= delta
1722

1723
    img = torch.ones(3, *input_image_size, dtype=torch.uint8)
1724
    crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width)
1725

1726
    # Test both transforms, one with PIL input and one with tensor
1727
    output_pil = transforms.Compose(
1728
        [transforms.ToPILImage(), transforms.CenterCrop(crop_size), transforms.PILToTensor()],
1729
    )(img)
1730
    assert output_pil.size()[1:3] == crop_size
1731

1732
    output_tensor = transforms.CenterCrop(crop_size)(img)
1733
    assert output_tensor.size()[1:3] == crop_size
1734

1735
    # Ensure output for PIL and Tensor are equal
1736
    assert_equal(
1737
        output_tensor,
1738
        output_pil,
1739
        msg=f"image_size: {input_image_size} crop_size: {crop_size}",
1740
    )
1741

1742
    # Check if content in center of both image and cropped output is same.
1743
    center_size = (min(crop_size[0], input_image_size[0]), min(crop_size[1], input_image_size[1]))
1744
    crop_center_tl, input_center_tl = [0, 0], [0, 0]
1745
    for index in range(2):
1746
        if crop_size[index] > input_image_size[index]:
1747
            crop_center_tl[index] = (crop_size[index] - input_image_size[index]) // 2
1748
        else:
1749
            input_center_tl[index] = (input_image_size[index] - crop_size[index]) // 2
1750

1751
    output_center = output_pil[
1752
        :,
1753
        crop_center_tl[0] : crop_center_tl[0] + center_size[0],
1754
        crop_center_tl[1] : crop_center_tl[1] + center_size[1],
1755
    ]
1756

1757
    img_center = img[
1758
        :,
1759
        input_center_tl[0] : input_center_tl[0] + center_size[0],
1760
        input_center_tl[1] : input_center_tl[1] + center_size[1],
1761
    ]
1762

1763
    assert_equal(output_center, img_center)
1764

1765

1766
def test_color_jitter():
1767
    color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)
1768

1769
    x_shape = [2, 2, 3]
1770
    x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
1771
    x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
1772
    x_pil = Image.fromarray(x_np, mode="RGB")
1773
    x_pil_2 = x_pil.convert("L")
1774

1775
    for _ in range(10):
1776
        y_pil = color_jitter(x_pil)
1777
        assert y_pil.mode == x_pil.mode
1778

1779
        y_pil_2 = color_jitter(x_pil_2)
1780
        assert y_pil_2.mode == x_pil_2.mode
1781

1782
    # Checking if ColorJitter can be printed as string
1783
    color_jitter.__repr__()
1784

1785

1786
@pytest.mark.parametrize("hue", [1, (-1, 1)])
1787
def test_color_jitter_hue_out_of_bounds(hue):
1788
    with pytest.raises(ValueError, match=re.escape("hue values should be between (-0.5, 0.5)")):
1789
        transforms.ColorJitter(hue=hue)
1790

1791

1792
@pytest.mark.parametrize("seed", range(10))
1793
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
1794
def test_random_erasing(seed):
1795
    torch.random.manual_seed(seed)
1796
    img = torch.ones(3, 128, 128)
1797

1798
    t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.0))
1799
    y, x, h, w, v = t.get_params(
1800
        img,
1801
        t.scale,
1802
        t.ratio,
1803
        [
1804
            t.value,
1805
        ],
1806
    )
1807
    aspect_ratio = h / w
1808
    # Add some tolerance due to the rounding and int conversion used in the transform
1809
    tol = 0.05
1810
    assert 1 / 3 - tol <= aspect_ratio <= 3 + tol
1811

1812
    # Make sure that h > w and h < w are equally likely (log-scale sampling)
1813
    aspect_ratios = []
1814
    random.seed(42)
1815
    trial = 1000
1816
    for _ in range(trial):
1817
        y, x, h, w, v = t.get_params(
1818
            img,
1819
            t.scale,
1820
            t.ratio,
1821
            [
1822
                t.value,
1823
            ],
1824
        )
1825
        aspect_ratios.append(h / w)
1826

1827
    count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1])
1828
    p_value = stats.binomtest(count_bigger_then_ones, trial, p=0.5).pvalue
1829
    assert p_value > 0.0001
1830

1831
    # Checking if RandomErasing can be printed as string
1832
    t.__repr__()
1833

1834

1835
def test_random_rotation():
1836

1837
    with pytest.raises(ValueError):
1838
        transforms.RandomRotation(-0.7)
1839

1840
    with pytest.raises(ValueError):
1841
        transforms.RandomRotation([-0.7])
1842

1843
    with pytest.raises(ValueError):
1844
        transforms.RandomRotation([-0.7, 0, 0.7])
1845

1846
    t = transforms.RandomRotation(0, fill=None)
1847
    assert t.fill == 0
1848

1849
    t = transforms.RandomRotation(10)
1850
    angle = t.get_params(t.degrees)
1851
    assert angle > -10 and angle < 10
1852

1853
    t = transforms.RandomRotation((-10, 10))
1854
    angle = t.get_params(t.degrees)
1855
    assert -10 < angle < 10
1856

1857
    # Checking if RandomRotation can be printed as string
1858
    t.__repr__()
1859

1860
    t = transforms.RandomRotation((-10, 10), interpolation=Image.BILINEAR)
1861
    assert t.interpolation == transforms.InterpolationMode.BILINEAR
1862

1863

1864
def test_random_rotation_error():
1865
    # assert fill being either a Sequence or a Number
1866
    with pytest.raises(TypeError):
1867
        transforms.RandomRotation(0, fill={})
1868

1869

1870
def test_randomperspective():
1871
    for _ in range(10):
1872
        height = random.randint(24, 32) * 2
1873
        width = random.randint(24, 32) * 2
1874
        img = torch.ones(3, height, width)
1875
        to_pil_image = transforms.ToPILImage()
1876
        img = to_pil_image(img)
1877
        perp = transforms.RandomPerspective()
1878
        startpoints, endpoints = perp.get_params(width, height, 0.5)
1879
        tr_img = F.perspective(img, startpoints, endpoints)
1880
        tr_img2 = F.convert_image_dtype(F.pil_to_tensor(F.perspective(tr_img, endpoints, startpoints)))
1881
        tr_img = F.convert_image_dtype(F.pil_to_tensor(tr_img))
1882
        assert img.size[0] == width
1883
        assert img.size[1] == height
1884
        assert torch.nn.functional.mse_loss(
1885
            tr_img, F.convert_image_dtype(F.pil_to_tensor(img))
1886
        ) + 0.3 > torch.nn.functional.mse_loss(tr_img2, F.convert_image_dtype(F.pil_to_tensor(img)))
1887

1888

1889
@pytest.mark.parametrize("seed", range(10))
1890
@pytest.mark.parametrize("mode", ["L", "RGB", "F"])
1891
def test_randomperspective_fill(mode, seed):
1892
    torch.random.manual_seed(seed)
1893

1894
    # assert fill being either a Sequence or a Number
1895
    with pytest.raises(TypeError):
1896
        transforms.RandomPerspective(fill={})
1897

1898
    t = transforms.RandomPerspective(fill=None)
1899
    assert t.fill == 0
1900

1901
    height = 100
1902
    width = 100
1903
    img = torch.ones(3, height, width)
1904
    to_pil_image = transforms.ToPILImage()
1905
    img = to_pil_image(img)
1906
    fill = 127
1907
    num_bands = len(mode)
1908

1909
    img_conv = img.convert(mode)
1910
    perspective = transforms.RandomPerspective(p=1, fill=fill)
1911
    tr_img = perspective(img_conv)
1912
    pixel = tr_img.getpixel((0, 0))
1913

1914
    if not isinstance(pixel, tuple):
1915
        pixel = (pixel,)
1916
    assert pixel == tuple([fill] * num_bands)
1917

1918
    startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5)
1919
    tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill)
1920
    pixel = tr_img.getpixel((0, 0))
1921

1922
    if not isinstance(pixel, tuple):
1923
        pixel = (pixel,)
1924
    assert pixel == tuple([fill] * num_bands)
1925

1926
    wrong_num_bands = num_bands + 1
1927
    with pytest.raises(ValueError):
1928
        F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))
1929

1930

1931
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
1932
def test_normalize():
1933
    def samples_from_standard_normal(tensor):
1934
        p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue
1935
        return p_value > 0.0001
1936

1937
    random_state = random.getstate()
1938
    random.seed(42)
1939
    for channels in [1, 3]:
1940
        img = torch.rand(channels, 10, 10)
1941
        mean = [img[c].mean() for c in range(channels)]
1942
        std = [img[c].std() for c in range(channels)]
1943
        normalized = transforms.Normalize(mean, std)(img)
1944
        assert samples_from_standard_normal(normalized)
1945
    random.setstate(random_state)
1946

1947
    # Checking if Normalize can be printed as string
1948
    transforms.Normalize(mean, std).__repr__()
1949

1950
    # Checking the optional in-place behaviour
1951
    tensor = torch.rand((1, 16, 16))
1952
    tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor)
1953
    assert_equal(tensor, tensor_inplace)
1954

1955

1956
@pytest.mark.parametrize("dtype1", [torch.float32, torch.float64])
1957
@pytest.mark.parametrize("dtype2", [torch.int64, torch.float32, torch.float64])
1958
def test_normalize_different_dtype(dtype1, dtype2):
1959
    img = torch.rand(3, 10, 10, dtype=dtype1)
1960
    mean = torch.tensor([1, 2, 3], dtype=dtype2)
1961
    std = torch.tensor([1, 2, 1], dtype=dtype2)
1962
    # checks that it doesn't crash
1963
    transforms.functional.normalize(img, mean, std)
1964

1965

1966
def test_normalize_3d_tensor():
1967
    torch.manual_seed(28)
1968
    n_channels = 3
1969
    img_size = 10
1970
    mean = torch.rand(n_channels)
1971
    std = torch.rand(n_channels)
1972
    img = torch.rand(n_channels, img_size, img_size)
1973
    target = F.normalize(img, mean, std)
1974

1975
    mean_unsqueezed = mean.view(-1, 1, 1)
1976
    std_unsqueezed = std.view(-1, 1, 1)
1977
    result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
1978
    result2 = F.normalize(
1979
        img, mean_unsqueezed.repeat(1, img_size, img_size), std_unsqueezed.repeat(1, img_size, img_size)
1980
    )
1981
    torch.testing.assert_close(target, result1)
1982
    torch.testing.assert_close(target, result2)
1983

1984

1985
class TestAffine:
1986
    @pytest.fixture(scope="class")
1987
    def input_img(self):
1988
        input_img = np.zeros((40, 40, 3), dtype=np.uint8)
1989
        for pt in [(16, 16), (20, 16), (20, 20)]:
1990
            for i in range(-5, 5):
1991
                for j in range(-5, 5):
1992
                    input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
1993
        return input_img
1994

1995
    def test_affine_translate_seq(self, input_img):
1996
        with pytest.raises(TypeError, match=r"Argument translate should be a sequence"):
1997
            F.affine(input_img, 10, translate=0, scale=1, shear=1)
1998

1999
    @pytest.fixture(scope="class")
2000
    def pil_image(self, input_img):
2001
        return F.to_pil_image(input_img)
2002

2003
    def _to_3x3_inv(self, inv_result_matrix):
2004
        result_matrix = np.zeros((3, 3))
2005
        result_matrix[:2, :] = np.array(inv_result_matrix).reshape((2, 3))
2006
        result_matrix[2, 2] = 1
2007
        return np.linalg.inv(result_matrix)
2008

2009
    def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img, center=None):
2010

2011
        a_rad = math.radians(angle)
2012
        s_rad = [math.radians(sh_) for sh_ in shear]
2013
        cnt = [20, 20] if center is None else center
2014
        cx, cy = cnt
2015
        tx, ty = translate
2016
        sx, sy = s_rad
2017
        rot = a_rad
2018

2019
        # 1) Check transformation matrix:
2020
        C = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
2021
        T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
2022
        Cinv = np.linalg.inv(C)
2023

2024
        RS = np.array(
2025
            [
2026
                [scale * math.cos(rot), -scale * math.sin(rot), 0],
2027
                [scale * math.sin(rot), scale * math.cos(rot), 0],
2028
                [0, 0, 1],
2029
            ]
2030
        )
2031

2032
        SHx = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
2033

2034
        SHy = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
2035

2036
        RSS = np.matmul(RS, np.matmul(SHy, SHx))
2037

2038
        true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv)))
2039

2040
        result_matrix = self._to_3x3_inv(
2041
            F._get_inverse_affine_matrix(center=cnt, angle=angle, translate=translate, scale=scale, shear=shear)
2042
        )
2043
        assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10
2044
        # 2) Perform inverse mapping:
2045
        true_result = np.zeros((40, 40, 3), dtype=np.uint8)
2046
        inv_true_matrix = np.linalg.inv(true_matrix)
2047
        for y in range(true_result.shape[0]):
2048
            for x in range(true_result.shape[1]):
2049
                # Same as for PIL:
2050
                # https://github.com/python-pillow/Pillow/blob/71f8ec6a0cfc1008076a023c0756542539d057ab/
2051
                # src/libImaging/Geometry.c#L1060
2052
                input_pt = np.array([x + 0.5, y + 0.5, 1.0])
2053
                res = np.floor(np.dot(inv_true_matrix, input_pt)).astype(int)
2054
                _x, _y = res[:2]
2055
                if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
2056
                    true_result[y, x, :] = input_img[_y, _x, :]
2057

2058
        result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear, center=center)
2059
        assert result.size == pil_image.size
2060
        # Compute number of different pixels:
2061
        np_result = np.array(result)
2062
        n_diff_pixels = np.sum(np_result != true_result) / 3
2063
        # Accept 3 wrong pixels
2064
        error_msg = (
2065
            f"angle={angle}, translate={translate}, scale={scale}, shear={shear}\nn diff pixels={n_diff_pixels}\n"
2066
        )
2067
        assert n_diff_pixels < 3, error_msg
2068

2069
    def test_transformation_discrete(self, pil_image, input_img):
2070
        # Test rotation
2071
        angle = 45
2072
        self._test_transformation(
2073
            angle=angle, translate=(0, 0), scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
2074
        )
2075

2076
        # Test rotation
2077
        angle = 45
2078
        self._test_transformation(
2079
            angle=angle,
2080
            translate=(0, 0),
2081
            scale=1.0,
2082
            shear=(0.0, 0.0),
2083
            pil_image=pil_image,
2084
            input_img=input_img,
2085
            center=[0, 0],
2086
        )
2087

2088
        # Test translation
2089
        translate = [10, 15]
2090
        self._test_transformation(
2091
            angle=0.0, translate=translate, scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
2092
        )
2093

2094
        # Test scale
2095
        scale = 1.2
2096
        self._test_transformation(
2097
            angle=0.0, translate=(0.0, 0.0), scale=scale, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
2098
        )
2099

2100
        # Test shear
2101
        shear = [45.0, 25.0]
2102
        self._test_transformation(
2103
            angle=0.0, translate=(0.0, 0.0), scale=1.0, shear=shear, pil_image=pil_image, input_img=input_img
2104
        )
2105

2106
        # Test shear with top-left as center
2107
        shear = [45.0, 25.0]
2108
        self._test_transformation(
2109
            angle=0.0,
2110
            translate=(0.0, 0.0),
2111
            scale=1.0,
2112
            shear=shear,
2113
            pil_image=pil_image,
2114
            input_img=input_img,
2115
            center=[0, 0],
2116
        )
2117

2118
    @pytest.mark.parametrize("angle", range(-90, 90, 36))
2119
    @pytest.mark.parametrize("translate", range(-10, 10, 5))
2120
    @pytest.mark.parametrize("scale", [0.77, 1.0, 1.27])
2121
    @pytest.mark.parametrize("shear", range(-15, 15, 5))
2122
    def test_transformation_range(self, angle, translate, scale, shear, pil_image, input_img):
2123
        self._test_transformation(
2124
            angle=angle,
2125
            translate=(translate, translate),
2126
            scale=scale,
2127
            shear=(shear, shear),
2128
            pil_image=pil_image,
2129
            input_img=input_img,
2130
        )
2131

2132

2133
def test_random_affine():
2134

2135
    with pytest.raises(ValueError):
2136
        transforms.RandomAffine(-0.7)
2137
    with pytest.raises(ValueError):
2138
        transforms.RandomAffine([-0.7])
2139
    with pytest.raises(ValueError):
2140
        transforms.RandomAffine([-0.7, 0, 0.7])
2141
    with pytest.raises(TypeError):
2142
        transforms.RandomAffine([-90, 90], translate=2.0)
2143
    with pytest.raises(ValueError):
2144
        transforms.RandomAffine([-90, 90], translate=[-1.0, 1.0])
2145
    with pytest.raises(ValueError):
2146
        transforms.RandomAffine([-90, 90], translate=[-1.0, 0.0, 1.0])
2147

2148
    with pytest.raises(ValueError):
2149
        transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.0])
2150
    with pytest.raises(ValueError):
2151
        transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[-1.0, 1.0])
2152
    with pytest.raises(ValueError):
2153
        transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, -0.5])
2154
    with pytest.raises(ValueError):
2155
        transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 3.0, -0.5])
2156

2157
    with pytest.raises(ValueError):
2158
        transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=-7)
2159
    with pytest.raises(ValueError):
2160
        transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10])
2161
    with pytest.raises(ValueError):
2162
        transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10])
2163
    with pytest.raises(ValueError):
2164
        transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10])
2165

2166
    # assert fill being either a Sequence or a Number
2167
    with pytest.raises(TypeError):
2168
        transforms.RandomAffine(0, fill={})
2169

2170
    t = transforms.RandomAffine(0, fill=None)
2171
    assert t.fill == 0
2172

2173
    x = np.zeros((100, 100, 3), dtype=np.uint8)
2174
    img = F.to_pil_image(x)
2175

2176
    t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40])
2177
    for _ in range(100):
2178
        angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, img_size=img.size)
2179
        assert -10 < angle < 10
2180
        assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5
2181
        assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5
2182
        assert 0.7 < scale < 1.3
2183
        assert -10 < shear[0] < 10
2184
        assert -20 < shear[1] < 40
2185

2186
    # Checking if RandomAffine can be printed as string
2187
    t.__repr__()
2188

2189
    t = transforms.RandomAffine(10, interpolation=transforms.InterpolationMode.BILINEAR)
2190
    assert "bilinear" in t.__repr__()
2191

2192
    t = transforms.RandomAffine(10, interpolation=Image.BILINEAR)
2193
    assert t.interpolation == transforms.InterpolationMode.BILINEAR
2194

2195

2196
def test_elastic_transformation():
2197
    with pytest.raises(TypeError, match=r"alpha should be float or a sequence of floats"):
2198
        transforms.ElasticTransform(alpha=True, sigma=2.0)
2199
    with pytest.raises(TypeError, match=r"alpha should be a sequence of floats"):
2200
        transforms.ElasticTransform(alpha=[1.0, True], sigma=2.0)
2201
    with pytest.raises(ValueError, match=r"alpha is a sequence its length should be 2"):
2202
        transforms.ElasticTransform(alpha=[1.0, 0.0, 1.0], sigma=2.0)
2203

2204
    with pytest.raises(TypeError, match=r"sigma should be float or a sequence of floats"):
2205
        transforms.ElasticTransform(alpha=2.0, sigma=True)
2206
    with pytest.raises(TypeError, match=r"sigma should be a sequence of floats"):
2207
        transforms.ElasticTransform(alpha=2.0, sigma=[1.0, True])
2208
    with pytest.raises(ValueError, match=r"sigma is a sequence its length should be 2"):
2209
        transforms.ElasticTransform(alpha=2.0, sigma=[1.0, 0.0, 1.0])
2210

2211
    t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=Image.BILINEAR)
2212
    assert t.interpolation == transforms.InterpolationMode.BILINEAR
2213

2214
    with pytest.raises(TypeError, match=r"fill should be int or float"):
2215
        transforms.ElasticTransform(alpha=1.0, sigma=1.0, fill={})
2216

2217
    x = torch.randint(0, 256, (3, 32, 32), dtype=torch.uint8)
2218
    img = F.to_pil_image(x)
2219
    t = transforms.ElasticTransform(alpha=0.0, sigma=0.0)
2220
    transformed_img = t(img)
2221
    assert transformed_img == img
2222

2223
    # Smoke test on PIL images
2224
    t = transforms.ElasticTransform(alpha=0.5, sigma=0.23)
2225
    transformed_img = t(img)
2226
    assert isinstance(transformed_img, Image.Image)
2227

2228
    # Checking if ElasticTransform can be printed as string
2229
    t.__repr__()
2230

2231

2232
def test_random_grayscale_with_grayscale_input():
2233
    transform = transforms.RandomGrayscale(p=1.0)
2234

2235
    image_tensor = torch.randint(0, 256, (1, 16, 16), dtype=torch.uint8)
2236
    output_tensor = transform(image_tensor)
2237
    torch.testing.assert_close(output_tensor, image_tensor)
2238

2239
    image_pil = F.to_pil_image(image_tensor)
2240
    output_pil = transform(image_pil)
2241
    torch.testing.assert_close(F.pil_to_tensor(output_pil), image_tensor)
2242

2243

2244
if __name__ == "__main__":
2245
    pytest.main([__file__])
2246

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.