vision

Форк
0
/
test_transforms_v2.py 
6171 строка · 251.2 Кб
1
import contextlib
2
import decimal
3
import functools
4
import inspect
5
import itertools
6
import math
7
import pickle
8
import random
9
import re
10
import sys
11
from copy import deepcopy
12
from pathlib import Path
13
from unittest import mock
14

15
import numpy as np
16
import PIL.Image
17
import pytest
18

19
import torch
20
import torchvision.ops
21
import torchvision.transforms.v2 as transforms
22

23
from common_utils import (
24
    assert_equal,
25
    cache,
26
    cpu_and_cuda,
27
    freeze_rng_state,
28
    ignore_jit_no_profile_information_warning,
29
    make_bounding_boxes,
30
    make_detection_masks,
31
    make_image,
32
    make_image_pil,
33
    make_image_tensor,
34
    make_segmentation_mask,
35
    make_video,
36
    make_video_tensor,
37
    needs_cuda,
38
    set_rng_seed,
39
)
40

41
from torch import nn
42
from torch.testing import assert_close
43
from torch.utils._pytree import tree_flatten, tree_map
44
from torch.utils.data import DataLoader, default_collate
45
from torchvision import tv_tensors
46
from torchvision.ops.boxes import box_iou
47

48
from torchvision.transforms._functional_tensor import _max_value as get_max_value
49
from torchvision.transforms.functional import pil_modes_mapping, to_pil_image
50
from torchvision.transforms.v2 import functional as F
51
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
52
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs
53
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
54

55

56
# turns all warnings into errors for this module
57
pytestmark = [pytest.mark.filterwarnings("error")]
58

59
if sys.version_info[:2] >= (3, 12):
60
    # torchscript relies on some AST stuff that got deprecated in 3.12,
61
    # so we have to explicitly ignore those otherwise we'd error on warnings due to the pytestmark filter above.
62
    pytestmark.append(pytest.mark.filterwarnings("ignore::DeprecationWarning"))
63

64

65
@pytest.fixture(autouse=True)
66
def fix_rng_seed():
67
    set_rng_seed(0)
68
    yield
69

70

71
def _to_tolerances(maybe_tolerance_dict):
72
    if not isinstance(maybe_tolerance_dict, dict):
73
        return dict(rtol=None, atol=None)
74

75
    tolerances = dict(rtol=0, atol=0)
76
    tolerances.update(maybe_tolerance_dict)
77
    return tolerances
78

79

80
def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs):
81
    """Checks if the kernel produces closes results for inputs on GPU and CPU."""
82
    if input.device.type != "cuda":
83
        return
84

85
    input_cuda = input.as_subclass(torch.Tensor)
86
    input_cpu = input_cuda.to("cpu")
87

88
    with freeze_rng_state():
89
        actual = kernel(input_cuda, *args, **kwargs)
90
    with freeze_rng_state():
91
        expected = kernel(input_cpu, *args, **kwargs)
92

93
    assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol)
94

95

96
@cache
97
def _script(obj):
98
    try:
99
        return torch.jit.script(obj)
100
    except Exception as error:
101
        name = getattr(obj, "__name__", obj.__class__.__name__)
102
        raise AssertionError(f"Trying to `torch.jit.script` `{name}` raised the error above.") from error
103

104

105
def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):
106
    """Checks if the kernel is scriptable and if the scripted output is close to the eager one."""
107
    if input.device.type != "cpu":
108
        return
109

110
    kernel_scripted = _script(kernel)
111

112
    input = input.as_subclass(torch.Tensor)
113
    with ignore_jit_no_profile_information_warning():
114
        with freeze_rng_state():
115
            actual = kernel_scripted(input, *args, **kwargs)
116
    with freeze_rng_state():
117
        expected = kernel(input, *args, **kwargs)
118

119
    assert_close(actual, expected, rtol=rtol, atol=atol)
120

121

122
def _check_kernel_batched_vs_unbatched(kernel, input, *args, rtol, atol, **kwargs):
123
    """Checks if the kernel produces close results for batched and unbatched inputs."""
124
    unbatched_input = input.as_subclass(torch.Tensor)
125

126
    for batch_dims in [(2,), (2, 1)]:
127
        repeats = [*batch_dims, *[1] * input.ndim]
128

129
        actual = kernel(unbatched_input.repeat(repeats), *args, **kwargs)
130

131
        expected = kernel(unbatched_input, *args, **kwargs)
132
        # We can't directly call `.repeat()` on the output, since some kernel also return some additional metadata
133
        if isinstance(expected, torch.Tensor):
134
            expected = expected.repeat(repeats)
135
        else:
136
            tensor, *metadata = expected
137
            expected = (tensor.repeat(repeats), *metadata)
138

139
        assert_close(actual, expected, rtol=rtol, atol=atol)
140

141
    for degenerate_batch_dims in [(0,), (5, 0), (0, 5)]:
142
        degenerate_batched_input = torch.empty(
143
            degenerate_batch_dims + input.shape, dtype=input.dtype, device=input.device
144
        )
145

146
        output = kernel(degenerate_batched_input, *args, **kwargs)
147
        # Most kernels just return a tensor, but some also return some additional metadata
148
        if not isinstance(output, torch.Tensor):
149
            output, *_ = output
150

151
        assert output.shape[: -input.ndim] == degenerate_batch_dims
152

153

154
def check_kernel(
155
    kernel,
156
    input,
157
    *args,
158
    check_cuda_vs_cpu=True,
159
    check_scripted_vs_eager=True,
160
    check_batched_vs_unbatched=True,
161
    **kwargs,
162
):
163
    initial_input_version = input._version
164

165
    output = kernel(input.as_subclass(torch.Tensor), *args, **kwargs)
166
    # Most kernels just return a tensor, but some also return some additional metadata
167
    if not isinstance(output, torch.Tensor):
168
        output, *_ = output
169

170
    # check that no inplace operation happened
171
    assert input._version == initial_input_version
172

173
    if kernel not in {F.to_dtype_image, F.to_dtype_video}:
174
        assert output.dtype == input.dtype
175
    assert output.device == input.device
176

177
    if check_cuda_vs_cpu:
178
        _check_kernel_cuda_vs_cpu(kernel, input, *args, **kwargs, **_to_tolerances(check_cuda_vs_cpu))
179

180
    if check_scripted_vs_eager:
181
        _check_kernel_scripted_vs_eager(kernel, input, *args, **kwargs, **_to_tolerances(check_scripted_vs_eager))
182

183
    if check_batched_vs_unbatched:
184
        _check_kernel_batched_vs_unbatched(kernel, input, *args, **kwargs, **_to_tolerances(check_batched_vs_unbatched))
185

186

187
def _check_functional_scripted_smoke(functional, input, *args, **kwargs):
188
    """Checks if the functional can be scripted and the scripted version can be called without error."""
189
    if not isinstance(input, tv_tensors.Image):
190
        return
191

192
    functional_scripted = _script(functional)
193
    with ignore_jit_no_profile_information_warning():
194
        functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)
195

196

197
def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs):
198
    unknown_input = object()
199
    with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
200
        functional(unknown_input, *args, **kwargs)
201

202
    with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
203
        output = functional(input, *args, **kwargs)
204

205
        spy.assert_any_call(f"{functional.__module__}.{functional.__name__}")
206

207
    assert isinstance(output, type(input))
208

209
    if isinstance(input, tv_tensors.BoundingBoxes) and functional is not F.convert_bounding_box_format:
210
        assert output.format == input.format
211

212
    if check_scripted_smoke:
213
        _check_functional_scripted_smoke(functional, input, *args, **kwargs)
214

215

216
def check_functional_kernel_signature_match(functional, *, kernel, input_type):
217
    """Checks if the signature of the functional matches the kernel signature."""
218
    functional_params = list(inspect.signature(functional).parameters.values())[1:]
219
    kernel_params = list(inspect.signature(kernel).parameters.values())[1:]
220

221
    if issubclass(input_type, tv_tensors.TVTensor):
222
        # We filter out metadata that is implicitly passed to the functional through the input tv_tensor, but has to be
223
        # explicitly passed to the kernel.
224
        explicit_metadata = {
225
            tv_tensors.BoundingBoxes: {"format", "canvas_size"},
226
        }
227
        kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
228

229
    functional_params = iter(functional_params)
230
    for functional_param, kernel_param in zip(functional_params, kernel_params):
231
        try:
232
            # In general, the functional parameters are a superset of the kernel parameters. Thus, we filter out
233
            # functional parameters that have no kernel equivalent while keeping the order intact.
234
            while functional_param.name != kernel_param.name:
235
                functional_param = next(functional_params)
236
        except StopIteration:
237
            raise AssertionError(
238
                f"Parameter `{kernel_param.name}` of kernel `{kernel.__name__}` "
239
                f"has no corresponding parameter on the functional `{functional.__name__}`."
240
            ) from None
241

242
        if issubclass(input_type, PIL.Image.Image):
243
            # PIL kernels often have more correct annotations, since they are not limited by JIT. Thus, we don't check
244
            # them in the first place.
245
            functional_param._annotation = kernel_param._annotation = inspect.Parameter.empty
246

247
        assert functional_param == kernel_param
248

249

250
def _check_transform_v1_compatibility(transform, input, *, rtol, atol):
251
    """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
252
    ``get_params`` method that is the v1 equivalent, the output is close to v1, is scriptable, and the scripted version
253
    can be called without error."""
254
    if not (type(input) is torch.Tensor or isinstance(input, PIL.Image.Image)):
255
        return
256

257
    v1_transform_cls = transform._v1_transform_cls
258
    if v1_transform_cls is None:
259
        return
260

261
    if hasattr(v1_transform_cls, "get_params"):
262
        assert type(transform).get_params is v1_transform_cls.get_params
263

264
    v1_transform = v1_transform_cls(**transform._extract_params_for_v1_transform())
265

266
    with freeze_rng_state():
267
        output_v2 = transform(input)
268

269
    with freeze_rng_state():
270
        output_v1 = v1_transform(input)
271

272
    assert_close(F.to_image(output_v2), F.to_image(output_v1), rtol=rtol, atol=atol)
273

274
    if isinstance(input, PIL.Image.Image):
275
        return
276

277
    _script(v1_transform)(input)
278

279

280
def _make_transform_sample(transform, *, image_or_video, adapter):
281
    device = image_or_video.device if isinstance(image_or_video, torch.Tensor) else "cpu"
282
    size = F.get_size(image_or_video)
283
    input = dict(
284
        image_or_video=image_or_video,
285
        image_tv_tensor=make_image(size, device=device),
286
        video_tv_tensor=make_video(size, device=device),
287
        image_pil=make_image_pil(size),
288
        bounding_boxes_xyxy=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.XYXY, device=device),
289
        bounding_boxes_xywh=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.XYWH, device=device),
290
        bounding_boxes_cxcywh=make_bounding_boxes(size, format=tv_tensors.BoundingBoxFormat.CXCYWH, device=device),
291
        bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes(
292
            [
293
                [0, 0, 0, 0],  # no height or width
294
                [0, 0, 0, 1],  # no height
295
                [0, 0, 1, 0],  # no width
296
                [2, 0, 1, 1],  # x1 > x2, y1 < y2
297
                [0, 2, 1, 1],  # x1 < x2, y1 > y2
298
                [2, 2, 1, 1],  # x1 > x2, y1 > y2
299
            ],
300
            format=tv_tensors.BoundingBoxFormat.XYXY,
301
            canvas_size=size,
302
            device=device,
303
        ),
304
        bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes(
305
            [
306
                [0, 0, 0, 0],  # no height or width
307
                [0, 0, 0, 1],  # no height
308
                [0, 0, 1, 0],  # no width
309
                [0, 0, 1, -1],  # negative height
310
                [0, 0, -1, 1],  # negative width
311
                [0, 0, -1, -1],  # negative height and width
312
            ],
313
            format=tv_tensors.BoundingBoxFormat.XYWH,
314
            canvas_size=size,
315
            device=device,
316
        ),
317
        bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes(
318
            [
319
                [0, 0, 0, 0],  # no height or width
320
                [0, 0, 0, 1],  # no height
321
                [0, 0, 1, 0],  # no width
322
                [0, 0, 1, -1],  # negative height
323
                [0, 0, -1, 1],  # negative width
324
                [0, 0, -1, -1],  # negative height and width
325
            ],
326
            format=tv_tensors.BoundingBoxFormat.CXCYWH,
327
            canvas_size=size,
328
            device=device,
329
        ),
330
        detection_mask=make_detection_masks(size, device=device),
331
        segmentation_mask=make_segmentation_mask(size, device=device),
332
        int=0,
333
        float=0.0,
334
        bool=True,
335
        none=None,
336
        str="str",
337
        path=Path.cwd(),
338
        object=object(),
339
        tensor=torch.empty(5),
340
        array=np.empty(5),
341
    )
342
    if adapter is not None:
343
        input = adapter(transform, input, device)
344
    return input
345

346

347
def _check_transform_sample_input_smoke(transform, input, *, adapter):
348
    # This is a bunch of input / output convention checks, using a big sample with different parts as input.
349

350
    if not check_type(input, (is_pure_tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video)):
351
        return
352

353
    sample = _make_transform_sample(
354
        # adapter might change transform inplace
355
        transform=transform if adapter is None else deepcopy(transform),
356
        image_or_video=input,
357
        adapter=adapter,
358
    )
359
    for container_type in [dict, list, tuple]:
360
        if container_type is dict:
361
            input = sample
362
        else:
363
            input = container_type(sample.values())
364

365
        input_flat, input_spec = tree_flatten(input)
366

367
        with freeze_rng_state():
368
            torch.manual_seed(0)
369
            output = transform(input)
370
        output_flat, output_spec = tree_flatten(output)
371

372
        assert output_spec == input_spec
373

374
        for output_item, input_item, should_be_transformed in zip(
375
            output_flat, input_flat, transforms.Transform()._needs_transform_list(input_flat)
376
        ):
377
            if should_be_transformed:
378
                assert type(output_item) is type(input_item)
379
            else:
380
                assert output_item is input_item
381

382
    # Enforce that the transform does not turn a degenerate bounding box, e.g. marked by RandomIoUCrop (or any other
383
    # future transform that does this), back into a valid one.
384
    for degenerate_bounding_boxes in (
385
        bounding_box
386
        for name, bounding_box in sample.items()
387
        if "degenerate" in name and isinstance(bounding_box, tv_tensors.BoundingBoxes)
388
    ):
389
        sample = dict(
390
            boxes=degenerate_bounding_boxes,
391
            labels=torch.randint(10, (degenerate_bounding_boxes.shape[0],), device=degenerate_bounding_boxes.device),
392
        )
393
        assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
394

395

396
def check_transform(transform, input, check_v1_compatibility=True, check_sample_input=True):
397
    pickle.loads(pickle.dumps(transform))
398

399
    output = transform(input)
400
    assert isinstance(output, type(input))
401

402
    if isinstance(input, tv_tensors.BoundingBoxes) and not isinstance(transform, transforms.ConvertBoundingBoxFormat):
403
        assert output.format == input.format
404

405
    if check_sample_input:
406
        _check_transform_sample_input_smoke(
407
            transform, input, adapter=check_sample_input if callable(check_sample_input) else None
408
        )
409

410
    if check_v1_compatibility:
411
        _check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))
412

413
    return output
414

415

416
def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
417
    def wrapper(input, *args, **kwargs):
418
        transform = transform_cls(*args, **transform_specific_kwargs, **kwargs)
419
        return transform(input)
420

421
    wrapper.__name__ = transform_cls.__name__
422

423
    return wrapper
424

425

426
def param_value_parametrization(**kwargs):
427
    """Helper function to turn
428

429
    @pytest.mark.parametrize(
430
        ("param", "value"),
431
        ("a", 1),
432
        ("a", 2),
433
        ("a", 3),
434
        ("b", -1.0)
435
        ("b", 1.0)
436
    )
437

438
    into
439

440
    @param_value_parametrization(a=[1, 2, 3], b=[-1.0, 1.0])
441
    """
442
    return pytest.mark.parametrize(
443
        ("param", "value"),
444
        [(param, value) for param, values in kwargs.items() for value in values],
445
    )
446

447

448
def adapt_fill(value, *, dtype):
449
    """Adapt fill values in the range [0.0, 1.0] to the value range of the dtype"""
450
    if value is None:
451
        return value
452

453
    max_value = get_max_value(dtype)
454
    value_type = float if dtype.is_floating_point else int
455

456
    if isinstance(value, (int, float)):
457
        return value_type(value * max_value)
458
    elif isinstance(value, (list, tuple)):
459
        return type(value)(value_type(v * max_value) for v in value)
460
    else:
461
        raise ValueError(f"fill should be an int or float, or a list or tuple of the former, but got '{value}'.")
462

463

464
EXHAUSTIVE_TYPE_FILLS = [
465
    None,
466
    1,
467
    0.5,
468
    [1],
469
    [0.2],
470
    (0,),
471
    (0.7,),
472
    [1, 0, 1],
473
    [0.1, 0.2, 0.3],
474
    (0, 1, 0),
475
    (0.9, 0.234, 0.314),
476
]
477
CORRECTNESS_FILLS = [
478
    v for v in EXHAUSTIVE_TYPE_FILLS if v is None or isinstance(v, float) or (isinstance(v, list) and len(v) > 1)
479
]
480

481

482
# We cannot use `list(transforms.InterpolationMode)` here, since it includes some PIL-only ones as well
483
INTERPOLATION_MODES = [
484
    transforms.InterpolationMode.NEAREST,
485
    transforms.InterpolationMode.NEAREST_EXACT,
486
    transforms.InterpolationMode.BILINEAR,
487
    transforms.InterpolationMode.BICUBIC,
488
]
489

490

491
def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
492
    format = bounding_boxes.format
493
    canvas_size = new_canvas_size or bounding_boxes.canvas_size
494

495
    def affine_bounding_boxes(bounding_boxes):
496
        dtype = bounding_boxes.dtype
497
        device = bounding_boxes.device
498

499
        # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
500
        input_xyxy = F.convert_bounding_box_format(
501
            bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True),
502
            old_format=format,
503
            new_format=tv_tensors.BoundingBoxFormat.XYXY,
504
            inplace=True,
505
        )
506
        x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist()
507

508
        points = np.array(
509
            [
510
                [x1, y1, 1.0],
511
                [x2, y1, 1.0],
512
                [x1, y2, 1.0],
513
                [x2, y2, 1.0],
514
            ]
515
        )
516
        transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T)
517

518
        output_xyxy = torch.Tensor(
519
            [
520
                float(np.min(transformed_points[:, 0])),
521
                float(np.min(transformed_points[:, 1])),
522
                float(np.max(transformed_points[:, 0])),
523
                float(np.max(transformed_points[:, 1])),
524
            ]
525
        )
526

527
        output = F.convert_bounding_box_format(
528
            output_xyxy, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format
529
        )
530

531
        if clamp:
532
            # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
533
            output = F.clamp_bounding_boxes(
534
                output,
535
                format=format,
536
                canvas_size=canvas_size,
537
            )
538
        else:
539
            # We leave the bounding box as float64 so the caller gets the full precision to perform any additional
540
            # operation
541
            dtype = output.dtype
542

543
        return output.to(dtype=dtype, device=device)
544

545
    return tv_tensors.BoundingBoxes(
546
        torch.cat([affine_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape(
547
            bounding_boxes.shape
548
        ),
549
        format=format,
550
        canvas_size=canvas_size,
551
    )
552

553

554
class TestResize:
555
    INPUT_SIZE = (17, 11)
556
    OUTPUT_SIZES = [17, [17], (17,), None, [12, 13], (12, 13)]
557

558
    def _make_max_size_kwarg(self, *, use_max_size, size):
559
        if size is None:
560
            max_size = min(list(self.INPUT_SIZE))
561
        elif use_max_size:
562
            if not (isinstance(size, int) or len(size) == 1):
563
                # This would result in an `ValueError`
564
                return None
565

566
            max_size = (size if isinstance(size, int) else size[0]) + 1
567
        else:
568
            max_size = None
569

570
        return dict(max_size=max_size)
571

572
    def _compute_output_size(self, *, input_size, size, max_size):
573
        if size is None:
574
            size = max_size
575

576
        elif not (isinstance(size, int) or len(size) == 1):
577
            return tuple(size)
578

579
        elif not isinstance(size, int):
580
            size = size[0]
581

582
        old_height, old_width = input_size
583
        ratio = old_width / old_height
584
        if ratio > 1:
585
            new_height = size
586
            new_width = int(ratio * new_height)
587
        else:
588
            new_width = size
589
            new_height = int(new_width / ratio)
590

591
        if max_size is not None and max(new_height, new_width) > max_size:
592
            # Need to recompute the aspect ratio, since it might have changed due to rounding
593
            ratio = new_width / new_height
594
            if ratio > 1:
595
                new_width = max_size
596
                new_height = int(new_width / ratio)
597
            else:
598
                new_height = max_size
599
                new_width = int(new_height * ratio)
600

601
        return new_height, new_width
602

603
    @pytest.mark.parametrize("size", OUTPUT_SIZES)
604
    @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
605
    @pytest.mark.parametrize("use_max_size", [True, False])
606
    @pytest.mark.parametrize("antialias", [True, False])
607
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
608
    @pytest.mark.parametrize("device", cpu_and_cuda())
609
    def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype, device):
610
        if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
611
            return
612

613
        # In contrast to CPU, there is no native `InterpolationMode.BICUBIC` implementation for uint8 images on CUDA.
614
        # Internally, it uses the float path. Thus, we need to test with an enormous tolerance here to account for that.
615
        atol = 30 if (interpolation is transforms.InterpolationMode.BICUBIC and dtype is torch.uint8) else 1
616
        check_cuda_vs_cpu_tolerances = dict(rtol=0, atol=atol / 255 if dtype.is_floating_point else atol)
617

618
        check_kernel(
619
            F.resize_image,
620
            make_image(self.INPUT_SIZE, dtype=dtype, device=device),
621
            size=size,
622
            interpolation=interpolation,
623
            **max_size_kwarg,
624
            antialias=antialias,
625
            check_cuda_vs_cpu=check_cuda_vs_cpu_tolerances,
626
            check_scripted_vs_eager=not isinstance(size, int),
627
        )
628

629
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
630
    @pytest.mark.parametrize("size", OUTPUT_SIZES)
631
    @pytest.mark.parametrize("use_max_size", [True, False])
632
    @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
633
    @pytest.mark.parametrize("device", cpu_and_cuda())
634
    def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device):
635
        if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
636
            return
637

638
        bounding_boxes = make_bounding_boxes(
639
            format=format,
640
            canvas_size=self.INPUT_SIZE,
641
            dtype=dtype,
642
            device=device,
643
        )
644
        check_kernel(
645
            F.resize_bounding_boxes,
646
            bounding_boxes,
647
            canvas_size=bounding_boxes.canvas_size,
648
            size=size,
649
            **max_size_kwarg,
650
            check_scripted_vs_eager=not isinstance(size, int),
651
        )
652

653
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
654
    def test_kernel_mask(self, make_mask):
655
        check_kernel(F.resize_mask, make_mask(self.INPUT_SIZE), size=self.OUTPUT_SIZES[-1])
656

657
    def test_kernel_video(self):
658
        check_kernel(F.resize_video, make_video(self.INPUT_SIZE), size=self.OUTPUT_SIZES[-1], antialias=True)
659

660
    @pytest.mark.parametrize("size", OUTPUT_SIZES)
661
    @pytest.mark.parametrize(
662
        "make_input",
663
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
664
    )
665
    def test_functional(self, size, make_input):
666
        max_size_kwarg = self._make_max_size_kwarg(use_max_size=size is None, size=size)
667

668
        check_functional(
669
            F.resize,
670
            make_input(self.INPUT_SIZE),
671
            size=size,
672
            **max_size_kwarg,
673
            antialias=True,
674
            check_scripted_smoke=not isinstance(size, int),
675
        )
676

677
    @pytest.mark.parametrize(
678
        ("kernel", "input_type"),
679
        [
680
            (F.resize_image, torch.Tensor),
681
            (F._geometry._resize_image_pil, PIL.Image.Image),
682
            (F.resize_image, tv_tensors.Image),
683
            (F.resize_bounding_boxes, tv_tensors.BoundingBoxes),
684
            (F.resize_mask, tv_tensors.Mask),
685
            (F.resize_video, tv_tensors.Video),
686
        ],
687
    )
688
    def test_functional_signature(self, kernel, input_type):
689
        check_functional_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type)
690

691
    @pytest.mark.parametrize("size", OUTPUT_SIZES)
692
    @pytest.mark.parametrize("device", cpu_and_cuda())
693
    @pytest.mark.parametrize(
694
        "make_input",
695
        [
696
            make_image_tensor,
697
            make_image_pil,
698
            make_image,
699
            make_bounding_boxes,
700
            make_segmentation_mask,
701
            make_detection_masks,
702
            make_video,
703
        ],
704
    )
705
    def test_transform(self, size, device, make_input):
706
        max_size_kwarg = self._make_max_size_kwarg(use_max_size=size is None, size=size)
707

708
        check_transform(
709
            transforms.Resize(size=size, **max_size_kwarg, antialias=True),
710
            make_input(self.INPUT_SIZE, device=device),
711
            # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
712
            check_v1_compatibility=dict(rtol=0, atol=1) if size is not None else False,
713
        )
714

715
    def _check_output_size(self, input, output, *, size, max_size):
716
        assert tuple(F.get_size(output)) == self._compute_output_size(
717
            input_size=F.get_size(input), size=size, max_size=max_size
718
        )
719

720
    @pytest.mark.parametrize("size", OUTPUT_SIZES)
721
    # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
722
    # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
723
    @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
724
    @pytest.mark.parametrize("use_max_size", [True, False])
725
    @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
726
    def test_image_correctness(self, size, interpolation, use_max_size, fn):
727
        if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
728
            return
729

730
        image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
731

732
        actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True)
733
        expected = F.to_image(F.resize(F.to_pil_image(image), size=size, interpolation=interpolation, **max_size_kwarg))
734

735
        self._check_output_size(image, actual, size=size, **max_size_kwarg)
736
        torch.testing.assert_close(actual, expected, atol=1, rtol=0)
737

738
    def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=None):
739
        old_height, old_width = bounding_boxes.canvas_size
740
        new_height, new_width = self._compute_output_size(
741
            input_size=bounding_boxes.canvas_size, size=size, max_size=max_size
742
        )
743

744
        if (old_height, old_width) == (new_height, new_width):
745
            return bounding_boxes
746

747
        affine_matrix = np.array(
748
            [
749
                [new_width / old_width, 0, 0],
750
                [0, new_height / old_height, 0],
751
            ],
752
        )
753

754
        return reference_affine_bounding_boxes_helper(
755
            bounding_boxes,
756
            affine_matrix=affine_matrix,
757
            new_canvas_size=(new_height, new_width),
758
        )
759

760
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
761
    @pytest.mark.parametrize("size", OUTPUT_SIZES)
762
    @pytest.mark.parametrize("use_max_size", [True, False])
763
    @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
764
    def test_bounding_boxes_correctness(self, format, size, use_max_size, fn):
765
        if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
766
            return
767

768
        bounding_boxes = make_bounding_boxes(format=format, canvas_size=self.INPUT_SIZE)
769

770
        actual = fn(bounding_boxes, size=size, **max_size_kwarg)
771
        expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg)
772

773
        self._check_output_size(bounding_boxes, actual, size=size, **max_size_kwarg)
774
        torch.testing.assert_close(actual, expected)
775

776
    @pytest.mark.parametrize("interpolation", set(transforms.InterpolationMode) - set(INTERPOLATION_MODES))
777
    @pytest.mark.parametrize(
778
        "make_input",
779
        [make_image_tensor, make_image_pil, make_image, make_video],
780
    )
781
    def test_pil_interpolation_compat_smoke(self, interpolation, make_input):
782
        input = make_input(self.INPUT_SIZE)
783

784
        with (
785
            contextlib.nullcontext()
786
            if isinstance(input, PIL.Image.Image)
787
            # This error is triggered in PyTorch core
788
            else pytest.raises(NotImplementedError, match=f"got {interpolation.value.lower()}")
789
        ):
790
            F.resize(
791
                input,
792
                size=self.OUTPUT_SIZES[0],
793
                interpolation=interpolation,
794
            )
795

796
    def test_functional_pil_antialias_warning(self):
797
        with pytest.warns(UserWarning, match="Anti-alias option is always applied for PIL Image input"):
798
            F.resize(make_image_pil(self.INPUT_SIZE), size=self.OUTPUT_SIZES[0], antialias=False)
799

800
    @pytest.mark.parametrize("size", OUTPUT_SIZES)
801
    @pytest.mark.parametrize(
802
        "make_input",
803
        [
804
            make_image_tensor,
805
            make_image_pil,
806
            make_image,
807
            make_bounding_boxes,
808
            make_segmentation_mask,
809
            make_detection_masks,
810
            make_video,
811
        ],
812
    )
813
    def test_max_size_error(self, size, make_input):
814
        if size is None:
815
            # value can be anything other than an integer
816
            max_size = None
817
            match = "max_size must be an integer when size is None"
818
        elif isinstance(size, int) or len(size) == 1:
819
            max_size = (size if isinstance(size, int) else size[0]) - 1
820
            match = "must be strictly greater than the requested size"
821
        else:
822
            # value can be anything other than None
823
            max_size = -1
824
            match = "size should be an int or a sequence of length 1"
825

826
        with pytest.raises(ValueError, match=match):
827
            F.resize(make_input(self.INPUT_SIZE), size=size, max_size=max_size, antialias=True)
828

829
        if isinstance(size, list) and len(size) != 1:
830
            with pytest.raises(ValueError, match="max_size should only be passed if size is None or specifies"):
831
                F.resize(make_input(self.INPUT_SIZE), size=size, max_size=500)
832

833
    @pytest.mark.parametrize(
834
        "input_size, max_size, expected_size",
835
        [
836
            ((10, 10), 10, (10, 10)),
837
            ((10, 20), 40, (20, 40)),
838
            ((20, 10), 40, (40, 20)),
839
            ((10, 20), 10, (5, 10)),
840
            ((20, 10), 10, (10, 5)),
841
        ],
842
    )
843
    @pytest.mark.parametrize(
844
        "make_input",
845
        [
846
            make_image_tensor,
847
            make_image_pil,
848
            make_image,
849
            make_bounding_boxes,
850
            make_segmentation_mask,
851
            make_detection_masks,
852
            make_video,
853
        ],
854
    )
855
    def test_resize_size_none(self, input_size, max_size, expected_size, make_input):
856
        img = make_input(input_size)
857
        out = F.resize(img, size=None, max_size=max_size)
858
        assert F.get_size(out)[-2:] == list(expected_size)
859

860
    @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
861
    @pytest.mark.parametrize(
862
        "make_input",
863
        [make_image_tensor, make_image_pil, make_image, make_video],
864
    )
865
    def test_interpolation_int(self, interpolation, make_input):
866
        input = make_input(self.INPUT_SIZE)
867

868
        # `InterpolationMode.NEAREST_EXACT` has no proper corresponding integer equivalent. Internally, we map it to
869
        # `0` to be the same as `InterpolationMode.NEAREST` for PIL. However, for the tensor backend there is a
870
        # difference and thus we don't test it here.
871
        if isinstance(input, torch.Tensor) and interpolation is transforms.InterpolationMode.NEAREST_EXACT:
872
            return
873

874
        expected = F.resize(input, size=self.OUTPUT_SIZES[0], interpolation=interpolation, antialias=True)
875
        actual = F.resize(
876
            input, size=self.OUTPUT_SIZES[0], interpolation=pil_modes_mapping[interpolation], antialias=True
877
        )
878

879
        assert_equal(actual, expected)
880

881
    def test_transform_unknown_size_error(self):
882
        with pytest.raises(ValueError, match="size can be an integer, a sequence of one or two integers, or None"):
883
            transforms.Resize(size=object())
884

885
    @pytest.mark.parametrize(
886
        "size", [min(INPUT_SIZE), [min(INPUT_SIZE)], (min(INPUT_SIZE),), list(INPUT_SIZE), tuple(INPUT_SIZE)]
887
    )
888
    @pytest.mark.parametrize(
889
        "make_input",
890
        [
891
            make_image_tensor,
892
            make_image_pil,
893
            make_image,
894
            make_bounding_boxes,
895
            make_segmentation_mask,
896
            make_detection_masks,
897
            make_video,
898
        ],
899
    )
900
    def test_noop(self, size, make_input):
901
        input = make_input(self.INPUT_SIZE)
902

903
        output = F.resize(input, size=F.get_size(input), antialias=True)
904

905
        # This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
906
        # is a good reason to break this, feel free to downgrade to an equality check.
907
        if isinstance(input, tv_tensors.TVTensor):
908
            # We can't test identity directly, since that checks for the identity of the Python object. Since all
909
            # tv_tensors unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
910
            # that the underlying storage is the same
911
            assert output.data_ptr() == input.data_ptr()
912
        else:
913
            assert output is input
914

915
    @pytest.mark.parametrize(
916
        "make_input",
917
        [
918
            make_image_tensor,
919
            make_image_pil,
920
            make_image,
921
            make_bounding_boxes,
922
            make_segmentation_mask,
923
            make_detection_masks,
924
            make_video,
925
        ],
926
    )
927
    def test_no_regression_5405(self, make_input):
928
        # Checks that `max_size` is not ignored if `size == small_edge_size`
929
        # See https://github.com/pytorch/vision/issues/5405
930

931
        input = make_input(self.INPUT_SIZE)
932

933
        size = min(F.get_size(input))
934
        max_size = size + 1
935
        output = F.resize(input, size=size, max_size=max_size, antialias=True)
936

937
        assert max(F.get_size(output)) == max_size
938

939
    def _make_image(self, *args, batch_dims=(), memory_format=torch.contiguous_format, **kwargs):
940
        # torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming
941
        # from PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the
942
        # layout of the data in memory is channels last. To emulate this when a 3D input is requested here, we create
943
        # the image as 4D and create a view with the right shape afterwards. With this the layout in memory is channels
944
        # last although PyTorch doesn't recognizes it as such.
945
        emulate_channels_last = memory_format is torch.channels_last and len(batch_dims) != 1
946

947
        image = make_image(
948
            *args,
949
            batch_dims=(math.prod(batch_dims),) if emulate_channels_last else batch_dims,
950
            memory_format=memory_format,
951
            **kwargs,
952
        )
953

954
        if emulate_channels_last:
955
            image = tv_tensors.wrap(image.view(*batch_dims, *image.shape[-3:]), like=image)
956

957
        return image
958

959
    def _check_stride(self, image, *, memory_format):
960
        C, H, W = F.get_dimensions(image)
961
        if memory_format is torch.contiguous_format:
962
            expected_stride = (H * W, W, 1)
963
        elif memory_format is torch.channels_last:
964
            expected_stride = (1, W * C, C)
965
        else:
966
            raise ValueError(f"Unknown memory_format: {memory_format}")
967

968
        assert image.stride() == expected_stride
969

970
    # TODO: We can remove this test and related torchvision workaround
971
    #  once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
972
    @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
973
    @pytest.mark.parametrize("antialias", [True, False])
974
    @pytest.mark.parametrize("memory_format", [torch.contiguous_format, torch.channels_last])
975
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
976
    @pytest.mark.parametrize("device", cpu_and_cuda())
977
    def test_kernel_image_memory_format_consistency(self, interpolation, antialias, memory_format, dtype, device):
978
        size = self.OUTPUT_SIZES[0]
979

980
        input = self._make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format)
981

982
        # Smoke test to make sure we aren't starting with wrong assumptions
983
        self._check_stride(input, memory_format=memory_format)
984

985
        output = F.resize_image(input, size=size, interpolation=interpolation, antialias=antialias)
986

987
        self._check_stride(output, memory_format=memory_format)
988

989
    def test_float16_no_rounding(self):
990
        # Make sure Resize() doesn't round float16 images
991
        # Non-regression test for https://github.com/pytorch/vision/issues/7667
992

993
        input = make_image_tensor(self.INPUT_SIZE, dtype=torch.float16)
994
        output = F.resize_image(input, size=self.OUTPUT_SIZES[0], antialias=True)
995

996
        assert output.dtype is torch.float16
997
        assert (output.round() - output).abs().sum() > 0
998

999

1000
class TestHorizontalFlip:
1001
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
1002
    @pytest.mark.parametrize("device", cpu_and_cuda())
1003
    def test_kernel_image(self, dtype, device):
1004
        check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device))
1005

1006
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1007
    @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
1008
    @pytest.mark.parametrize("device", cpu_and_cuda())
1009
    def test_kernel_bounding_boxes(self, format, dtype, device):
1010
        bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
1011
        check_kernel(
1012
            F.horizontal_flip_bounding_boxes,
1013
            bounding_boxes,
1014
            format=format,
1015
            canvas_size=bounding_boxes.canvas_size,
1016
        )
1017

1018
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
1019
    def test_kernel_mask(self, make_mask):
1020
        check_kernel(F.horizontal_flip_mask, make_mask())
1021

1022
    def test_kernel_video(self):
1023
        check_kernel(F.horizontal_flip_video, make_video())
1024

1025
    @pytest.mark.parametrize(
1026
        "make_input",
1027
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1028
    )
1029
    def test_functional(self, make_input):
1030
        check_functional(F.horizontal_flip, make_input())
1031

1032
    @pytest.mark.parametrize(
1033
        ("kernel", "input_type"),
1034
        [
1035
            (F.horizontal_flip_image, torch.Tensor),
1036
            (F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
1037
            (F.horizontal_flip_image, tv_tensors.Image),
1038
            (F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
1039
            (F.horizontal_flip_mask, tv_tensors.Mask),
1040
            (F.horizontal_flip_video, tv_tensors.Video),
1041
        ],
1042
    )
1043
    def test_functional_signature(self, kernel, input_type):
1044
        check_functional_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
1045

1046
    @pytest.mark.parametrize(
1047
        "make_input",
1048
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1049
    )
1050
    @pytest.mark.parametrize("device", cpu_and_cuda())
1051
    def test_transform(self, make_input, device):
1052
        check_transform(transforms.RandomHorizontalFlip(p=1), make_input(device=device))
1053

1054
    @pytest.mark.parametrize(
1055
        "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
1056
    )
1057
    def test_image_correctness(self, fn):
1058
        image = make_image(dtype=torch.uint8, device="cpu")
1059

1060
        actual = fn(image)
1061
        expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))
1062

1063
        torch.testing.assert_close(actual, expected)
1064

1065
    def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes):
1066
        affine_matrix = np.array(
1067
            [
1068
                [-1, 0, bounding_boxes.canvas_size[1]],
1069
                [0, 1, 0],
1070
            ],
1071
        )
1072

1073
        return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
1074

1075
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1076
    @pytest.mark.parametrize(
1077
        "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
1078
    )
1079
    def test_bounding_boxes_correctness(self, format, fn):
1080
        bounding_boxes = make_bounding_boxes(format=format)
1081

1082
        actual = fn(bounding_boxes)
1083
        expected = self._reference_horizontal_flip_bounding_boxes(bounding_boxes)
1084

1085
        torch.testing.assert_close(actual, expected)
1086

1087
    @pytest.mark.parametrize(
1088
        "make_input",
1089
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1090
    )
1091
    @pytest.mark.parametrize("device", cpu_and_cuda())
1092
    def test_transform_noop(self, make_input, device):
1093
        input = make_input(device=device)
1094

1095
        transform = transforms.RandomHorizontalFlip(p=0)
1096

1097
        output = transform(input)
1098

1099
        assert_equal(output, input)
1100

1101

1102
class TestAffine:
1103
    _EXHAUSTIVE_TYPE_AFFINE_KWARGS = dict(
1104
        # float, int
1105
        angle=[-10.9, 18],
1106
        # two-list of float, two-list of int, two-tuple of float, two-tuple of int
1107
        translate=[[6.3, -0.6], [1, -3], (16.6, -6.6), (-2, 4)],
1108
        # float
1109
        scale=[0.5],
1110
        # float, int,
1111
        # one-list of float, one-list of int, one-tuple of float, one-tuple of int
1112
        # two-list of float, two-list of int, two-tuple of float, two-tuple of int
1113
        shear=[35.6, 38, [-37.7], [-23], (5.3,), (-52,), [5.4, 21.8], [-47, 51], (-11.2, 36.7), (8, -53)],
1114
        # None
1115
        # two-list of float, two-list of int, two-tuple of float, two-tuple of int
1116
        center=[None, [1.2, 4.9], [-3, 1], (2.5, -4.7), (3, 2)],
1117
    )
1118
    # The special case for shear makes sure we pick a value that is supported while JIT scripting
1119
    _MINIMAL_AFFINE_KWARGS = {
1120
        k: vs[0] if k != "shear" else next(v for v in vs if isinstance(v, list))
1121
        for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items()
1122
    }
1123
    _CORRECTNESS_AFFINE_KWARGS = {
1124
        k: [v for v in vs if v is None or isinstance(v, float) or (isinstance(v, list) and len(v) > 1)]
1125
        for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items()
1126
    }
1127

1128
    _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES = dict(
1129
        degrees=[30, (-15, 20)],
1130
        translate=[None, (0.5, 0.5)],
1131
        scale=[None, (0.75, 1.25)],
1132
        shear=[None, (12, 30, -17, 5), 10, (-5, 12)],
1133
    )
1134
    _CORRECTNESS_TRANSFORM_AFFINE_RANGES = {
1135
        k: next(v for v in vs if v is not None) for k, vs in _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES.items()
1136
    }
1137

1138
    def _check_kernel(self, kernel, input, *args, **kwargs):
1139
        kwargs_ = self._MINIMAL_AFFINE_KWARGS.copy()
1140
        kwargs_.update(kwargs)
1141
        check_kernel(kernel, input, *args, **kwargs_)
1142

1143
    @param_value_parametrization(
1144
        angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"],
1145
        translate=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"],
1146
        shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"],
1147
        center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
1148
        interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
1149
        fill=EXHAUSTIVE_TYPE_FILLS,
1150
    )
1151
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
1152
    @pytest.mark.parametrize("device", cpu_and_cuda())
1153
    def test_kernel_image(self, param, value, dtype, device):
1154
        if param == "fill":
1155
            value = adapt_fill(value, dtype=dtype)
1156
        self._check_kernel(
1157
            F.affine_image,
1158
            make_image(dtype=dtype, device=device),
1159
            **{param: value},
1160
            check_scripted_vs_eager=not (param in {"shear", "fill"} and isinstance(value, (int, float))),
1161
            check_cuda_vs_cpu=dict(atol=1, rtol=0)
1162
            if dtype is torch.uint8 and param == "interpolation" and value is transforms.InterpolationMode.BILINEAR
1163
            else True,
1164
        )
1165

1166
    @param_value_parametrization(
1167
        angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"],
1168
        translate=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"],
1169
        shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"],
1170
        center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
1171
    )
1172
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1173
    @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
1174
    @pytest.mark.parametrize("device", cpu_and_cuda())
1175
    def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
1176
        bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
1177
        self._check_kernel(
1178
            F.affine_bounding_boxes,
1179
            bounding_boxes,
1180
            format=format,
1181
            canvas_size=bounding_boxes.canvas_size,
1182
            **{param: value},
1183
            check_scripted_vs_eager=not (param == "shear" and isinstance(value, (int, float))),
1184
        )
1185

1186
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
1187
    def test_kernel_mask(self, make_mask):
1188
        self._check_kernel(F.affine_mask, make_mask())
1189

1190
    def test_kernel_video(self):
1191
        self._check_kernel(F.affine_video, make_video())
1192

1193
    @pytest.mark.parametrize(
1194
        "make_input",
1195
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1196
    )
1197
    def test_functional(self, make_input):
1198
        check_functional(F.affine, make_input(), **self._MINIMAL_AFFINE_KWARGS)
1199

1200
    @pytest.mark.parametrize(
1201
        ("kernel", "input_type"),
1202
        [
1203
            (F.affine_image, torch.Tensor),
1204
            (F._geometry._affine_image_pil, PIL.Image.Image),
1205
            (F.affine_image, tv_tensors.Image),
1206
            (F.affine_bounding_boxes, tv_tensors.BoundingBoxes),
1207
            (F.affine_mask, tv_tensors.Mask),
1208
            (F.affine_video, tv_tensors.Video),
1209
        ],
1210
    )
1211
    def test_functional_signature(self, kernel, input_type):
1212
        check_functional_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type)
1213

1214
    @pytest.mark.parametrize(
1215
        "make_input",
1216
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1217
    )
1218
    @pytest.mark.parametrize("device", cpu_and_cuda())
1219
    def test_transform(self, make_input, device):
1220
        input = make_input(device=device)
1221

1222
        check_transform(transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), input)
1223

1224
    @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
1225
    @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
1226
    @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])
1227
    @pytest.mark.parametrize("shear", _CORRECTNESS_AFFINE_KWARGS["shear"])
1228
    @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1229
    @pytest.mark.parametrize(
1230
        "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
1231
    )
1232
    @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
1233
    def test_functional_image_correctness(self, angle, translate, scale, shear, center, interpolation, fill):
1234
        image = make_image(dtype=torch.uint8, device="cpu")
1235

1236
        fill = adapt_fill(fill, dtype=torch.uint8)
1237

1238
        actual = F.affine(
1239
            image,
1240
            angle=angle,
1241
            translate=translate,
1242
            scale=scale,
1243
            shear=shear,
1244
            center=center,
1245
            interpolation=interpolation,
1246
            fill=fill,
1247
        )
1248
        expected = F.to_image(
1249
            F.affine(
1250
                F.to_pil_image(image),
1251
                angle=angle,
1252
                translate=translate,
1253
                scale=scale,
1254
                shear=shear,
1255
                center=center,
1256
                interpolation=interpolation,
1257
                fill=fill,
1258
            )
1259
        )
1260

1261
        mae = (actual.float() - expected.float()).abs().mean()
1262
        assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
1263

1264
    @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1265
    @pytest.mark.parametrize(
1266
        "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
1267
    )
1268
    @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
1269
    @pytest.mark.parametrize("seed", list(range(5)))
1270
    def test_transform_image_correctness(self, center, interpolation, fill, seed):
1271
        image = make_image(dtype=torch.uint8, device="cpu")
1272

1273
        fill = adapt_fill(fill, dtype=torch.uint8)
1274

1275
        transform = transforms.RandomAffine(
1276
            **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center, interpolation=interpolation, fill=fill
1277
        )
1278

1279
        torch.manual_seed(seed)
1280
        actual = transform(image)
1281

1282
        torch.manual_seed(seed)
1283
        expected = F.to_image(transform(F.to_pil_image(image)))
1284

1285
        mae = (actual.float() - expected.float()).abs().mean()
1286
        assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
1287

1288
    def _compute_affine_matrix(self, *, angle, translate, scale, shear, center):
1289
        rot = math.radians(angle)
1290
        cx, cy = center
1291
        tx, ty = translate
1292
        sx, sy = [math.radians(s) for s in ([shear, 0.0] if isinstance(shear, (int, float)) else shear)]
1293

1294
        c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
1295
        t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
1296
        c_matrix_inv = np.linalg.inv(c_matrix)
1297
        rs_matrix = np.array(
1298
            [
1299
                [scale * math.cos(rot), -scale * math.sin(rot), 0],
1300
                [scale * math.sin(rot), scale * math.cos(rot), 0],
1301
                [0, 0, 1],
1302
            ]
1303
        )
1304
        shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
1305
        shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
1306
        rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
1307
        true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
1308
        return true_matrix[:2, :]
1309

1310
    def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate, scale, shear, center):
1311
        if center is None:
1312
            center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
1313

1314
        return reference_affine_bounding_boxes_helper(
1315
            bounding_boxes,
1316
            affine_matrix=self._compute_affine_matrix(
1317
                angle=angle, translate=translate, scale=scale, shear=shear, center=center
1318
            ),
1319
        )
1320

1321
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1322
    @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
1323
    @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
1324
    @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])
1325
    @pytest.mark.parametrize("shear", _CORRECTNESS_AFFINE_KWARGS["shear"])
1326
    @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1327
    def test_functional_bounding_boxes_correctness(self, format, angle, translate, scale, shear, center):
1328
        bounding_boxes = make_bounding_boxes(format=format)
1329

1330
        actual = F.affine(
1331
            bounding_boxes,
1332
            angle=angle,
1333
            translate=translate,
1334
            scale=scale,
1335
            shear=shear,
1336
            center=center,
1337
        )
1338
        expected = self._reference_affine_bounding_boxes(
1339
            bounding_boxes,
1340
            angle=angle,
1341
            translate=translate,
1342
            scale=scale,
1343
            shear=shear,
1344
            center=center,
1345
        )
1346

1347
        torch.testing.assert_close(actual, expected)
1348

1349
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1350
    @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1351
    @pytest.mark.parametrize("seed", list(range(5)))
1352
    def test_transform_bounding_boxes_correctness(self, format, center, seed):
1353
        bounding_boxes = make_bounding_boxes(format=format)
1354

1355
        transform = transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center)
1356

1357
        torch.manual_seed(seed)
1358
        params = transform._get_params([bounding_boxes])
1359

1360
        torch.manual_seed(seed)
1361
        actual = transform(bounding_boxes)
1362

1363
        expected = self._reference_affine_bounding_boxes(bounding_boxes, **params, center=center)
1364

1365
        torch.testing.assert_close(actual, expected)
1366

1367
    @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"])
1368
    @pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["translate"])
1369
    @pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["scale"])
1370
    @pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["shear"])
1371
    @pytest.mark.parametrize("seed", list(range(10)))
1372
    def test_transform_get_params_bounds(self, degrees, translate, scale, shear, seed):
1373
        image = make_image()
1374
        height, width = F.get_size(image)
1375

1376
        transform = transforms.RandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear)
1377

1378
        torch.manual_seed(seed)
1379
        params = transform._get_params([image])
1380

1381
        if isinstance(degrees, (int, float)):
1382
            assert -degrees <= params["angle"] <= degrees
1383
        else:
1384
            assert degrees[0] <= params["angle"] <= degrees[1]
1385

1386
        if translate is not None:
1387
            width_max = int(round(translate[0] * width))
1388
            height_max = int(round(translate[1] * height))
1389
            assert -width_max <= params["translate"][0] <= width_max
1390
            assert -height_max <= params["translate"][1] <= height_max
1391
        else:
1392
            assert params["translate"] == (0, 0)
1393

1394
        if scale is not None:
1395
            assert scale[0] <= params["scale"] <= scale[1]
1396
        else:
1397
            assert params["scale"] == 1.0
1398

1399
        if shear is not None:
1400
            if isinstance(shear, (int, float)):
1401
                assert -shear <= params["shear"][0] <= shear
1402
                assert params["shear"][1] == 0.0
1403
            elif len(shear) == 2:
1404
                assert shear[0] <= params["shear"][0] <= shear[1]
1405
                assert params["shear"][1] == 0.0
1406
            elif len(shear) == 4:
1407
                assert shear[0] <= params["shear"][0] <= shear[1]
1408
                assert shear[2] <= params["shear"][1] <= shear[3]
1409
        else:
1410
            assert params["shear"] == (0, 0)
1411

1412
    @pytest.mark.parametrize("param", ["degrees", "translate", "scale", "shear", "center"])
1413
    @pytest.mark.parametrize("value", [0, [0], [0, 0, 0]])
1414
    def test_transform_sequence_len_errors(self, param, value):
1415
        if param in {"degrees", "shear"} and not isinstance(value, list):
1416
            return
1417

1418
        kwargs = {param: value}
1419
        if param != "degrees":
1420
            kwargs["degrees"] = 0
1421

1422
        with pytest.raises(
1423
            ValueError if isinstance(value, list) else TypeError, match=f"{param} should be a sequence of length 2"
1424
        ):
1425
            transforms.RandomAffine(**kwargs)
1426

1427
    def test_transform_negative_degrees_error(self):
1428
        with pytest.raises(ValueError, match="If degrees is a single number, it must be positive"):
1429
            transforms.RandomAffine(degrees=-1)
1430

1431
    @pytest.mark.parametrize("translate", [[-1, 0], [2, 0], [-1, 2]])
1432
    def test_transform_translate_range_error(self, translate):
1433
        with pytest.raises(ValueError, match="translation values should be between 0 and 1"):
1434
            transforms.RandomAffine(degrees=0, translate=translate)
1435

1436
    @pytest.mark.parametrize("scale", [[-1, 0], [0, -1], [-1, -1]])
1437
    def test_transform_scale_range_error(self, scale):
1438
        with pytest.raises(ValueError, match="scale values should be positive"):
1439
            transforms.RandomAffine(degrees=0, scale=scale)
1440

1441
    def test_transform_negative_shear_error(self):
1442
        with pytest.raises(ValueError, match="If shear is a single number, it must be positive"):
1443
            transforms.RandomAffine(degrees=0, shear=-1)
1444

1445
    def test_transform_unknown_fill_error(self):
1446
        with pytest.raises(TypeError, match="Got inappropriate fill arg"):
1447
            transforms.RandomAffine(degrees=0, fill="fill")
1448

1449

1450
class TestVerticalFlip:
1451
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
1452
    @pytest.mark.parametrize("device", cpu_and_cuda())
1453
    def test_kernel_image(self, dtype, device):
1454
        check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device))
1455

1456
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1457
    @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
1458
    @pytest.mark.parametrize("device", cpu_and_cuda())
1459
    def test_kernel_bounding_boxes(self, format, dtype, device):
1460
        bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
1461
        check_kernel(
1462
            F.vertical_flip_bounding_boxes,
1463
            bounding_boxes,
1464
            format=format,
1465
            canvas_size=bounding_boxes.canvas_size,
1466
        )
1467

1468
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
1469
    def test_kernel_mask(self, make_mask):
1470
        check_kernel(F.vertical_flip_mask, make_mask())
1471

1472
    def test_kernel_video(self):
1473
        check_kernel(F.vertical_flip_video, make_video())
1474

1475
    @pytest.mark.parametrize(
1476
        "make_input",
1477
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1478
    )
1479
    def test_functional(self, make_input):
1480
        check_functional(F.vertical_flip, make_input())
1481

1482
    @pytest.mark.parametrize(
1483
        ("kernel", "input_type"),
1484
        [
1485
            (F.vertical_flip_image, torch.Tensor),
1486
            (F._geometry._vertical_flip_image_pil, PIL.Image.Image),
1487
            (F.vertical_flip_image, tv_tensors.Image),
1488
            (F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
1489
            (F.vertical_flip_mask, tv_tensors.Mask),
1490
            (F.vertical_flip_video, tv_tensors.Video),
1491
        ],
1492
    )
1493
    def test_functional_signature(self, kernel, input_type):
1494
        check_functional_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)
1495

1496
    @pytest.mark.parametrize(
1497
        "make_input",
1498
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1499
    )
1500
    @pytest.mark.parametrize("device", cpu_and_cuda())
1501
    def test_transform(self, make_input, device):
1502
        check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))
1503

1504
    @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
1505
    def test_image_correctness(self, fn):
1506
        image = make_image(dtype=torch.uint8, device="cpu")
1507

1508
        actual = fn(image)
1509
        expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))
1510

1511
        torch.testing.assert_close(actual, expected)
1512

1513
    def _reference_vertical_flip_bounding_boxes(self, bounding_boxes):
1514
        affine_matrix = np.array(
1515
            [
1516
                [1, 0, 0],
1517
                [0, -1, bounding_boxes.canvas_size[0]],
1518
            ],
1519
        )
1520

1521
        return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
1522

1523
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1524
    @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
1525
    def test_bounding_boxes_correctness(self, format, fn):
1526
        bounding_boxes = make_bounding_boxes(format=format)
1527

1528
        actual = fn(bounding_boxes)
1529
        expected = self._reference_vertical_flip_bounding_boxes(bounding_boxes)
1530

1531
        torch.testing.assert_close(actual, expected)
1532

1533
    @pytest.mark.parametrize(
1534
        "make_input",
1535
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1536
    )
1537
    @pytest.mark.parametrize("device", cpu_and_cuda())
1538
    def test_transform_noop(self, make_input, device):
1539
        input = make_input(device=device)
1540

1541
        transform = transforms.RandomVerticalFlip(p=0)
1542

1543
        output = transform(input)
1544

1545
        assert_equal(output, input)
1546

1547

1548
class TestRotate:
1549
    _EXHAUSTIVE_TYPE_AFFINE_KWARGS = dict(
1550
        # float, int
1551
        angle=[-10.9, 18],
1552
        # None
1553
        # two-list of float, two-list of int, two-tuple of float, two-tuple of int
1554
        center=[None, [1.2, 4.9], [-3, 1], (2.5, -4.7), (3, 2)],
1555
    )
1556
    _MINIMAL_AFFINE_KWARGS = {k: vs[0] for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items()}
1557
    _CORRECTNESS_AFFINE_KWARGS = {
1558
        k: [v for v in vs if v is None or isinstance(v, float) or isinstance(v, list)]
1559
        for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items()
1560
    }
1561

1562
    _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES = dict(
1563
        degrees=[30, (-15, 20)],
1564
    )
1565
    _CORRECTNESS_TRANSFORM_AFFINE_RANGES = {k: vs[0] for k, vs in _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES.items()}
1566

1567
    @param_value_parametrization(
1568
        angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"],
1569
        interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
1570
        expand=[False, True],
1571
        center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
1572
        fill=EXHAUSTIVE_TYPE_FILLS,
1573
    )
1574
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
1575
    @pytest.mark.parametrize("device", cpu_and_cuda())
1576
    def test_kernel_image(self, param, value, dtype, device):
1577
        kwargs = {param: value}
1578
        if param != "angle":
1579
            kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"]
1580
        check_kernel(
1581
            F.rotate_image,
1582
            make_image(dtype=dtype, device=device),
1583
            **kwargs,
1584
            check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
1585
        )
1586

1587
    @param_value_parametrization(
1588
        angle=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"],
1589
        expand=[False, True],
1590
        center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
1591
    )
1592
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1593
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
1594
    @pytest.mark.parametrize("device", cpu_and_cuda())
1595
    def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
1596
        kwargs = {param: value}
1597
        if param != "angle":
1598
            kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"]
1599

1600
        bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
1601

1602
        check_kernel(
1603
            F.rotate_bounding_boxes,
1604
            bounding_boxes,
1605
            format=format,
1606
            canvas_size=bounding_boxes.canvas_size,
1607
            **kwargs,
1608
        )
1609

1610
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
1611
    def test_kernel_mask(self, make_mask):
1612
        check_kernel(F.rotate_mask, make_mask(), **self._MINIMAL_AFFINE_KWARGS)
1613

1614
    def test_kernel_video(self):
1615
        check_kernel(F.rotate_video, make_video(), **self._MINIMAL_AFFINE_KWARGS)
1616

1617
    @pytest.mark.parametrize(
1618
        "make_input",
1619
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1620
    )
1621
    def test_functional(self, make_input):
1622
        check_functional(F.rotate, make_input(), **self._MINIMAL_AFFINE_KWARGS)
1623

1624
    @pytest.mark.parametrize(
1625
        ("kernel", "input_type"),
1626
        [
1627
            (F.rotate_image, torch.Tensor),
1628
            (F._geometry._rotate_image_pil, PIL.Image.Image),
1629
            (F.rotate_image, tv_tensors.Image),
1630
            (F.rotate_bounding_boxes, tv_tensors.BoundingBoxes),
1631
            (F.rotate_mask, tv_tensors.Mask),
1632
            (F.rotate_video, tv_tensors.Video),
1633
        ],
1634
    )
1635
    def test_functional_signature(self, kernel, input_type):
1636
        check_functional_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type)
1637

1638
    @pytest.mark.parametrize(
1639
        "make_input",
1640
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1641
    )
1642
    @pytest.mark.parametrize("device", cpu_and_cuda())
1643
    def test_transform(self, make_input, device):
1644
        check_transform(
1645
            transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), make_input(device=device)
1646
        )
1647

1648
    @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
1649
    @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1650
    @pytest.mark.parametrize(
1651
        "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
1652
    )
1653
    @pytest.mark.parametrize("expand", [False, True])
1654
    @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
1655
    def test_functional_image_correctness(self, angle, center, interpolation, expand, fill):
1656
        image = make_image(dtype=torch.uint8, device="cpu")
1657

1658
        fill = adapt_fill(fill, dtype=torch.uint8)
1659

1660
        actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill)
1661
        expected = F.to_image(
1662
            F.rotate(
1663
                F.to_pil_image(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill
1664
            )
1665
        )
1666

1667
        mae = (actual.float() - expected.float()).abs().mean()
1668
        assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
1669

1670
    @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1671
    @pytest.mark.parametrize(
1672
        "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
1673
    )
1674
    @pytest.mark.parametrize("expand", [False, True])
1675
    @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
1676
    @pytest.mark.parametrize("seed", list(range(5)))
1677
    def test_transform_image_correctness(self, center, interpolation, expand, fill, seed):
1678
        image = make_image(dtype=torch.uint8, device="cpu")
1679

1680
        fill = adapt_fill(fill, dtype=torch.uint8)
1681

1682
        transform = transforms.RandomRotation(
1683
            **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES,
1684
            center=center,
1685
            interpolation=interpolation,
1686
            expand=expand,
1687
            fill=fill,
1688
        )
1689

1690
        torch.manual_seed(seed)
1691
        actual = transform(image)
1692

1693
        torch.manual_seed(seed)
1694
        expected = F.to_image(transform(F.to_pil_image(image)))
1695

1696
        mae = (actual.float() - expected.float()).abs().mean()
1697
        assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
1698

1699
    def _compute_output_canvas_size(self, *, expand, canvas_size, affine_matrix):
1700
        if not expand:
1701
            return canvas_size, (0.0, 0.0)
1702

1703
        input_height, input_width = canvas_size
1704

1705
        input_image_frame = np.array(
1706
            [
1707
                [0.0, 0.0, 1.0],
1708
                [0.0, input_height, 1.0],
1709
                [input_width, input_height, 1.0],
1710
                [input_width, 0.0, 1.0],
1711
            ],
1712
            dtype=np.float64,
1713
        )
1714
        output_image_frame = np.matmul(input_image_frame, affine_matrix.astype(input_image_frame.dtype).T)
1715

1716
        recenter_x = float(np.min(output_image_frame[:, 0]))
1717
        recenter_y = float(np.min(output_image_frame[:, 1]))
1718

1719
        output_width = int(np.max(output_image_frame[:, 0]) - recenter_x)
1720
        output_height = int(np.max(output_image_frame[:, 1]) - recenter_y)
1721

1722
        return (output_height, output_width), (recenter_x, recenter_y)
1723

1724
    def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy):
1725
        x, y = recenter_xy
1726
        if bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYXY:
1727
            translate = [x, y, x, y]
1728
        else:
1729
            translate = [x, y, 0.0, 0.0]
1730
        return tv_tensors.wrap(
1731
            (bounding_boxes.to(torch.float64) - torch.tensor(translate)).to(bounding_boxes.dtype), like=bounding_boxes
1732
        )
1733

1734
    def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, center):
1735
        if center is None:
1736
            center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
1737
        cx, cy = center
1738

1739
        a = np.cos(angle * np.pi / 180.0)
1740
        b = np.sin(angle * np.pi / 180.0)
1741
        affine_matrix = np.array(
1742
            [
1743
                [a, b, cx - cx * a - b * cy],
1744
                [-b, a, cy + cx * b - a * cy],
1745
            ],
1746
        )
1747

1748
        new_canvas_size, recenter_xy = self._compute_output_canvas_size(
1749
            expand=expand, canvas_size=bounding_boxes.canvas_size, affine_matrix=affine_matrix
1750
        )
1751

1752
        output = reference_affine_bounding_boxes_helper(
1753
            bounding_boxes,
1754
            affine_matrix=affine_matrix,
1755
            new_canvas_size=new_canvas_size,
1756
            clamp=False,
1757
        )
1758

1759
        return F.clamp_bounding_boxes(self._recenter_bounding_boxes_after_expand(output, recenter_xy=recenter_xy)).to(
1760
            bounding_boxes
1761
        )
1762

1763
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1764
    @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
1765
    @pytest.mark.parametrize("expand", [False, True])
1766
    @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1767
    def test_functional_bounding_boxes_correctness(self, format, angle, expand, center):
1768
        bounding_boxes = make_bounding_boxes(format=format)
1769

1770
        actual = F.rotate(bounding_boxes, angle=angle, expand=expand, center=center)
1771
        expected = self._reference_rotate_bounding_boxes(bounding_boxes, angle=angle, expand=expand, center=center)
1772

1773
        torch.testing.assert_close(actual, expected)
1774
        torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
1775

1776
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
1777
    @pytest.mark.parametrize("expand", [False, True])
1778
    @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1779
    @pytest.mark.parametrize("seed", list(range(5)))
1780
    def test_transform_bounding_boxes_correctness(self, format, expand, center, seed):
1781
        bounding_boxes = make_bounding_boxes(format=format)
1782

1783
        transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center)
1784

1785
        torch.manual_seed(seed)
1786
        params = transform._get_params([bounding_boxes])
1787

1788
        torch.manual_seed(seed)
1789
        actual = transform(bounding_boxes)
1790

1791
        expected = self._reference_rotate_bounding_boxes(bounding_boxes, **params, expand=expand, center=center)
1792

1793
        torch.testing.assert_close(actual, expected)
1794
        torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
1795

1796
    @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"])
1797
    @pytest.mark.parametrize("seed", list(range(10)))
1798
    def test_transform_get_params_bounds(self, degrees, seed):
1799
        transform = transforms.RandomRotation(degrees=degrees)
1800

1801
        torch.manual_seed(seed)
1802
        params = transform._get_params([])
1803

1804
        if isinstance(degrees, (int, float)):
1805
            assert -degrees <= params["angle"] <= degrees
1806
        else:
1807
            assert degrees[0] <= params["angle"] <= degrees[1]
1808

1809
    @pytest.mark.parametrize("param", ["degrees", "center"])
1810
    @pytest.mark.parametrize("value", [0, [0], [0, 0, 0]])
1811
    def test_transform_sequence_len_errors(self, param, value):
1812
        if param == "degrees" and not isinstance(value, list):
1813
            return
1814

1815
        kwargs = {param: value}
1816
        if param != "degrees":
1817
            kwargs["degrees"] = 0
1818

1819
        with pytest.raises(
1820
            ValueError if isinstance(value, list) else TypeError, match=f"{param} should be a sequence of length 2"
1821
        ):
1822
            transforms.RandomRotation(**kwargs)
1823

1824
    def test_transform_negative_degrees_error(self):
1825
        with pytest.raises(ValueError, match="If degrees is a single number, it must be positive"):
1826
            transforms.RandomAffine(degrees=-1)
1827

1828
    def test_transform_unknown_fill_error(self):
1829
        with pytest.raises(TypeError, match="Got inappropriate fill arg"):
1830
            transforms.RandomAffine(degrees=0, fill="fill")
1831

1832
    @pytest.mark.parametrize("size", [(11, 17), (16, 16)])
1833
    @pytest.mark.parametrize("angle", [0, 90, 180, 270])
1834
    @pytest.mark.parametrize("expand", [False, True])
1835
    def test_functional_image_fast_path_correctness(self, size, angle, expand):
1836
        image = make_image(size, dtype=torch.uint8, device="cpu")
1837

1838
        actual = F.rotate(image, angle=angle, expand=expand)
1839
        expected = F.to_image(F.rotate(F.to_pil_image(image), angle=angle, expand=expand))
1840

1841
        torch.testing.assert_close(actual, expected)
1842

1843

1844
class TestContainerTransforms:
1845
    class BuiltinTransform(transforms.Transform):
1846
        def _transform(self, inpt, params):
1847
            return inpt
1848

1849
    class PackedInputTransform(nn.Module):
1850
        def forward(self, sample):
1851
            assert len(sample) == 2
1852
            return sample
1853

1854
    class UnpackedInputTransform(nn.Module):
1855
        def forward(self, image, label):
1856
            return image, label
1857

1858
    @pytest.mark.parametrize(
1859
        "transform_cls", [transforms.Compose, functools.partial(transforms.RandomApply, p=1), transforms.RandomOrder]
1860
    )
1861
    @pytest.mark.parametrize(
1862
        "wrapped_transform_clss",
1863
        [
1864
            [BuiltinTransform],
1865
            [PackedInputTransform],
1866
            [UnpackedInputTransform],
1867
            [BuiltinTransform, BuiltinTransform],
1868
            [PackedInputTransform, PackedInputTransform],
1869
            [UnpackedInputTransform, UnpackedInputTransform],
1870
            [BuiltinTransform, PackedInputTransform, BuiltinTransform],
1871
            [BuiltinTransform, UnpackedInputTransform, BuiltinTransform],
1872
            [PackedInputTransform, BuiltinTransform, PackedInputTransform],
1873
            [UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform],
1874
        ],
1875
    )
1876
    @pytest.mark.parametrize("unpack", [True, False])
1877
    def test_packed_unpacked(self, transform_cls, wrapped_transform_clss, unpack):
1878
        needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in wrapped_transform_clss)
1879
        needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in wrapped_transform_clss)
1880
        assert not (needs_packed_inputs and needs_unpacked_inputs)
1881

1882
        transform = transform_cls([cls() for cls in wrapped_transform_clss])
1883

1884
        image = make_image()
1885
        label = 3
1886
        packed_input = (image, label)
1887

1888
        def call_transform():
1889
            if unpack:
1890
                return transform(*packed_input)
1891
            else:
1892
                return transform(packed_input)
1893

1894
        if needs_unpacked_inputs and not unpack:
1895
            with pytest.raises(TypeError, match="missing 1 required positional argument"):
1896
                call_transform()
1897
        elif needs_packed_inputs and unpack:
1898
            with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"):
1899
                call_transform()
1900
        else:
1901
            output = call_transform()
1902

1903
            assert isinstance(output, tuple) and len(output) == 2
1904
            assert output[0] is image
1905
            assert output[1] is label
1906

1907
    def test_compose(self):
1908
        transform = transforms.Compose(
1909
            [
1910
                transforms.RandomHorizontalFlip(p=1),
1911
                transforms.RandomVerticalFlip(p=1),
1912
            ]
1913
        )
1914

1915
        input = make_image()
1916

1917
        actual = check_transform(transform, input)
1918
        expected = F.vertical_flip(F.horizontal_flip(input))
1919

1920
        assert_equal(actual, expected)
1921

1922
    @pytest.mark.parametrize("p", [0.0, 1.0])
1923
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
1924
    def test_random_apply(self, p, sequence_type):
1925
        transform = transforms.RandomApply(
1926
            sequence_type(
1927
                [
1928
                    transforms.RandomHorizontalFlip(p=1),
1929
                    transforms.RandomVerticalFlip(p=1),
1930
                ]
1931
            ),
1932
            p=p,
1933
        )
1934

1935
        # This needs to be a pure tensor (or a PIL image), because otherwise check_transforms skips the v1 compatibility
1936
        # check
1937
        input = make_image_tensor()
1938
        output = check_transform(transform, input, check_v1_compatibility=issubclass(sequence_type, nn.ModuleList))
1939

1940
        if p == 1:
1941
            assert_equal(output, F.vertical_flip(F.horizontal_flip(input)))
1942
        else:
1943
            assert output is input
1944

1945
    @pytest.mark.parametrize("p", [(0, 1), (1, 0)])
1946
    def test_random_choice(self, p):
1947
        transform = transforms.RandomChoice(
1948
            [
1949
                transforms.RandomHorizontalFlip(p=1),
1950
                transforms.RandomVerticalFlip(p=1),
1951
            ],
1952
            p=p,
1953
        )
1954

1955
        input = make_image()
1956
        output = check_transform(transform, input)
1957

1958
        p_horz, p_vert = p
1959
        if p_horz:
1960
            assert_equal(output, F.horizontal_flip(input))
1961
        else:
1962
            assert_equal(output, F.vertical_flip(input))
1963

1964
    def test_random_order(self):
1965
        transform = transforms.Compose(
1966
            [
1967
                transforms.RandomHorizontalFlip(p=1),
1968
                transforms.RandomVerticalFlip(p=1),
1969
            ]
1970
        )
1971

1972
        input = make_image()
1973

1974
        actual = check_transform(transform, input)
1975
        # We can't really check whether the transforms are actually applied in random order. However, horizontal and
1976
        # vertical flip are commutative. Meaning, even under the assumption that the transform applies them in random
1977
        # order, we can use a fixed order to compute the expected value.
1978
        expected = F.vertical_flip(F.horizontal_flip(input))
1979

1980
        assert_equal(actual, expected)
1981

1982
    def test_errors(self):
1983
        for cls in [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]:
1984
            with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
1985
                cls(lambda x: x)
1986

1987
        with pytest.raises(ValueError, match="at least one transform"):
1988
            transforms.Compose([])
1989

1990
        for p in [-1, 2]:
1991
            with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")):
1992
                transforms.RandomApply([lambda x: x], p=p)
1993

1994
        for transforms_, p in [([lambda x: x], []), ([], [1.0])]:
1995
            with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"):
1996
                transforms.RandomChoice(transforms_, p=p)
1997

1998

1999
class TestToDtype:
2000
    @pytest.mark.parametrize(
2001
        ("kernel", "make_input"),
2002
        [
2003
            (F.to_dtype_image, make_image_tensor),
2004
            (F.to_dtype_image, make_image),
2005
            (F.to_dtype_video, make_video),
2006
        ],
2007
    )
2008
    @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
2009
    @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
2010
    @pytest.mark.parametrize("device", cpu_and_cuda())
2011
    @pytest.mark.parametrize("scale", (True, False))
2012
    def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, scale):
2013
        check_kernel(
2014
            kernel,
2015
            make_input(dtype=input_dtype, device=device),
2016
            dtype=output_dtype,
2017
            scale=scale,
2018
        )
2019

2020
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
2021
    @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
2022
    @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
2023
    @pytest.mark.parametrize("device", cpu_and_cuda())
2024
    @pytest.mark.parametrize("scale", (True, False))
2025
    def test_functional(self, make_input, input_dtype, output_dtype, device, scale):
2026
        check_functional(
2027
            F.to_dtype,
2028
            make_input(dtype=input_dtype, device=device),
2029
            dtype=output_dtype,
2030
            scale=scale,
2031
        )
2032

2033
    @pytest.mark.parametrize(
2034
        "make_input",
2035
        [make_image_tensor, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
2036
    )
2037
    @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
2038
    @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
2039
    @pytest.mark.parametrize("device", cpu_and_cuda())
2040
    @pytest.mark.parametrize("scale", (True, False))
2041
    @pytest.mark.parametrize("as_dict", (True, False))
2042
    def test_transform(self, make_input, input_dtype, output_dtype, device, scale, as_dict):
2043
        input = make_input(dtype=input_dtype, device=device)
2044
        if as_dict:
2045
            output_dtype = {type(input): output_dtype}
2046
        check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input, check_sample_input=not as_dict)
2047

2048
    def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False):
2049
        input_dtype = image.dtype
2050
        output_dtype = dtype
2051

2052
        if not scale:
2053
            return image.to(dtype)
2054

2055
        if output_dtype == input_dtype:
2056
            return image
2057

2058
        def fn(value):
2059
            if input_dtype.is_floating_point:
2060
                if output_dtype.is_floating_point:
2061
                    return value
2062
                else:
2063
                    return round(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
2064
            else:
2065
                input_max_value = torch.iinfo(input_dtype).max
2066

2067
                if output_dtype.is_floating_point:
2068
                    return float(decimal.Decimal(value) / input_max_value)
2069
                else:
2070
                    output_max_value = torch.iinfo(output_dtype).max
2071

2072
                    if input_max_value > output_max_value:
2073
                        factor = (input_max_value + 1) // (output_max_value + 1)
2074
                        return value / factor
2075
                    else:
2076
                        factor = (output_max_value + 1) // (input_max_value + 1)
2077
                        return value * factor
2078

2079
        return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device)
2080

2081
    @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
2082
    @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
2083
    @pytest.mark.parametrize("device", cpu_and_cuda())
2084
    @pytest.mark.parametrize("scale", (True, False))
2085
    def test_image_correctness(self, input_dtype, output_dtype, device, scale):
2086
        if input_dtype.is_floating_point and output_dtype == torch.int64:
2087
            pytest.xfail("float to int64 conversion is not supported")
2088
        if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda":
2089
            pytest.xfail("uint8 to uint16 conversion is not supported on cuda")
2090

2091
        input = make_image(dtype=input_dtype, device=device)
2092

2093
        out = F.to_dtype(input, dtype=output_dtype, scale=scale)
2094
        expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale)
2095

2096
        if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
2097
            torch.testing.assert_close(out, expected, atol=1, rtol=0)
2098
        else:
2099
            torch.testing.assert_close(out, expected)
2100

2101
    def was_scaled(self, inpt):
2102
        # this assumes the target dtype is float
2103
        return inpt.max() <= 1
2104

2105
    def make_inpt_with_bbox_and_mask(self, make_input):
2106
        H, W = 10, 10
2107
        inpt_dtype = torch.uint8
2108
        bbox_dtype = torch.float32
2109
        mask_dtype = torch.bool
2110
        sample = {
2111
            "inpt": make_input(size=(H, W), dtype=inpt_dtype),
2112
            "bbox": make_bounding_boxes(canvas_size=(H, W), dtype=bbox_dtype),
2113
            "mask": make_detection_masks(size=(H, W), dtype=mask_dtype),
2114
        }
2115

2116
        return sample, inpt_dtype, bbox_dtype, mask_dtype
2117

2118
    @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
2119
    @pytest.mark.parametrize("scale", (True, False))
2120
    def test_dtype_not_a_dict(self, make_input, scale):
2121
        # assert only inpt gets transformed when dtype isn't a dict
2122

2123
        sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
2124
        out = transforms.ToDtype(dtype=torch.float32, scale=scale)(sample)
2125

2126
        assert out["inpt"].dtype != inpt_dtype
2127
        assert out["inpt"].dtype == torch.float32
2128
        if scale:
2129
            assert self.was_scaled(out["inpt"])
2130
        else:
2131
            assert not self.was_scaled(out["inpt"])
2132
        assert out["bbox"].dtype == bbox_dtype
2133
        assert out["mask"].dtype == mask_dtype
2134

2135
    @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
2136
    def test_others_catch_all_and_none(self, make_input):
2137
        # make sure "others" works as a catch-all and that None means no conversion
2138

2139
        sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
2140
        out = transforms.ToDtype(dtype={tv_tensors.Mask: torch.int64, "others": None})(sample)
2141
        assert out["inpt"].dtype == inpt_dtype
2142
        assert out["bbox"].dtype == bbox_dtype
2143
        assert out["mask"].dtype != mask_dtype
2144
        assert out["mask"].dtype == torch.int64
2145

2146
    @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
2147
    def test_typical_use_case(self, make_input):
2148
        # Typical use-case: want to convert dtype and scale for inpt and just dtype for masks.
2149
        # This just makes sure we now have a decent API for this
2150

2151
        sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
2152
        out = transforms.ToDtype(
2153
            dtype={type(sample["inpt"]): torch.float32, tv_tensors.Mask: torch.int64, "others": None}, scale=True
2154
        )(sample)
2155
        assert out["inpt"].dtype != inpt_dtype
2156
        assert out["inpt"].dtype == torch.float32
2157
        assert self.was_scaled(out["inpt"])
2158
        assert out["bbox"].dtype == bbox_dtype
2159
        assert out["mask"].dtype != mask_dtype
2160
        assert out["mask"].dtype == torch.int64
2161

2162
    @pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
2163
    def test_errors_warnings(self, make_input):
2164
        sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
2165

2166
        with pytest.raises(ValueError, match="No dtype was specified for"):
2167
            out = transforms.ToDtype(dtype={tv_tensors.Mask: torch.float32})(sample)
2168
        with pytest.warns(UserWarning, match=re.escape("plain `torch.Tensor` will *not* be transformed")):
2169
            transforms.ToDtype(dtype={torch.Tensor: torch.float32, tv_tensors.Image: torch.float32})
2170
        with pytest.warns(UserWarning, match="no scaling will be done"):
2171
            out = transforms.ToDtype(dtype={"others": None}, scale=True)(sample)
2172
        assert out["inpt"].dtype == inpt_dtype
2173
        assert out["bbox"].dtype == bbox_dtype
2174
        assert out["mask"].dtype == mask_dtype
2175

2176
    def test_uint16(self):
2177
        # These checks are probably already covered above but since uint16 is a
2178
        # newly supported dtype,  we want to be extra careful, hence this
2179
        # explicit test
2180
        img_uint16 = torch.randint(0, 65535, (256, 512), dtype=torch.uint16)
2181

2182
        img_uint8 = F.to_dtype(img_uint16, torch.uint8, scale=True)
2183
        img_float32 = F.to_dtype(img_uint16, torch.float32, scale=True)
2184
        img_int32 = F.to_dtype(img_uint16, torch.int32, scale=True)
2185

2186
        assert_equal(img_uint8, (img_uint16 / 256).to(torch.uint8))
2187
        assert_close(img_float32, (img_uint16 / 65535))
2188

2189
        assert_close(F.to_dtype(img_float32, torch.uint16, scale=True), img_uint16, rtol=0, atol=1)
2190
        # Ideally we'd check against (img_uint16 & 0xFF00) but bitwise and isn't supported for it yet
2191
        # so we simulate it by scaling down and up again.
2192
        assert_equal(F.to_dtype(img_uint8, torch.uint16, scale=True), ((img_uint16 / 256).to(torch.uint16) * 256))
2193
        assert_equal(F.to_dtype(img_int32, torch.uint16, scale=True), img_uint16)
2194

2195
        assert_equal(F.to_dtype(img_float32, torch.uint8, scale=True), img_uint8)
2196
        assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2)
2197

2198

2199
class TestAdjustBrightness:
2200
    _CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0]
2201
    _DEFAULT_BRIGHTNESS_FACTOR = _CORRECTNESS_BRIGHTNESS_FACTORS[0]
2202

2203
    @pytest.mark.parametrize(
2204
        ("kernel", "make_input"),
2205
        [
2206
            (F.adjust_brightness_image, make_image),
2207
            (F.adjust_brightness_video, make_video),
2208
        ],
2209
    )
2210
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
2211
    @pytest.mark.parametrize("device", cpu_and_cuda())
2212
    def test_kernel(self, kernel, make_input, dtype, device):
2213
        check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
2214

2215
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
2216
    def test_functional(self, make_input):
2217
        check_functional(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
2218

2219
    @pytest.mark.parametrize(
2220
        ("kernel", "input_type"),
2221
        [
2222
            (F.adjust_brightness_image, torch.Tensor),
2223
            (F._color._adjust_brightness_image_pil, PIL.Image.Image),
2224
            (F.adjust_brightness_image, tv_tensors.Image),
2225
            (F.adjust_brightness_video, tv_tensors.Video),
2226
        ],
2227
    )
2228
    def test_functional_signature(self, kernel, input_type):
2229
        check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type)
2230

2231
    @pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS)
2232
    def test_image_correctness(self, brightness_factor):
2233
        image = make_image(dtype=torch.uint8, device="cpu")
2234

2235
        actual = F.adjust_brightness(image, brightness_factor=brightness_factor)
2236
        expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor))
2237

2238
        torch.testing.assert_close(actual, expected)
2239

2240

2241
class TestCutMixMixUp:
2242
    class DummyDataset:
2243
        def __init__(self, size, num_classes, one_hot_labels):
2244
            self.size = size
2245
            self.num_classes = num_classes
2246
            self.one_hot_labels = one_hot_labels
2247
            assert size < num_classes
2248

2249
        def __getitem__(self, idx):
2250
            img = torch.rand(3, 100, 100)
2251
            label = idx  # This ensures all labels in a batch are unique and makes testing easier
2252
            if self.one_hot_labels:
2253
                label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes)
2254
            return img, label
2255

2256
        def __len__(self):
2257
            return self.size
2258

2259
    @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
2260
    @pytest.mark.parametrize("one_hot_labels", (True, False))
2261
    def test_supported_input_structure(self, T, one_hot_labels):
2262

2263
        batch_size = 32
2264
        num_classes = 100
2265

2266
        dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels)
2267

2268
        cutmix_mixup = T(num_classes=num_classes)
2269

2270
        dl = DataLoader(dataset, batch_size=batch_size)
2271

2272
        # Input sanity checks
2273
        img, target = next(iter(dl))
2274
        input_img_size = img.shape[-3:]
2275
        assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
2276
        assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
2277

2278
        def check_output(img, target):
2279
            assert img.shape == (batch_size, *input_img_size)
2280
            assert target.shape == (batch_size, num_classes)
2281
            torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size))
2282
            num_non_zero_labels = (target != 0).sum(axis=-1)
2283
            assert (num_non_zero_labels == 2).all()
2284

2285
        # After Dataloader, as unpacked input
2286
        img, target = next(iter(dl))
2287
        assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
2288
        img, target = cutmix_mixup(img, target)
2289
        check_output(img, target)
2290

2291
        # After Dataloader, as packed input
2292
        packed_from_dl = next(iter(dl))
2293
        assert isinstance(packed_from_dl, list)
2294
        img, target = cutmix_mixup(packed_from_dl)
2295
        check_output(img, target)
2296

2297
        # As collation function. We expect default_collate to be used by users.
2298
        def collate_fn_1(batch):
2299
            return cutmix_mixup(default_collate(batch))
2300

2301
        def collate_fn_2(batch):
2302
            return cutmix_mixup(*default_collate(batch))
2303

2304
        for collate_fn in (collate_fn_1, collate_fn_2):
2305
            dl = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
2306
            img, target = next(iter(dl))
2307
            check_output(img, target)
2308

2309
    @needs_cuda
2310
    @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
2311
    def test_cpu_vs_gpu(self, T):
2312
        num_classes = 10
2313
        batch_size = 3
2314
        H, W = 12, 12
2315

2316
        imgs = torch.rand(batch_size, 3, H, W)
2317
        labels = torch.randint(0, num_classes, (batch_size,))
2318
        cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
2319

2320
        _check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)
2321

2322
    @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
2323
    def test_error(self, T):
2324

2325
        num_classes = 10
2326
        batch_size = 9
2327

2328
        imgs = torch.rand(batch_size, 3, 12, 12)
2329
        cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
2330

2331
        for input_with_bad_type in (
2332
            F.to_pil_image(imgs[0]),
2333
            tv_tensors.Mask(torch.rand(12, 12)),
2334
            tv_tensors.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12),
2335
        ):
2336
            with pytest.raises(ValueError, match="does not support PIL images, "):
2337
                cutmix_mixup(input_with_bad_type)
2338

2339
        with pytest.raises(ValueError, match="Could not infer where the labels are"):
2340
            cutmix_mixup({"img": imgs, "Nothing_else": 3})
2341

2342
        with pytest.raises(ValueError, match="labels should be index based"):
2343
            # Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
2344
            # It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
2345
            cutmix_mixup(imgs)
2346

2347
        with pytest.raises(ValueError, match="When using the default labels_getter"):
2348
            cutmix_mixup(imgs, "not_a_tensor")
2349

2350
        with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
2351
            cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))
2352

2353
        with pytest.raises(ValueError, match="does not match the batch size of the labels"):
2354
            cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))
2355

2356
        with pytest.raises(ValueError, match="When passing 2D labels"):
2357
            wrong_num_classes = num_classes + 1
2358
            T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes)))
2359

2360
        with pytest.raises(ValueError, match="but got a tensor of shape"):
2361
            cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4)))
2362

2363
        with pytest.raises(ValueError, match="num_classes must be passed"):
2364
            T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,)))
2365

2366

2367
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
2368
@pytest.mark.parametrize("sample_type", (tuple, list, dict))
2369
def test_labels_getter_default_heuristic(key, sample_type):
2370
    labels = torch.arange(10)
2371
    sample = {key: labels, "another_key": "whatever"}
2372
    if sample_type is not dict:
2373
        sample = sample_type((None, sample, "whatever_again"))
2374
    assert transforms._utils._find_labels_default_heuristic(sample) is labels
2375

2376
    if key.lower() != "labels":
2377
        # If "labels" is in the dict (case-insensitive),
2378
        # it takes precedence over other keys which would otherwise be a match
2379
        d = {key: "something_else", "labels": labels}
2380
        assert transforms._utils._find_labels_default_heuristic(d) is labels
2381

2382

2383
class TestShapeGetters:
2384
    @pytest.mark.parametrize(
2385
        ("kernel", "make_input"),
2386
        [
2387
            (F.get_dimensions_image, make_image_tensor),
2388
            (F._meta._get_dimensions_image_pil, make_image_pil),
2389
            (F.get_dimensions_image, make_image),
2390
            (F.get_dimensions_video, make_video),
2391
        ],
2392
    )
2393
    def test_get_dimensions(self, kernel, make_input):
2394
        size = (10, 10)
2395
        color_space, num_channels = "RGB", 3
2396

2397
        input = make_input(size, color_space=color_space)
2398

2399
        assert kernel(input) == F.get_dimensions(input) == [num_channels, *size]
2400

2401
    @pytest.mark.parametrize(
2402
        ("kernel", "make_input"),
2403
        [
2404
            (F.get_num_channels_image, make_image_tensor),
2405
            (F._meta._get_num_channels_image_pil, make_image_pil),
2406
            (F.get_num_channels_image, make_image),
2407
            (F.get_num_channels_video, make_video),
2408
        ],
2409
    )
2410
    def test_get_num_channels(self, kernel, make_input):
2411
        color_space, num_channels = "RGB", 3
2412

2413
        input = make_input(color_space=color_space)
2414

2415
        assert kernel(input) == F.get_num_channels(input) == num_channels
2416

2417
    @pytest.mark.parametrize(
2418
        ("kernel", "make_input"),
2419
        [
2420
            (F.get_size_image, make_image_tensor),
2421
            (F._meta._get_size_image_pil, make_image_pil),
2422
            (F.get_size_image, make_image),
2423
            (F.get_size_bounding_boxes, make_bounding_boxes),
2424
            (F.get_size_mask, make_detection_masks),
2425
            (F.get_size_mask, make_segmentation_mask),
2426
            (F.get_size_video, make_video),
2427
        ],
2428
    )
2429
    def test_get_size(self, kernel, make_input):
2430
        size = (10, 10)
2431

2432
        input = make_input(size)
2433

2434
        assert kernel(input) == F.get_size(input) == list(size)
2435

2436
    @pytest.mark.parametrize(
2437
        ("kernel", "make_input"),
2438
        [
2439
            (F.get_num_frames_video, make_video_tensor),
2440
            (F.get_num_frames_video, make_video),
2441
        ],
2442
    )
2443
    def test_get_num_frames(self, kernel, make_input):
2444
        num_frames = 4
2445

2446
        input = make_input(num_frames=num_frames)
2447

2448
        assert kernel(input) == F.get_num_frames(input) == num_frames
2449

2450
    @pytest.mark.parametrize(
2451
        ("functional", "make_input"),
2452
        [
2453
            (F.get_dimensions, make_bounding_boxes),
2454
            (F.get_dimensions, make_detection_masks),
2455
            (F.get_dimensions, make_segmentation_mask),
2456
            (F.get_num_channels, make_bounding_boxes),
2457
            (F.get_num_channels, make_detection_masks),
2458
            (F.get_num_channels, make_segmentation_mask),
2459
            (F.get_num_frames, make_image_pil),
2460
            (F.get_num_frames, make_image),
2461
            (F.get_num_frames, make_bounding_boxes),
2462
            (F.get_num_frames, make_detection_masks),
2463
            (F.get_num_frames, make_segmentation_mask),
2464
        ],
2465
    )
2466
    def test_unsupported_types(self, functional, make_input):
2467
        input = make_input()
2468

2469
        with pytest.raises(TypeError, match=re.escape(str(type(input)))):
2470
            functional(input)
2471

2472

2473
class TestRegisterKernel:
2474
    @pytest.mark.parametrize("functional", (F.resize, "resize"))
2475
    def test_register_kernel(self, functional):
2476
        class CustomTVTensor(tv_tensors.TVTensor):
2477
            pass
2478

2479
        kernel_was_called = False
2480

2481
        @F.register_kernel(functional, CustomTVTensor)
2482
        def new_resize(dp, *args, **kwargs):
2483
            nonlocal kernel_was_called
2484
            kernel_was_called = True
2485
            return dp
2486

2487
        t = transforms.Resize(size=(224, 224), antialias=True)
2488

2489
        my_dp = CustomTVTensor(torch.rand(3, 10, 10))
2490
        out = t(my_dp)
2491
        assert out is my_dp
2492
        assert kernel_was_called
2493

2494
        # Sanity check to make sure we didn't override the kernel of other types
2495
        t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
2496
        t(tv_tensors.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
2497

2498
    def test_errors(self):
2499
        with pytest.raises(ValueError, match="Could not find functional with name"):
2500
            F.register_kernel("bad_name", tv_tensors.Image)
2501

2502
        with pytest.raises(ValueError, match="Kernels can only be registered on functionals"):
2503
            F.register_kernel(tv_tensors.Image, F.resize)
2504

2505
        with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
2506
            F.register_kernel(F.resize, object)
2507

2508
        with pytest.raises(ValueError, match="cannot be registered for the builtin tv_tensor classes"):
2509
            F.register_kernel(F.resize, tv_tensors.Image)(F.resize_image)
2510

2511
        class CustomTVTensor(tv_tensors.TVTensor):
2512
            pass
2513

2514
        def resize_custom_tv_tensor():
2515
            pass
2516

2517
        F.register_kernel(F.resize, CustomTVTensor)(resize_custom_tv_tensor)
2518

2519
        with pytest.raises(ValueError, match="already has a kernel registered for type"):
2520
            F.register_kernel(F.resize, CustomTVTensor)(resize_custom_tv_tensor)
2521

2522

2523
class TestGetKernel:
2524
    # We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination
2525
    # would also be fine
2526
    KERNELS = {
2527
        torch.Tensor: F.resize_image,
2528
        PIL.Image.Image: F._geometry._resize_image_pil,
2529
        tv_tensors.Image: F.resize_image,
2530
        tv_tensors.BoundingBoxes: F.resize_bounding_boxes,
2531
        tv_tensors.Mask: F.resize_mask,
2532
        tv_tensors.Video: F.resize_video,
2533
    }
2534

2535
    @pytest.mark.parametrize("input_type", [str, int, object])
2536
    def test_unsupported_types(self, input_type):
2537
        with pytest.raises(TypeError, match="supports inputs of type"):
2538
            _get_kernel(F.resize, input_type)
2539

2540
    def test_exact_match(self):
2541
        # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
2542
        # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
2543
        # here, register the kernels without wrapper, and check the exact matching afterwards.
2544
        def resize_with_pure_kernels():
2545
            pass
2546

2547
        for input_type, kernel in self.KERNELS.items():
2548
            _register_kernel_internal(resize_with_pure_kernels, input_type, tv_tensor_wrapper=False)(kernel)
2549

2550
            assert _get_kernel(resize_with_pure_kernels, input_type) is kernel
2551

2552
    def test_builtin_tv_tensor_subclass(self):
2553
        # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
2554
        # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
2555
        # here, register the kernels without wrapper, and check if subclasses of our builtin tv_tensors get dispatched
2556
        # to the kernel of the corresponding superclass
2557
        def resize_with_pure_kernels():
2558
            pass
2559

2560
        class MyImage(tv_tensors.Image):
2561
            pass
2562

2563
        class MyBoundingBoxes(tv_tensors.BoundingBoxes):
2564
            pass
2565

2566
        class MyMask(tv_tensors.Mask):
2567
            pass
2568

2569
        class MyVideo(tv_tensors.Video):
2570
            pass
2571

2572
        for custom_tv_tensor_subclass in [
2573
            MyImage,
2574
            MyBoundingBoxes,
2575
            MyMask,
2576
            MyVideo,
2577
        ]:
2578
            builtin_tv_tensor_class = custom_tv_tensor_subclass.__mro__[1]
2579
            builtin_tv_tensor_kernel = self.KERNELS[builtin_tv_tensor_class]
2580
            _register_kernel_internal(resize_with_pure_kernels, builtin_tv_tensor_class, tv_tensor_wrapper=False)(
2581
                builtin_tv_tensor_kernel
2582
            )
2583

2584
            assert _get_kernel(resize_with_pure_kernels, custom_tv_tensor_subclass) is builtin_tv_tensor_kernel
2585

2586
    def test_tv_tensor_subclass(self):
2587
        class MyTVTensor(tv_tensors.TVTensor):
2588
            pass
2589

2590
        with pytest.raises(TypeError, match="supports inputs of type"):
2591
            _get_kernel(F.resize, MyTVTensor)
2592

2593
        def resize_my_tv_tensor():
2594
            pass
2595

2596
        _register_kernel_internal(F.resize, MyTVTensor, tv_tensor_wrapper=False)(resize_my_tv_tensor)
2597

2598
        assert _get_kernel(F.resize, MyTVTensor) is resize_my_tv_tensor
2599

2600
    def test_pil_image_subclass(self):
2601
        opened_image = PIL.Image.open(Path(__file__).parent / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")
2602
        loaded_image = opened_image.convert("RGB")
2603

2604
        # check the assumptions
2605
        assert isinstance(opened_image, PIL.Image.Image)
2606
        assert type(opened_image) is not PIL.Image.Image
2607

2608
        assert type(loaded_image) is PIL.Image.Image
2609

2610
        size = [17, 11]
2611
        for image in [opened_image, loaded_image]:
2612
            kernel = _get_kernel(F.resize, type(image))
2613

2614
            output = kernel(image, size=size)
2615

2616
            assert F.get_size(output) == size
2617

2618

2619
class TestPermuteChannels:
2620
    _DEFAULT_PERMUTATION = [2, 0, 1]
2621

2622
    @pytest.mark.parametrize(
2623
        ("kernel", "make_input"),
2624
        [
2625
            (F.permute_channels_image, make_image_tensor),
2626
            # FIXME
2627
            # check_kernel does not support PIL kernel, but it should
2628
            (F.permute_channels_image, make_image),
2629
            (F.permute_channels_video, make_video),
2630
        ],
2631
    )
2632
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
2633
    @pytest.mark.parametrize("device", cpu_and_cuda())
2634
    def test_kernel(self, kernel, make_input, dtype, device):
2635
        check_kernel(kernel, make_input(dtype=dtype, device=device), permutation=self._DEFAULT_PERMUTATION)
2636

2637
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
2638
    def test_functional(self, make_input):
2639
        check_functional(F.permute_channels, make_input(), permutation=self._DEFAULT_PERMUTATION)
2640

2641
    @pytest.mark.parametrize(
2642
        ("kernel", "input_type"),
2643
        [
2644
            (F.permute_channels_image, torch.Tensor),
2645
            (F._color._permute_channels_image_pil, PIL.Image.Image),
2646
            (F.permute_channels_image, tv_tensors.Image),
2647
            (F.permute_channels_video, tv_tensors.Video),
2648
        ],
2649
    )
2650
    def test_functional_signature(self, kernel, input_type):
2651
        check_functional_kernel_signature_match(F.permute_channels, kernel=kernel, input_type=input_type)
2652

2653
    def reference_image_correctness(self, image, permutation):
2654
        channel_images = image.split(1, dim=-3)
2655
        permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation]
2656
        return tv_tensors.Image(torch.concat(permuted_channel_images, dim=-3))
2657

2658
    @pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]])
2659
    @pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)])
2660
    def test_image_correctness(self, permutation, batch_dims):
2661
        image = make_image(batch_dims=batch_dims)
2662

2663
        actual = F.permute_channels(image, permutation=permutation)
2664
        expected = self.reference_image_correctness(image, permutation=permutation)
2665

2666
        torch.testing.assert_close(actual, expected)
2667

2668

2669
class TestElastic:
2670
    def _make_displacement(self, inpt):
2671
        return torch.rand(
2672
            1,
2673
            *F.get_size(inpt),
2674
            2,
2675
            dtype=torch.float32,
2676
            device=inpt.device if isinstance(inpt, torch.Tensor) else "cpu",
2677
        )
2678

2679
    @param_value_parametrization(
2680
        interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
2681
        fill=EXHAUSTIVE_TYPE_FILLS,
2682
    )
2683
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8, torch.float16])
2684
    @pytest.mark.parametrize("device", cpu_and_cuda())
2685
    def test_kernel_image(self, param, value, dtype, device):
2686
        image = make_image_tensor(dtype=dtype, device=device)
2687

2688
        check_kernel(
2689
            F.elastic_image,
2690
            image,
2691
            displacement=self._make_displacement(image),
2692
            **{param: value},
2693
            check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
2694
            check_cuda_vs_cpu=dtype is not torch.float16,
2695
        )
2696

2697
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
2698
    @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
2699
    @pytest.mark.parametrize("device", cpu_and_cuda())
2700
    def test_kernel_bounding_boxes(self, format, dtype, device):
2701
        bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
2702

2703
        check_kernel(
2704
            F.elastic_bounding_boxes,
2705
            bounding_boxes,
2706
            format=bounding_boxes.format,
2707
            canvas_size=bounding_boxes.canvas_size,
2708
            displacement=self._make_displacement(bounding_boxes),
2709
        )
2710

2711
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
2712
    def test_kernel_mask(self, make_mask):
2713
        mask = make_mask()
2714
        check_kernel(F.elastic_mask, mask, displacement=self._make_displacement(mask))
2715

2716
    def test_kernel_video(self):
2717
        video = make_video()
2718
        check_kernel(F.elastic_video, video, displacement=self._make_displacement(video))
2719

2720
    @pytest.mark.parametrize(
2721
        "make_input",
2722
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
2723
    )
2724
    def test_functional(self, make_input):
2725
        input = make_input()
2726
        check_functional(F.elastic, input, displacement=self._make_displacement(input))
2727

2728
    @pytest.mark.parametrize(
2729
        ("kernel", "input_type"),
2730
        [
2731
            (F.elastic_image, torch.Tensor),
2732
            (F._geometry._elastic_image_pil, PIL.Image.Image),
2733
            (F.elastic_image, tv_tensors.Image),
2734
            (F.elastic_bounding_boxes, tv_tensors.BoundingBoxes),
2735
            (F.elastic_mask, tv_tensors.Mask),
2736
            (F.elastic_video, tv_tensors.Video),
2737
        ],
2738
    )
2739
    def test_functional_signature(self, kernel, input_type):
2740
        check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type)
2741

2742
    @pytest.mark.parametrize(
2743
        "make_input",
2744
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
2745
    )
2746
    def test_displacement_error(self, make_input):
2747
        input = make_input()
2748

2749
        with pytest.raises(TypeError, match="displacement should be a Tensor"):
2750
            F.elastic(input, displacement=None)
2751

2752
        with pytest.raises(ValueError, match="displacement shape should be"):
2753
            F.elastic(input, displacement=torch.rand(F.get_size(input)))
2754

2755
    @pytest.mark.parametrize(
2756
        "make_input",
2757
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
2758
    )
2759
    # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
2760
    @pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)])
2761
    @pytest.mark.parametrize("device", cpu_and_cuda())
2762
    def test_transform(self, make_input, size, device):
2763
        # We have to skip that test on M1 because it's flaky: Mismatched elements: 35 / 89205 (0.0%)
2764
        # See https://github.com/pytorch/vision/issues/8154
2765
        # All other platforms are fine, so the differences do not come from something we own in torchvision
2766
        check_v1_compatibility = False if sys.platform == "darwin" else dict(rtol=0, atol=1)
2767

2768
        check_transform(
2769
            transforms.ElasticTransform(),
2770
            make_input(size, device=device),
2771
            check_v1_compatibility=check_v1_compatibility,
2772
        )
2773

2774

2775
class TestToPureTensor:
2776
    def test_correctness(self):
2777
        input = {
2778
            "img": make_image(),
2779
            "img_tensor": make_image_tensor(),
2780
            "img_pil": make_image_pil(),
2781
            "mask": make_detection_masks(),
2782
            "video": make_video(),
2783
            "bbox": make_bounding_boxes(),
2784
            "str": "str",
2785
        }
2786

2787
        out = transforms.ToPureTensor()(input)
2788

2789
        for input_value, out_value in zip(input.values(), out.values()):
2790
            if isinstance(input_value, tv_tensors.TVTensor):
2791
                assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, tv_tensors.TVTensor)
2792
            else:
2793
                assert isinstance(out_value, type(input_value))
2794

2795

2796
class TestCrop:
2797
    INPUT_SIZE = (21, 11)
2798

2799
    CORRECTNESS_CROP_KWARGS = [
2800
        # center
2801
        dict(top=5, left=5, height=10, width=5),
2802
        # larger than input, i.e. pad
2803
        dict(top=-5, left=-5, height=30, width=20),
2804
        # sides: left, right, top, bottom
2805
        dict(top=-5, left=-5, height=30, width=10),
2806
        dict(top=-5, left=5, height=30, width=10),
2807
        dict(top=-5, left=-5, height=20, width=20),
2808
        dict(top=5, left=-5, height=20, width=20),
2809
        # corners: top-left, top-right, bottom-left, bottom-right
2810
        dict(top=-5, left=-5, height=20, width=10),
2811
        dict(top=-5, left=5, height=20, width=10),
2812
        dict(top=5, left=-5, height=20, width=10),
2813
        dict(top=5, left=5, height=20, width=10),
2814
    ]
2815
    MINIMAL_CROP_KWARGS = CORRECTNESS_CROP_KWARGS[0]
2816

2817
    @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
2818
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
2819
    @pytest.mark.parametrize("device", cpu_and_cuda())
2820
    def test_kernel_image(self, kwargs, dtype, device):
2821
        check_kernel(F.crop_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **kwargs)
2822

2823
    @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
2824
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
2825
    @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
2826
    @pytest.mark.parametrize("device", cpu_and_cuda())
2827
    def test_kernel_bounding_box(self, kwargs, format, dtype, device):
2828
        bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)
2829
        check_kernel(F.crop_bounding_boxes, bounding_boxes, format=format, **kwargs)
2830

2831
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
2832
    def test_kernel_mask(self, make_mask):
2833
        check_kernel(F.crop_mask, make_mask(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS)
2834

2835
    def test_kernel_video(self):
2836
        check_kernel(F.crop_video, make_video(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS)
2837

2838
    @pytest.mark.parametrize(
2839
        "make_input",
2840
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
2841
    )
2842
    def test_functional(self, make_input):
2843
        check_functional(F.crop, make_input(self.INPUT_SIZE), **self.MINIMAL_CROP_KWARGS)
2844

2845
    @pytest.mark.parametrize(
2846
        ("kernel", "input_type"),
2847
        [
2848
            (F.crop_image, torch.Tensor),
2849
            (F._geometry._crop_image_pil, PIL.Image.Image),
2850
            (F.crop_image, tv_tensors.Image),
2851
            (F.crop_bounding_boxes, tv_tensors.BoundingBoxes),
2852
            (F.crop_mask, tv_tensors.Mask),
2853
            (F.crop_video, tv_tensors.Video),
2854
        ],
2855
    )
2856
    def test_functional_signature(self, kernel, input_type):
2857
        check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type)
2858

2859
    @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
2860
    def test_functional_image_correctness(self, kwargs):
2861
        image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
2862

2863
        actual = F.crop(image, **kwargs)
2864
        expected = F.to_image(F.crop(F.to_pil_image(image), **kwargs))
2865

2866
        assert_equal(actual, expected)
2867

2868
    @param_value_parametrization(
2869
        size=[(10, 5), (25, 15), (25, 5), (10, 15)],
2870
        fill=EXHAUSTIVE_TYPE_FILLS,
2871
    )
2872
    @pytest.mark.parametrize(
2873
        "make_input",
2874
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
2875
    )
2876
    def test_transform(self, param, value, make_input):
2877
        input = make_input(self.INPUT_SIZE)
2878

2879
        check_sample_input = True
2880
        if param == "fill":
2881
            if isinstance(value, (tuple, list)):
2882
                if isinstance(input, tv_tensors.Mask):
2883
                    pytest.skip("F.pad_mask doesn't support non-scalar fill.")
2884
                else:
2885
                    check_sample_input = False
2886

2887
            kwargs = dict(
2888
                # 1. size is required
2889
                # 2. the fill parameter only has an affect if we need padding
2890
                size=[s + 4 for s in self.INPUT_SIZE],
2891
                fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8),
2892
            )
2893
        else:
2894
            kwargs = {param: value}
2895

2896
        check_transform(
2897
            transforms.RandomCrop(**kwargs, pad_if_needed=True),
2898
            input,
2899
            check_v1_compatibility=param != "fill" or isinstance(value, (int, float)),
2900
            check_sample_input=check_sample_input,
2901
        )
2902

2903
    @pytest.mark.parametrize("padding", [1, (1, 1), (1, 1, 1, 1)])
2904
    def test_transform_padding(self, padding):
2905
        inpt = make_image(self.INPUT_SIZE)
2906

2907
        output_size = [s + 2 for s in F.get_size(inpt)]
2908
        transform = transforms.RandomCrop(output_size, padding=padding)
2909

2910
        output = transform(inpt)
2911

2912
        assert F.get_size(output) == output_size
2913

2914
    @pytest.mark.parametrize("padding", [None, 1, (1, 1), (1, 1, 1, 1)])
2915
    def test_transform_insufficient_padding(self, padding):
2916
        inpt = make_image(self.INPUT_SIZE)
2917

2918
        output_size = [s + 3 for s in F.get_size(inpt)]
2919
        transform = transforms.RandomCrop(output_size, padding=padding)
2920

2921
        with pytest.raises(ValueError, match="larger than (padded )?input image size"):
2922
            transform(inpt)
2923

2924
    def test_transform_pad_if_needed(self):
2925
        inpt = make_image(self.INPUT_SIZE)
2926

2927
        output_size = [s * 2 for s in F.get_size(inpt)]
2928
        transform = transforms.RandomCrop(output_size, pad_if_needed=True)
2929

2930
        output = transform(inpt)
2931

2932
        assert F.get_size(output) == output_size
2933

2934
    @param_value_parametrization(
2935
        size=[(10, 5), (25, 15), (25, 5), (10, 15)],
2936
        fill=CORRECTNESS_FILLS,
2937
        padding_mode=["constant", "edge", "reflect", "symmetric"],
2938
    )
2939
    @pytest.mark.parametrize("seed", list(range(5)))
2940
    def test_transform_image_correctness(self, param, value, seed):
2941
        kwargs = {param: value}
2942
        if param != "size":
2943
            # 1. size is required
2944
            # 2. the fill / padding_mode parameters only have an affect if we need padding
2945
            kwargs["size"] = [s + 4 for s in self.INPUT_SIZE]
2946
        if param == "fill":
2947
            kwargs["fill"] = adapt_fill(kwargs["fill"], dtype=torch.uint8)
2948

2949
        transform = transforms.RandomCrop(pad_if_needed=True, **kwargs)
2950

2951
        image = make_image(self.INPUT_SIZE)
2952

2953
        with freeze_rng_state():
2954
            torch.manual_seed(seed)
2955
            actual = transform(image)
2956

2957
            torch.manual_seed(seed)
2958
            expected = F.to_image(transform(F.to_pil_image(image)))
2959

2960
        assert_equal(actual, expected)
2961

2962
    def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width):
2963
        affine_matrix = np.array(
2964
            [
2965
                [1, 0, -left],
2966
                [0, 1, -top],
2967
            ],
2968
        )
2969
        return reference_affine_bounding_boxes_helper(
2970
            bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width)
2971
        )
2972

2973
    @pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
2974
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
2975
    @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
2976
    @pytest.mark.parametrize("device", cpu_and_cuda())
2977
    def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device):
2978
        bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)
2979

2980
        actual = F.crop(bounding_boxes, **kwargs)
2981
        expected = self._reference_crop_bounding_boxes(bounding_boxes, **kwargs)
2982

2983
        assert_equal(actual, expected, atol=1, rtol=0)
2984
        assert_equal(F.get_size(actual), F.get_size(expected))
2985

2986
    @pytest.mark.parametrize("output_size", [(17, 11), (11, 17), (11, 11)])
2987
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
2988
    @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
2989
    @pytest.mark.parametrize("device", cpu_and_cuda())
2990
    @pytest.mark.parametrize("seed", list(range(5)))
2991
    def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, device, seed):
2992
        input_size = [s * 2 for s in output_size]
2993
        bounding_boxes = make_bounding_boxes(input_size, format=format, dtype=dtype, device=device)
2994

2995
        transform = transforms.RandomCrop(output_size)
2996

2997
        with freeze_rng_state():
2998
            torch.manual_seed(seed)
2999
            params = transform._get_params([bounding_boxes])
3000
            assert not params.pop("needs_pad")
3001
            del params["padding"]
3002
            assert params.pop("needs_crop")
3003

3004
            torch.manual_seed(seed)
3005
            actual = transform(bounding_boxes)
3006

3007
        expected = self._reference_crop_bounding_boxes(bounding_boxes, **params)
3008

3009
        assert_equal(actual, expected)
3010
        assert_equal(F.get_size(actual), F.get_size(expected))
3011

3012
    def test_errors(self):
3013
        with pytest.raises(ValueError, match="Please provide only two dimensions"):
3014
            transforms.RandomCrop([10, 12, 14])
3015

3016
        with pytest.raises(TypeError, match="Got inappropriate padding arg"):
3017
            transforms.RandomCrop([10, 12], padding="abc")
3018

3019
        with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
3020
            transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7])
3021

3022
        with pytest.raises(TypeError, match="Got inappropriate fill arg"):
3023
            transforms.RandomCrop([10, 12], padding=1, fill="abc")
3024

3025
        with pytest.raises(ValueError, match="Padding mode should be either"):
3026
            transforms.RandomCrop([10, 12], padding=1, padding_mode="abc")
3027

3028

3029
class TestErase:
3030
    INPUT_SIZE = (17, 11)
3031
    FUNCTIONAL_KWARGS = dict(
3032
        zip("ijhwv", [2, 2, 10, 8, torch.tensor(0.0, dtype=torch.float32, device="cpu").reshape(-1, 1, 1)])
3033
    )
3034

3035
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
3036
    @pytest.mark.parametrize("device", cpu_and_cuda())
3037
    def test_kernel_image(self, dtype, device):
3038
        check_kernel(F.erase_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **self.FUNCTIONAL_KWARGS)
3039

3040
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
3041
    @pytest.mark.parametrize("device", cpu_and_cuda())
3042
    def test_kernel_image_inplace(self, dtype, device):
3043
        input = make_image(self.INPUT_SIZE, dtype=dtype, device=device)
3044
        input_version = input._version
3045

3046
        output_out_of_place = F.erase_image(input, **self.FUNCTIONAL_KWARGS)
3047
        assert output_out_of_place.data_ptr() != input.data_ptr()
3048
        assert output_out_of_place is not input
3049

3050
        output_inplace = F.erase_image(input, **self.FUNCTIONAL_KWARGS, inplace=True)
3051
        assert output_inplace.data_ptr() == input.data_ptr()
3052
        assert output_inplace._version > input_version
3053
        assert output_inplace is input
3054

3055
        assert_equal(output_inplace, output_out_of_place)
3056

3057
    def test_kernel_video(self):
3058
        check_kernel(F.erase_video, make_video(self.INPUT_SIZE), **self.FUNCTIONAL_KWARGS)
3059

3060
    @pytest.mark.parametrize(
3061
        "make_input",
3062
        [make_image_tensor, make_image_pil, make_image, make_video],
3063
    )
3064
    def test_functional(self, make_input):
3065
        check_functional(F.erase, make_input(), **self.FUNCTIONAL_KWARGS)
3066

3067
    @pytest.mark.parametrize(
3068
        ("kernel", "input_type"),
3069
        [
3070
            (F.erase_image, torch.Tensor),
3071
            (F._augment._erase_image_pil, PIL.Image.Image),
3072
            (F.erase_image, tv_tensors.Image),
3073
            (F.erase_video, tv_tensors.Video),
3074
        ],
3075
    )
3076
    def test_functional_signature(self, kernel, input_type):
3077
        check_functional_kernel_signature_match(F.erase, kernel=kernel, input_type=input_type)
3078

3079
    @pytest.mark.parametrize(
3080
        "make_input",
3081
        [make_image_tensor, make_image_pil, make_image, make_video],
3082
    )
3083
    @pytest.mark.parametrize("device", cpu_and_cuda())
3084
    def test_transform(self, make_input, device):
3085
        input = make_input(device=device)
3086

3087
        with pytest.warns(UserWarning, match="currently passing through inputs of type"):
3088
            check_transform(
3089
                transforms.RandomErasing(p=1),
3090
                input,
3091
                check_v1_compatibility=not isinstance(input, PIL.Image.Image),
3092
            )
3093

3094
    def _reference_erase_image(self, image, *, i, j, h, w, v):
3095
        mask = torch.zeros_like(image, dtype=torch.bool)
3096
        mask[..., i : i + h, j : j + w] = True
3097

3098
        # The broadcasting and type casting logic is handled automagically in the kernel through indexing
3099
        value = torch.broadcast_to(v, (*image.shape[:-2], h, w)).to(image)
3100

3101
        erased_image = torch.empty_like(image)
3102
        erased_image[mask] = value.flatten()
3103
        erased_image[~mask] = image[~mask]
3104

3105
        return erased_image
3106

3107
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
3108
    @pytest.mark.parametrize("device", cpu_and_cuda())
3109
    def test_functional_image_correctness(self, dtype, device):
3110
        image = make_image(dtype=dtype, device=device)
3111

3112
        actual = F.erase(image, **self.FUNCTIONAL_KWARGS)
3113
        expected = self._reference_erase_image(image, **self.FUNCTIONAL_KWARGS)
3114

3115
        assert_equal(actual, expected)
3116

3117
    @param_value_parametrization(
3118
        scale=[(0.1, 0.2), [0.0, 1.0]],
3119
        ratio=[(0.3, 0.7), [0.1, 5.0]],
3120
        value=[0, 0.5, (0, 1, 0), [-0.2, 0.0, 1.3], "random"],
3121
    )
3122
    @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
3123
    @pytest.mark.parametrize("device", cpu_and_cuda())
3124
    @pytest.mark.parametrize("seed", list(range(5)))
3125
    def test_transform_image_correctness(self, param, value, dtype, device, seed):
3126
        transform = transforms.RandomErasing(**{param: value}, p=1)
3127

3128
        image = make_image(dtype=dtype, device=device)
3129

3130
        with freeze_rng_state():
3131
            torch.manual_seed(seed)
3132
            # This emulates the random apply check that happens before _get_params is called
3133
            torch.rand(1)
3134
            params = transform._get_params([image])
3135

3136
            torch.manual_seed(seed)
3137
            actual = transform(image)
3138

3139
        expected = self._reference_erase_image(image, **params)
3140

3141
        assert_equal(actual, expected)
3142

3143
    def test_transform_errors(self):
3144
        with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"):
3145
            transforms.RandomErasing(value={})
3146

3147
        with pytest.raises(ValueError, match="If value is str, it should be 'random'"):
3148
            transforms.RandomErasing(value="abc")
3149

3150
        with pytest.raises(TypeError, match="Scale should be a sequence"):
3151
            transforms.RandomErasing(scale=123)
3152

3153
        with pytest.raises(TypeError, match="Ratio should be a sequence"):
3154
            transforms.RandomErasing(ratio=123)
3155

3156
        with pytest.raises(ValueError, match="Scale should be between 0 and 1"):
3157
            transforms.RandomErasing(scale=[-1, 2])
3158

3159
        transform = transforms.RandomErasing(value=[1, 2, 3, 4])
3160

3161
        with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
3162
            transform._get_params([make_image()])
3163

3164

3165
class TestGaussianBlur:
3166
    @pytest.mark.parametrize("kernel_size", [1, 3, (3, 1), [3, 5]])
3167
    @pytest.mark.parametrize("sigma", [None, 1.0, 1, (0.5,), [0.3], (0.3, 0.7), [0.9, 0.2]])
3168
    def test_kernel_image(self, kernel_size, sigma):
3169
        check_kernel(
3170
            F.gaussian_blur_image,
3171
            make_image(),
3172
            kernel_size=kernel_size,
3173
            sigma=sigma,
3174
            check_scripted_vs_eager=not (isinstance(kernel_size, int) or isinstance(sigma, (float, int))),
3175
        )
3176

3177
    def test_kernel_image_errors(self):
3178
        image = make_image_tensor()
3179

3180
        with pytest.raises(ValueError, match="kernel_size is a sequence its length should be 2"):
3181
            F.gaussian_blur_image(image, kernel_size=[1, 2, 3])
3182

3183
        for kernel_size in [2, -1]:
3184
            with pytest.raises(ValueError, match="kernel_size should have odd and positive integers"):
3185
                F.gaussian_blur_image(image, kernel_size=kernel_size)
3186

3187
        with pytest.raises(ValueError, match="sigma is a sequence, its length should be 2"):
3188
            F.gaussian_blur_image(image, kernel_size=1, sigma=[1, 2, 3])
3189

3190
        with pytest.raises(TypeError, match="sigma should be either float or sequence of floats"):
3191
            F.gaussian_blur_image(image, kernel_size=1, sigma=object())
3192

3193
        with pytest.raises(ValueError, match="sigma should have positive values"):
3194
            F.gaussian_blur_image(image, kernel_size=1, sigma=-1)
3195

3196
    def test_kernel_video(self):
3197
        check_kernel(F.gaussian_blur_video, make_video(), kernel_size=(3, 3))
3198

3199
    @pytest.mark.parametrize(
3200
        "make_input",
3201
        [make_image_tensor, make_image_pil, make_image, make_video],
3202
    )
3203
    def test_functional(self, make_input):
3204
        check_functional(F.gaussian_blur, make_input(), kernel_size=(3, 3))
3205

3206
    @pytest.mark.parametrize(
3207
        ("kernel", "input_type"),
3208
        [
3209
            (F.gaussian_blur_image, torch.Tensor),
3210
            (F._misc._gaussian_blur_image_pil, PIL.Image.Image),
3211
            (F.gaussian_blur_image, tv_tensors.Image),
3212
            (F.gaussian_blur_video, tv_tensors.Video),
3213
        ],
3214
    )
3215
    def test_functional_signature(self, kernel, input_type):
3216
        check_functional_kernel_signature_match(F.gaussian_blur, kernel=kernel, input_type=input_type)
3217

3218
    @pytest.mark.parametrize(
3219
        "make_input",
3220
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
3221
    )
3222
    @pytest.mark.parametrize("device", cpu_and_cuda())
3223
    @pytest.mark.parametrize("sigma", [5, 2.0, (0.5, 2), [1.3, 2.7]])
3224
    def test_transform(self, make_input, device, sigma):
3225
        check_transform(transforms.GaussianBlur(kernel_size=3, sigma=sigma), make_input(device=device))
3226

3227
    def test_assertions(self):
3228
        with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"):
3229
            transforms.GaussianBlur([10, 12, 14])
3230

3231
        with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
3232
            transforms.GaussianBlur(4)
3233

3234
        with pytest.raises(ValueError, match="If sigma is a sequence its length should be 1 or 2. Got 3"):
3235
            transforms.GaussianBlur(3, sigma=[1, 2, 3])
3236

3237
        with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
3238
            transforms.GaussianBlur(3, sigma=-1.0)
3239

3240
        with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
3241
            transforms.GaussianBlur(3, sigma=[2.0, 1.0])
3242

3243
        with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
3244
            transforms.GaussianBlur(3, sigma={})
3245

3246
    @pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0], (10, 12.0), [10]])
3247
    def test__get_params(self, sigma):
3248
        transform = transforms.GaussianBlur(3, sigma=sigma)
3249
        params = transform._get_params([])
3250

3251
        if isinstance(sigma, float):
3252
            assert params["sigma"][0] == params["sigma"][1] == sigma
3253
        elif isinstance(sigma, list) and len(sigma) == 1:
3254
            assert params["sigma"][0] == params["sigma"][1] == sigma[0]
3255
        else:
3256
            assert sigma[0] <= params["sigma"][0] <= sigma[1]
3257
            assert sigma[0] <= params["sigma"][1] <= sigma[1]
3258

3259
    # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
3260
    # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
3261
    # {
3262
    #     "10_12_3__3_3_0.8": cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8),
3263
    #     "10_12_3__3_3_0.5": cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5),
3264
    #     "10_12_3__3_5_0.8": cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8),
3265
    #     "10_12_3__3_5_0.5": cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5),
3266
    #     "26_28_1__23_23_1.7": cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7),
3267
    # }
3268
    REFERENCE_GAUSSIAN_BLUR_IMAGE_RESULTS = torch.load(
3269
        Path(__file__).parent / "assets" / "gaussian_blur_opencv_results.pt",
3270
        weights_only=False,
3271
    )
3272

3273
    @pytest.mark.parametrize(
3274
        ("dimensions", "kernel_size", "sigma"),
3275
        [
3276
            ((3, 10, 12), (3, 3), 0.8),
3277
            ((3, 10, 12), (3, 3), 0.5),
3278
            ((3, 10, 12), (3, 5), 0.8),
3279
            ((3, 10, 12), (3, 5), 0.5),
3280
            ((1, 26, 28), (23, 23), 1.7),
3281
        ],
3282
    )
3283
    @pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.float16])
3284
    @pytest.mark.parametrize("device", cpu_and_cuda())
3285
    def test_functional_image_correctness(self, dimensions, kernel_size, sigma, dtype, device):
3286
        if dtype is torch.float16 and device == "cpu":
3287
            pytest.skip("The CPU implementation of float16 on CPU differs from opencv")
3288

3289
        num_channels, height, width = dimensions
3290

3291
        reference_results_key = f"{height}_{width}_{num_channels}__{kernel_size[0]}_{kernel_size[1]}_{sigma}"
3292
        expected = (
3293
            torch.tensor(self.REFERENCE_GAUSSIAN_BLUR_IMAGE_RESULTS[reference_results_key])
3294
            .reshape(height, width, num_channels)
3295
            .permute(2, 0, 1)
3296
            .to(dtype=dtype, device=device)
3297
        )
3298

3299
        image = tv_tensors.Image(
3300
            torch.arange(num_channels * height * width, dtype=torch.uint8)
3301
            .reshape(height, width, num_channels)
3302
            .permute(2, 0, 1),
3303
            dtype=dtype,
3304
            device=device,
3305
        )
3306

3307
        actual = F.gaussian_blur_image(image, kernel_size=kernel_size, sigma=sigma)
3308

3309
        torch.testing.assert_close(actual, expected, rtol=0, atol=1)
3310

3311

3312
class TestGaussianNoise:
3313
    @pytest.mark.parametrize(
3314
        "make_input",
3315
        [make_image_tensor, make_image, make_video],
3316
    )
3317
    def test_kernel(self, make_input):
3318
        check_kernel(
3319
            F.gaussian_noise,
3320
            make_input(dtype=torch.float32),
3321
            # This cannot pass because the noise on a batch in not per-image
3322
            check_batched_vs_unbatched=False,
3323
        )
3324

3325
    @pytest.mark.parametrize(
3326
        "make_input",
3327
        [make_image_tensor, make_image, make_video],
3328
    )
3329
    def test_functional(self, make_input):
3330
        check_functional(F.gaussian_noise, make_input(dtype=torch.float32))
3331

3332
    @pytest.mark.parametrize(
3333
        ("kernel", "input_type"),
3334
        [
3335
            (F.gaussian_noise, torch.Tensor),
3336
            (F.gaussian_noise_image, tv_tensors.Image),
3337
            (F.gaussian_noise_video, tv_tensors.Video),
3338
        ],
3339
    )
3340
    def test_functional_signature(self, kernel, input_type):
3341
        check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type)
3342

3343
    @pytest.mark.parametrize(
3344
        "make_input",
3345
        [make_image_tensor, make_image, make_video],
3346
    )
3347
    def test_transform(self, make_input):
3348
        def adapter(_, input, __):
3349
            # This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
3350
            # Same for PIL images
3351
            for key, value in input.items():
3352
                if isinstance(value, torch.Tensor) and not value.is_floating_point():
3353
                    input[key] = value.to(torch.float32)
3354
                if isinstance(value, PIL.Image.Image):
3355
                    input[key] = F.pil_to_tensor(value).to(torch.float32)
3356
            return input
3357

3358
        check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter)
3359

3360
    def test_bad_input(self):
3361
        with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."):
3362
            F.gaussian_noise(make_image_pil())
3363
        with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"):
3364
            F.gaussian_noise(make_image(dtype=torch.uint8))
3365
        with pytest.raises(ValueError, match="sigma shouldn't be negative"):
3366
            F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1)
3367

3368
    def test_clip(self):
3369
        img = make_image(dtype=torch.float32)
3370

3371
        out = F.gaussian_noise(img, mean=100, clip=False)
3372
        assert out.min() > 50
3373

3374
        out = F.gaussian_noise(img, mean=100, clip=True)
3375
        assert (out == 1).all()
3376

3377
        out = F.gaussian_noise(img, mean=-100, clip=False)
3378
        assert out.min() < -50
3379

3380
        out = F.gaussian_noise(img, mean=-100, clip=True)
3381
        assert (out == 0).all()
3382

3383

3384
class TestAutoAugmentTransforms:
3385
    # These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling.
3386
    # It's typically very hard to test the effect on some parameters without heavy mocking logic.
3387
    # This class adds correctness tests for the kernels that are specific to those transforms. The rest of kernels, e.g.
3388
    # rotate, are tested in their respective classes. The rest of the tests here are mostly smoke tests.
3389

3390
    def _reference_shear_translate(self, image, *, transform_id, magnitude, interpolation, fill):
3391
        if isinstance(image, PIL.Image.Image):
3392
            input = image
3393
        else:
3394
            input = F.to_pil_image(image)
3395

3396
        matrix = {
3397
            "ShearX": (1, magnitude, 0, 0, 1, 0),
3398
            "ShearY": (1, 0, 0, magnitude, 1, 0),
3399
            "TranslateX": (1, 0, -int(magnitude), 0, 1, 0),
3400
            "TranslateY": (1, 0, 0, 0, 1, -int(magnitude)),
3401
        }[transform_id]
3402

3403
        output = input.transform(
3404
            input.size, PIL.Image.AFFINE, matrix, resample=pil_modes_mapping[interpolation], fill=fill
3405
        )
3406

3407
        if isinstance(image, PIL.Image.Image):
3408
            return output
3409
        else:
3410
            return F.to_image(output)
3411

3412
    @pytest.mark.parametrize("transform_id", ["ShearX", "ShearY", "TranslateX", "TranslateY"])
3413
    @pytest.mark.parametrize("magnitude", [0.3, -0.2, 0.0])
3414
    @pytest.mark.parametrize(
3415
        "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
3416
    )
3417
    @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
3418
    @pytest.mark.parametrize("input_type", ["Tensor", "PIL"])
3419
    def test_correctness_shear_translate(self, transform_id, magnitude, interpolation, fill, input_type):
3420
        # ShearX/Y and TranslateX/Y are the only ops that are native to the AA transforms. They are modeled after the
3421
        # reference implementation:
3422
        # https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L362
3423
        # All other ops are checked in their respective dedicated tests.
3424

3425
        image = make_image(dtype=torch.uint8, device="cpu")
3426
        if input_type == "PIL":
3427
            image = F.to_pil_image(image)
3428

3429
        if "Translate" in transform_id:
3430
            # For TranslateX/Y magnitude is a value in pixels
3431
            magnitude *= min(F.get_size(image))
3432

3433
        actual = transforms.AutoAugment()._apply_image_or_video_transform(
3434
            image,
3435
            transform_id=transform_id,
3436
            magnitude=magnitude,
3437
            interpolation=interpolation,
3438
            fill={type(image): fill},
3439
        )
3440
        expected = self._reference_shear_translate(
3441
            image, transform_id=transform_id, magnitude=magnitude, interpolation=interpolation, fill=fill
3442
        )
3443

3444
        if input_type == "PIL":
3445
            actual, expected = F.to_image(actual), F.to_image(expected)
3446

3447
        if "Shear" in transform_id and input_type == "Tensor":
3448
            mae = (actual.float() - expected.float()).abs().mean()
3449
            assert mae < (12 if interpolation is transforms.InterpolationMode.NEAREST else 5)
3450
        else:
3451
            assert_close(actual, expected, rtol=0, atol=1)
3452

3453
    def _sample_input_adapter(self, transform, input, device):
3454
        adapted_input = {}
3455
        image_or_video_found = False
3456
        for key, value in input.items():
3457
            if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
3458
                # AA transforms don't support bounding boxes or masks
3459
                continue
3460
            elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
3461
                if image_or_video_found:
3462
                    # AA transforms only support a single image or video
3463
                    continue
3464
                image_or_video_found = True
3465
            adapted_input[key] = value
3466
        return adapted_input
3467

3468
    @pytest.mark.parametrize(
3469
        "transform",
3470
        [transforms.AutoAugment(), transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AugMix()],
3471
    )
3472
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
3473
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
3474
    @pytest.mark.parametrize("device", cpu_and_cuda())
3475
    def test_transform_smoke(self, transform, make_input, dtype, device):
3476
        if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"):
3477
            pytest.skip(
3478
                "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
3479
                "will degenerate to that anyway."
3480
            )
3481
        input = make_input(dtype=dtype, device=device)
3482

3483
        with freeze_rng_state():
3484
            # By default every test starts from the same random seed. This leads to minimal coverage of the sampling
3485
            # that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage,
3486
            # we build a reproducible random seed from the input type, dtype, and device.
3487
            torch.manual_seed(hash((make_input, dtype, device)))
3488

3489
            # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1
3490
            # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks
3491
            # here and only check if we can script the v2 transform and subsequently call the result.
3492
            check_transform(
3493
                transform, input, check_v1_compatibility=False, check_sample_input=self._sample_input_adapter
3494
            )
3495

3496
            if type(input) is torch.Tensor and dtype is torch.uint8:
3497
                _script(transform)(input)
3498

3499
    def test_auto_augment_policy_error(self):
3500
        with pytest.raises(ValueError, match="provided policy"):
3501
            transforms.AutoAugment(policy=None)
3502

3503
    @pytest.mark.parametrize("severity", [0, 11])
3504
    def test_aug_mix_severity_error(self, severity):
3505
        with pytest.raises(ValueError, match="severity must be between"):
3506
            transforms.AugMix(severity=severity)
3507

3508

3509
class TestConvertBoundingBoxFormat:
3510
    old_new_formats = list(itertools.permutations(iter(tv_tensors.BoundingBoxFormat), 2))
3511

3512
    @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
3513
    def test_kernel(self, old_format, new_format):
3514
        check_kernel(
3515
            F.convert_bounding_box_format,
3516
            make_bounding_boxes(format=old_format),
3517
            new_format=new_format,
3518
            old_format=old_format,
3519
        )
3520

3521
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3522
    @pytest.mark.parametrize("inplace", [False, True])
3523
    def test_kernel_noop(self, format, inplace):
3524
        input = make_bounding_boxes(format=format).as_subclass(torch.Tensor)
3525
        input_version = input._version
3526

3527
        output = F.convert_bounding_box_format(input, old_format=format, new_format=format, inplace=inplace)
3528

3529
        assert output is input
3530
        assert output.data_ptr() == input.data_ptr()
3531
        assert output._version == input_version
3532

3533
    @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
3534
    def test_kernel_inplace(self, old_format, new_format):
3535
        input = make_bounding_boxes(format=old_format).as_subclass(torch.Tensor)
3536
        input_version = input._version
3537

3538
        output_out_of_place = F.convert_bounding_box_format(input, old_format=old_format, new_format=new_format)
3539
        assert output_out_of_place.data_ptr() != input.data_ptr()
3540
        assert output_out_of_place is not input
3541

3542
        output_inplace = F.convert_bounding_box_format(
3543
            input, old_format=old_format, new_format=new_format, inplace=True
3544
        )
3545
        assert output_inplace.data_ptr() == input.data_ptr()
3546
        assert output_inplace._version > input_version
3547
        assert output_inplace is input
3548

3549
        assert_equal(output_inplace, output_out_of_place)
3550

3551
    @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
3552
    def test_functional(self, old_format, new_format):
3553
        check_functional(F.convert_bounding_box_format, make_bounding_boxes(format=old_format), new_format=new_format)
3554

3555
    @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
3556
    @pytest.mark.parametrize("format_type", ["enum", "str"])
3557
    def test_transform(self, old_format, new_format, format_type):
3558
        check_transform(
3559
            transforms.ConvertBoundingBoxFormat(new_format.name if format_type == "str" else new_format),
3560
            make_bounding_boxes(format=old_format),
3561
        )
3562

3563
    @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
3564
    def test_strings(self, old_format, new_format):
3565
        # Non-regression test for https://github.com/pytorch/vision/issues/8258
3566
        input = tv_tensors.BoundingBoxes(torch.tensor([[10, 10, 20, 20]]), format=old_format, canvas_size=(50, 50))
3567
        expected = self._reference_convert_bounding_box_format(input, new_format)
3568

3569
        old_format = old_format.name
3570
        new_format = new_format.name
3571

3572
        out_functional = F.convert_bounding_box_format(input, new_format=new_format)
3573
        out_functional_tensor = F.convert_bounding_box_format(
3574
            input.as_subclass(torch.Tensor), old_format=old_format, new_format=new_format
3575
        )
3576
        out_transform = transforms.ConvertBoundingBoxFormat(new_format)(input)
3577
        for out in (out_functional, out_functional_tensor, out_transform):
3578
            assert_equal(out, expected)
3579

3580
    def _reference_convert_bounding_box_format(self, bounding_boxes, new_format):
3581
        return tv_tensors.wrap(
3582
            torchvision.ops.box_convert(
3583
                bounding_boxes.as_subclass(torch.Tensor),
3584
                in_fmt=bounding_boxes.format.name.lower(),
3585
                out_fmt=new_format.name.lower(),
3586
            ).to(bounding_boxes.dtype),
3587
            like=bounding_boxes,
3588
            format=new_format,
3589
        )
3590

3591
    @pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
3592
    @pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
3593
    @pytest.mark.parametrize("device", cpu_and_cuda())
3594
    @pytest.mark.parametrize("fn_type", ["functional", "transform"])
3595
    def test_correctness(self, old_format, new_format, dtype, device, fn_type):
3596
        bounding_boxes = make_bounding_boxes(format=old_format, dtype=dtype, device=device)
3597

3598
        if fn_type == "functional":
3599
            fn = functools.partial(F.convert_bounding_box_format, new_format=new_format)
3600
        else:
3601
            fn = transforms.ConvertBoundingBoxFormat(format=new_format)
3602

3603
        actual = fn(bounding_boxes)
3604
        expected = self._reference_convert_bounding_box_format(bounding_boxes, new_format)
3605

3606
        assert_equal(actual, expected)
3607

3608
    def test_errors(self):
3609
        input_tv_tensor = make_bounding_boxes()
3610
        input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor)
3611

3612
        for input in [input_tv_tensor, input_pure_tensor]:
3613
            with pytest.raises(TypeError, match="missing 1 required argument: 'new_format'"):
3614
                F.convert_bounding_box_format(input)
3615

3616
        with pytest.raises(ValueError, match="`old_format` has to be passed"):
3617
            F.convert_bounding_box_format(input_pure_tensor, new_format=input_tv_tensor.format)
3618

3619
        with pytest.raises(ValueError, match="`old_format` must not be passed"):
3620
            F.convert_bounding_box_format(
3621
                input_tv_tensor, old_format=input_tv_tensor.format, new_format=input_tv_tensor.format
3622
            )
3623

3624

3625
class TestResizedCrop:
3626
    INPUT_SIZE = (17, 11)
3627
    CROP_KWARGS = dict(top=2, left=2, height=5, width=7)
3628
    OUTPUT_SIZE = (19, 32)
3629

3630
    @pytest.mark.parametrize(
3631
        ("kernel", "make_input"),
3632
        [
3633
            (F.resized_crop_image, make_image),
3634
            (F.resized_crop_bounding_boxes, make_bounding_boxes),
3635
            (F.resized_crop_mask, make_segmentation_mask),
3636
            (F.resized_crop_mask, make_detection_masks),
3637
            (F.resized_crop_video, make_video),
3638
        ],
3639
    )
3640
    def test_kernel(self, kernel, make_input):
3641
        input = make_input(self.INPUT_SIZE)
3642
        if isinstance(input, tv_tensors.BoundingBoxes):
3643
            extra_kwargs = dict(format=input.format)
3644
        elif isinstance(input, tv_tensors.Mask):
3645
            extra_kwargs = dict()
3646
        else:
3647
            extra_kwargs = dict(antialias=True)
3648

3649
        check_kernel(kernel, input, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, **extra_kwargs)
3650

3651
    @pytest.mark.parametrize(
3652
        "make_input",
3653
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
3654
    )
3655
    def test_functional(self, make_input):
3656
        check_functional(
3657
            F.resized_crop, make_input(self.INPUT_SIZE), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, antialias=True
3658
        )
3659

3660
    @pytest.mark.parametrize(
3661
        ("kernel", "input_type"),
3662
        [
3663
            (F.resized_crop_image, torch.Tensor),
3664
            (F._geometry._resized_crop_image_pil, PIL.Image.Image),
3665
            (F.resized_crop_image, tv_tensors.Image),
3666
            (F.resized_crop_bounding_boxes, tv_tensors.BoundingBoxes),
3667
            (F.resized_crop_mask, tv_tensors.Mask),
3668
            (F.resized_crop_video, tv_tensors.Video),
3669
        ],
3670
    )
3671
    def test_functional_signature(self, kernel, input_type):
3672
        check_functional_kernel_signature_match(F.resized_crop, kernel=kernel, input_type=input_type)
3673

3674
    @param_value_parametrization(
3675
        scale=[(0.1, 0.2), [0.0, 1.0]],
3676
        ratio=[(0.3, 0.7), [0.1, 5.0]],
3677
    )
3678
    @pytest.mark.parametrize(
3679
        "make_input",
3680
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
3681
    )
3682
    def test_transform(self, param, value, make_input):
3683
        check_transform(
3684
            transforms.RandomResizedCrop(size=self.OUTPUT_SIZE, **{param: value}, antialias=True),
3685
            make_input(self.INPUT_SIZE),
3686
            check_v1_compatibility=dict(rtol=0, atol=1),
3687
        )
3688

3689
    # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
3690
    # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
3691
    @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
3692
    def test_functional_image_correctness(self, interpolation):
3693
        image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
3694

3695
        actual = F.resized_crop(
3696
            image, **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation, antialias=True
3697
        )
3698
        expected = F.to_image(
3699
            F.resized_crop(
3700
                F.to_pil_image(image), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, interpolation=interpolation
3701
            )
3702
        )
3703

3704
        torch.testing.assert_close(actual, expected, atol=1, rtol=0)
3705

3706
    def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width, size):
3707
        new_height, new_width = size
3708

3709
        crop_affine_matrix = np.array(
3710
            [
3711
                [1, 0, -left],
3712
                [0, 1, -top],
3713
                [0, 0, 1],
3714
            ],
3715
        )
3716
        resize_affine_matrix = np.array(
3717
            [
3718
                [new_width / width, 0, 0],
3719
                [0, new_height / height, 0],
3720
                [0, 0, 1],
3721
            ],
3722
        )
3723
        affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :]
3724

3725
        return reference_affine_bounding_boxes_helper(
3726
            bounding_boxes,
3727
            affine_matrix=affine_matrix,
3728
            new_canvas_size=size,
3729
        )
3730

3731
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3732
    def test_functional_bounding_boxes_correctness(self, format):
3733
        bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)
3734

3735
        actual = F.resized_crop(bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE)
3736
        expected = self._reference_resized_crop_bounding_boxes(
3737
            bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE
3738
        )
3739

3740
        assert_equal(actual, expected)
3741
        assert_equal(F.get_size(actual), F.get_size(expected))
3742

3743
    def test_transform_errors_warnings(self):
3744
        with pytest.raises(ValueError, match="provide only two dimensions"):
3745
            transforms.RandomResizedCrop(size=(1, 2, 3))
3746

3747
        with pytest.raises(TypeError, match="Scale should be a sequence"):
3748
            transforms.RandomResizedCrop(size=self.INPUT_SIZE, scale=123)
3749

3750
        with pytest.raises(TypeError, match="Ratio should be a sequence"):
3751
            transforms.RandomResizedCrop(size=self.INPUT_SIZE, ratio=123)
3752

3753
        for param in ["scale", "ratio"]:
3754
            with pytest.warns(match="Scale and ratio should be of kind"):
3755
                transforms.RandomResizedCrop(size=self.INPUT_SIZE, **{param: [1, 0]})
3756

3757

3758
class TestPad:
3759
    EXHAUSTIVE_TYPE_PADDINGS = [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
3760
    CORRECTNESS_PADDINGS = [
3761
        padding
3762
        for padding in EXHAUSTIVE_TYPE_PADDINGS
3763
        if isinstance(padding, int) or isinstance(padding, list) and len(padding) > 1
3764
    ]
3765
    PADDING_MODES = ["constant", "symmetric", "edge", "reflect"]
3766

3767
    @param_value_parametrization(
3768
        padding=EXHAUSTIVE_TYPE_PADDINGS,
3769
        fill=EXHAUSTIVE_TYPE_FILLS,
3770
        padding_mode=PADDING_MODES,
3771
    )
3772
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
3773
    @pytest.mark.parametrize("device", cpu_and_cuda())
3774
    def test_kernel_image(self, param, value, dtype, device):
3775
        if param == "fill":
3776
            value = adapt_fill(value, dtype=dtype)
3777
        kwargs = {param: value}
3778
        if param != "padding":
3779
            kwargs["padding"] = [1]
3780

3781
        image = make_image(dtype=dtype, device=device)
3782

3783
        check_kernel(
3784
            F.pad_image,
3785
            image,
3786
            **kwargs,
3787
            check_scripted_vs_eager=not (
3788
                (param == "padding" and isinstance(value, int))
3789
                # See https://github.com/pytorch/vision/pull/7252#issue-1585585521 for details
3790
                or (
3791
                    param == "fill"
3792
                    and (
3793
                        isinstance(value, tuple) or (isinstance(value, list) and any(isinstance(v, int) for v in value))
3794
                    )
3795
                )
3796
            ),
3797
        )
3798

3799
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3800
    def test_kernel_bounding_boxes(self, format):
3801
        bounding_boxes = make_bounding_boxes(format=format)
3802
        check_kernel(
3803
            F.pad_bounding_boxes,
3804
            bounding_boxes,
3805
            format=bounding_boxes.format,
3806
            canvas_size=bounding_boxes.canvas_size,
3807
            padding=[1],
3808
        )
3809

3810
    @pytest.mark.parametrize("padding_mode", ["symmetric", "edge", "reflect"])
3811
    def test_kernel_bounding_boxes_errors(self, padding_mode):
3812
        bounding_boxes = make_bounding_boxes()
3813
        with pytest.raises(ValueError, match=f"'{padding_mode}' is not supported"):
3814
            F.pad_bounding_boxes(
3815
                bounding_boxes,
3816
                format=bounding_boxes.format,
3817
                canvas_size=bounding_boxes.canvas_size,
3818
                padding=[1],
3819
                padding_mode=padding_mode,
3820
            )
3821

3822
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
3823
    def test_kernel_mask(self, make_mask):
3824
        check_kernel(F.pad_mask, make_mask(), padding=[1])
3825

3826
    @pytest.mark.parametrize("fill", [[1], (0,), [1, 0, 1], (0, 1, 0)])
3827
    def test_kernel_mask_errors(self, fill):
3828
        with pytest.raises(ValueError, match="Non-scalar fill value is not supported"):
3829
            F.pad_mask(make_segmentation_mask(), padding=[1], fill=fill)
3830

3831
    def test_kernel_video(self):
3832
        check_kernel(F.pad_video, make_video(), padding=[1])
3833

3834
    @pytest.mark.parametrize(
3835
        "make_input",
3836
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
3837
    )
3838
    def test_functional(self, make_input):
3839
        check_functional(F.pad, make_input(), padding=[1])
3840

3841
    @pytest.mark.parametrize(
3842
        ("kernel", "input_type"),
3843
        [
3844
            (F.pad_image, torch.Tensor),
3845
            # The PIL kernel uses fill=0 as default rather than fill=None as all others.
3846
            # Since the whole fill story is already really inconsistent, we won't introduce yet another case to allow
3847
            # for this test to pass.
3848
            # See https://github.com/pytorch/vision/issues/6623 for a discussion.
3849
            # (F._geometry._pad_image_pil, PIL.Image.Image),
3850
            (F.pad_image, tv_tensors.Image),
3851
            (F.pad_bounding_boxes, tv_tensors.BoundingBoxes),
3852
            (F.pad_mask, tv_tensors.Mask),
3853
            (F.pad_video, tv_tensors.Video),
3854
        ],
3855
    )
3856
    def test_functional_signature(self, kernel, input_type):
3857
        check_functional_kernel_signature_match(F.pad, kernel=kernel, input_type=input_type)
3858

3859
    @pytest.mark.parametrize(
3860
        "make_input",
3861
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
3862
    )
3863
    def test_transform(self, make_input):
3864
        check_transform(transforms.Pad(padding=[1]), make_input())
3865

3866
    def test_transform_errors(self):
3867
        with pytest.raises(TypeError, match="Got inappropriate padding arg"):
3868
            transforms.Pad("abc")
3869

3870
        with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
3871
            transforms.Pad([-0.7, 0, 0.7])
3872

3873
        with pytest.raises(TypeError, match="Got inappropriate fill arg"):
3874
            transforms.Pad(12, fill="abc")
3875

3876
        with pytest.raises(ValueError, match="Padding mode should be either"):
3877
            transforms.Pad(12, padding_mode="abc")
3878

3879
    @pytest.mark.parametrize("padding", CORRECTNESS_PADDINGS)
3880
    @pytest.mark.parametrize(
3881
        ("padding_mode", "fill"),
3882
        [
3883
            *[("constant", fill) for fill in CORRECTNESS_FILLS],
3884
            *[(padding_mode, None) for padding_mode in ["symmetric", "edge", "reflect"]],
3885
        ],
3886
    )
3887
    @pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)])
3888
    def test_image_correctness(self, padding, padding_mode, fill, fn):
3889
        image = make_image(dtype=torch.uint8, device="cpu")
3890

3891
        fill = adapt_fill(fill, dtype=torch.uint8)
3892

3893
        actual = fn(image, padding=padding, padding_mode=padding_mode, fill=fill)
3894
        expected = F.to_image(F.pad(F.to_pil_image(image), padding=padding, padding_mode=padding_mode, fill=fill))
3895

3896
        assert_equal(actual, expected)
3897

3898
    def _reference_pad_bounding_boxes(self, bounding_boxes, *, padding):
3899
        if isinstance(padding, int):
3900
            padding = [padding]
3901
        left, top, right, bottom = padding * (4 // len(padding))
3902

3903
        affine_matrix = np.array(
3904
            [
3905
                [1, 0, left],
3906
                [0, 1, top],
3907
            ],
3908
        )
3909

3910
        height = bounding_boxes.canvas_size[0] + top + bottom
3911
        width = bounding_boxes.canvas_size[1] + left + right
3912

3913
        return reference_affine_bounding_boxes_helper(
3914
            bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width)
3915
        )
3916

3917
    @pytest.mark.parametrize("padding", CORRECTNESS_PADDINGS)
3918
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3919
    @pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
3920
    @pytest.mark.parametrize("device", cpu_and_cuda())
3921
    @pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)])
3922
    def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn):
3923
        bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
3924

3925
        actual = fn(bounding_boxes, padding=padding)
3926
        expected = self._reference_pad_bounding_boxes(bounding_boxes, padding=padding)
3927

3928
        assert_equal(actual, expected)
3929

3930

3931
class TestCenterCrop:
3932
    INPUT_SIZE = (17, 11)
3933
    OUTPUT_SIZES = [(3, 5), (5, 3), (4, 4), (21, 9), (13, 15), (19, 14), 3, (4,), [5], INPUT_SIZE]
3934

3935
    @pytest.mark.parametrize("output_size", OUTPUT_SIZES)
3936
    @pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
3937
    @pytest.mark.parametrize("device", cpu_and_cuda())
3938
    def test_kernel_image(self, output_size, dtype, device):
3939
        check_kernel(
3940
            F.center_crop_image,
3941
            make_image(self.INPUT_SIZE, dtype=dtype, device=device),
3942
            output_size=output_size,
3943
            check_scripted_vs_eager=not isinstance(output_size, int),
3944
        )
3945

3946
    @pytest.mark.parametrize("output_size", OUTPUT_SIZES)
3947
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
3948
    def test_kernel_bounding_boxes(self, output_size, format):
3949
        bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)
3950
        check_kernel(
3951
            F.center_crop_bounding_boxes,
3952
            bounding_boxes,
3953
            format=bounding_boxes.format,
3954
            canvas_size=bounding_boxes.canvas_size,
3955
            output_size=output_size,
3956
            check_scripted_vs_eager=not isinstance(output_size, int),
3957
        )
3958

3959
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
3960
    def test_kernel_mask(self, make_mask):
3961
        check_kernel(F.center_crop_mask, make_mask(), output_size=self.OUTPUT_SIZES[0])
3962

3963
    def test_kernel_video(self):
3964
        check_kernel(F.center_crop_video, make_video(self.INPUT_SIZE), output_size=self.OUTPUT_SIZES[0])
3965

3966
    @pytest.mark.parametrize(
3967
        "make_input",
3968
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
3969
    )
3970
    def test_functional(self, make_input):
3971
        check_functional(F.center_crop, make_input(self.INPUT_SIZE), output_size=self.OUTPUT_SIZES[0])
3972

3973
    @pytest.mark.parametrize(
3974
        ("kernel", "input_type"),
3975
        [
3976
            (F.center_crop_image, torch.Tensor),
3977
            (F._geometry._center_crop_image_pil, PIL.Image.Image),
3978
            (F.center_crop_image, tv_tensors.Image),
3979
            (F.center_crop_bounding_boxes, tv_tensors.BoundingBoxes),
3980
            (F.center_crop_mask, tv_tensors.Mask),
3981
            (F.center_crop_video, tv_tensors.Video),
3982
        ],
3983
    )
3984
    def test_functional_signature(self, kernel, input_type):
3985
        check_functional_kernel_signature_match(F.center_crop, kernel=kernel, input_type=input_type)
3986

3987
    @pytest.mark.parametrize(
3988
        "make_input",
3989
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
3990
    )
3991
    def test_transform(self, make_input):
3992
        check_transform(transforms.CenterCrop(self.OUTPUT_SIZES[0]), make_input(self.INPUT_SIZE))
3993

3994
    @pytest.mark.parametrize("output_size", OUTPUT_SIZES)
3995
    @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
3996
    def test_image_correctness(self, output_size, fn):
3997
        image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
3998

3999
        actual = fn(image, output_size)
4000
        expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))
4001

4002
        assert_equal(actual, expected)
4003

4004
    def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
4005
        image_height, image_width = bounding_boxes.canvas_size
4006
        if isinstance(output_size, int):
4007
            output_size = (output_size, output_size)
4008
        elif len(output_size) == 1:
4009
            output_size *= 2
4010
        crop_height, crop_width = output_size
4011

4012
        top = int(round((image_height - crop_height) / 2))
4013
        left = int(round((image_width - crop_width) / 2))
4014

4015
        affine_matrix = np.array(
4016
            [
4017
                [1, 0, -left],
4018
                [0, 1, -top],
4019
            ],
4020
        )
4021
        return reference_affine_bounding_boxes_helper(
4022
            bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=output_size
4023
        )
4024

4025
    @pytest.mark.parametrize("output_size", OUTPUT_SIZES)
4026
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4027
    @pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
4028
    @pytest.mark.parametrize("device", cpu_and_cuda())
4029
    @pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
4030
    def test_bounding_boxes_correctness(self, output_size, format, dtype, device, fn):
4031
        bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format, dtype=dtype, device=device)
4032

4033
        actual = fn(bounding_boxes, output_size)
4034
        expected = self._reference_center_crop_bounding_boxes(bounding_boxes, output_size)
4035

4036
        assert_equal(actual, expected)
4037

4038

4039
class TestPerspective:
4040
    COEFFICIENTS = [
4041
        [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
4042
        [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
4043
    ]
4044
    START_END_POINTS = [
4045
        ([[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]),
4046
        ([[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]),
4047
        ([[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]),
4048
    ]
4049
    MINIMAL_KWARGS = dict(startpoints=None, endpoints=None, coefficients=COEFFICIENTS[0])
4050

4051
    @param_value_parametrization(
4052
        coefficients=COEFFICIENTS,
4053
        start_end_points=START_END_POINTS,
4054
        fill=EXHAUSTIVE_TYPE_FILLS,
4055
    )
4056
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4057
    @pytest.mark.parametrize("device", cpu_and_cuda())
4058
    def test_kernel_image(self, param, value, dtype, device):
4059
        if param == "start_end_points":
4060
            kwargs = dict(zip(["startpoints", "endpoints"], value))
4061
        else:
4062
            kwargs = {"startpoints": None, "endpoints": None, param: value}
4063
        if param == "fill":
4064
            kwargs["coefficients"] = self.COEFFICIENTS[0]
4065

4066
        check_kernel(
4067
            F.perspective_image,
4068
            make_image(dtype=dtype, device=device),
4069
            **kwargs,
4070
            check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
4071
        )
4072

4073
    def test_kernel_image_error(self):
4074
        image = make_image_tensor()
4075

4076
        with pytest.raises(ValueError, match="startpoints/endpoints or the coefficients must have non `None` values"):
4077
            F.perspective_image(image, startpoints=None, endpoints=None)
4078

4079
        with pytest.raises(
4080
            ValueError, match="startpoints/endpoints and the coefficients shouldn't be defined concurrently"
4081
        ):
4082
            startpoints, endpoints = self.START_END_POINTS[0]
4083
            coefficients = self.COEFFICIENTS[0]
4084
            F.perspective_image(image, startpoints=startpoints, endpoints=endpoints, coefficients=coefficients)
4085

4086
        with pytest.raises(ValueError, match="coefficients should have 8 float values"):
4087
            F.perspective_image(image, startpoints=None, endpoints=None, coefficients=list(range(7)))
4088

4089
    @param_value_parametrization(
4090
        coefficients=COEFFICIENTS,
4091
        start_end_points=START_END_POINTS,
4092
    )
4093
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4094
    def test_kernel_bounding_boxes(self, param, value, format):
4095
        if param == "start_end_points":
4096
            kwargs = dict(zip(["startpoints", "endpoints"], value))
4097
        else:
4098
            kwargs = {"startpoints": None, "endpoints": None, param: value}
4099

4100
        bounding_boxes = make_bounding_boxes(format=format)
4101

4102
        check_kernel(
4103
            F.perspective_bounding_boxes,
4104
            bounding_boxes,
4105
            format=bounding_boxes.format,
4106
            canvas_size=bounding_boxes.canvas_size,
4107
            **kwargs,
4108
        )
4109

4110
    def test_kernel_bounding_boxes_error(self):
4111
        bounding_boxes = make_bounding_boxes()
4112
        format, canvas_size = bounding_boxes.format, bounding_boxes.canvas_size
4113
        bounding_boxes = bounding_boxes.as_subclass(torch.Tensor)
4114

4115
        with pytest.raises(RuntimeError, match="Denominator is zero"):
4116
            F.perspective_bounding_boxes(
4117
                bounding_boxes,
4118
                format=format,
4119
                canvas_size=canvas_size,
4120
                startpoints=None,
4121
                endpoints=None,
4122
                coefficients=[0.0] * 8,
4123
            )
4124

4125
    @pytest.mark.parametrize("make_mask", [make_segmentation_mask, make_detection_masks])
4126
    def test_kernel_mask(self, make_mask):
4127
        check_kernel(F.perspective_mask, make_mask(), **self.MINIMAL_KWARGS)
4128

4129
    def test_kernel_video(self):
4130
        check_kernel(F.perspective_video, make_video(), **self.MINIMAL_KWARGS)
4131

4132
    @pytest.mark.parametrize(
4133
        "make_input",
4134
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
4135
    )
4136
    def test_functional(self, make_input):
4137
        check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS)
4138

4139
    @pytest.mark.parametrize(
4140
        ("kernel", "input_type"),
4141
        [
4142
            (F.perspective_image, torch.Tensor),
4143
            (F._geometry._perspective_image_pil, PIL.Image.Image),
4144
            (F.perspective_image, tv_tensors.Image),
4145
            (F.perspective_bounding_boxes, tv_tensors.BoundingBoxes),
4146
            (F.perspective_mask, tv_tensors.Mask),
4147
            (F.perspective_video, tv_tensors.Video),
4148
        ],
4149
    )
4150
    def test_functional_signature(self, kernel, input_type):
4151
        check_functional_kernel_signature_match(F.perspective, kernel=kernel, input_type=input_type)
4152

4153
    @pytest.mark.parametrize("distortion_scale", [0.5, 0.0, 1.0])
4154
    @pytest.mark.parametrize(
4155
        "make_input",
4156
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
4157
    )
4158
    def test_transform(self, distortion_scale, make_input):
4159
        check_transform(transforms.RandomPerspective(distortion_scale=distortion_scale, p=1), make_input())
4160

4161
    @pytest.mark.parametrize("distortion_scale", [-1, 2])
4162
    def test_transform_error(self, distortion_scale):
4163
        with pytest.raises(ValueError, match="distortion_scale value should be between 0 and 1"):
4164
            transforms.RandomPerspective(distortion_scale=distortion_scale)
4165

4166
    @pytest.mark.parametrize("coefficients", COEFFICIENTS)
4167
    @pytest.mark.parametrize(
4168
        "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
4169
    )
4170
    @pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
4171
    def test_image_functional_correctness(self, coefficients, interpolation, fill):
4172
        image = make_image(dtype=torch.uint8, device="cpu")
4173

4174
        actual = F.perspective(
4175
            image, startpoints=None, endpoints=None, coefficients=coefficients, interpolation=interpolation, fill=fill
4176
        )
4177
        expected = F.to_image(
4178
            F.perspective(
4179
                F.to_pil_image(image),
4180
                startpoints=None,
4181
                endpoints=None,
4182
                coefficients=coefficients,
4183
                interpolation=interpolation,
4184
                fill=fill,
4185
            )
4186
        )
4187

4188
        if interpolation is transforms.InterpolationMode.BILINEAR:
4189
            abs_diff = (actual.float() - expected.float()).abs()
4190
            assert (abs_diff > 1).float().mean() < 7e-2
4191
            mae = abs_diff.mean()
4192
            assert mae < 3
4193
        else:
4194
            assert_equal(actual, expected)
4195

4196
    def _reference_perspective_bounding_boxes(self, bounding_boxes, *, startpoints, endpoints):
4197
        format = bounding_boxes.format
4198
        canvas_size = bounding_boxes.canvas_size
4199
        dtype = bounding_boxes.dtype
4200
        device = bounding_boxes.device
4201

4202
        coefficients = _get_perspective_coeffs(endpoints, startpoints)
4203

4204
        def perspective_bounding_boxes(bounding_boxes):
4205
            m1 = np.array(
4206
                [
4207
                    [coefficients[0], coefficients[1], coefficients[2]],
4208
                    [coefficients[3], coefficients[4], coefficients[5]],
4209
                ]
4210
            )
4211
            m2 = np.array(
4212
                [
4213
                    [coefficients[6], coefficients[7], 1.0],
4214
                    [coefficients[6], coefficients[7], 1.0],
4215
                ]
4216
            )
4217

4218
            # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
4219
            input_xyxy = F.convert_bounding_box_format(
4220
                bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True),
4221
                old_format=format,
4222
                new_format=tv_tensors.BoundingBoxFormat.XYXY,
4223
                inplace=True,
4224
            )
4225
            x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist()
4226

4227
            points = np.array(
4228
                [
4229
                    [x1, y1, 1.0],
4230
                    [x2, y1, 1.0],
4231
                    [x1, y2, 1.0],
4232
                    [x2, y2, 1.0],
4233
                ]
4234
            )
4235

4236
            numerator = points @ m1.T
4237
            denominator = points @ m2.T
4238
            transformed_points = numerator / denominator
4239

4240
            output_xyxy = torch.Tensor(
4241
                [
4242
                    float(np.min(transformed_points[:, 0])),
4243
                    float(np.min(transformed_points[:, 1])),
4244
                    float(np.max(transformed_points[:, 0])),
4245
                    float(np.max(transformed_points[:, 1])),
4246
                ]
4247
            )
4248

4249
            output = F.convert_bounding_box_format(
4250
                output_xyxy, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format
4251
            )
4252

4253
            # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
4254
            return F.clamp_bounding_boxes(
4255
                output,
4256
                format=format,
4257
                canvas_size=canvas_size,
4258
            ).to(dtype=dtype, device=device)
4259

4260
        return tv_tensors.BoundingBoxes(
4261
            torch.cat([perspective_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape(
4262
                bounding_boxes.shape
4263
            ),
4264
            format=format,
4265
            canvas_size=canvas_size,
4266
        )
4267

4268
    @pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS)
4269
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4270
    @pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
4271
    @pytest.mark.parametrize("device", cpu_and_cuda())
4272
    def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device):
4273
        bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
4274

4275
        actual = F.perspective(bounding_boxes, startpoints=startpoints, endpoints=endpoints)
4276
        expected = self._reference_perspective_bounding_boxes(
4277
            bounding_boxes, startpoints=startpoints, endpoints=endpoints
4278
        )
4279

4280
        assert_close(actual, expected, rtol=0, atol=1)
4281

4282

4283
class TestEqualize:
4284
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4285
    @pytest.mark.parametrize("device", cpu_and_cuda())
4286
    def test_kernel_image(self, dtype, device):
4287
        check_kernel(F.equalize_image, make_image(dtype=dtype, device=device))
4288

4289
    def test_kernel_video(self):
4290
        check_kernel(F.equalize_image, make_video())
4291

4292
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
4293
    def test_functional(self, make_input):
4294
        check_functional(F.equalize, make_input())
4295

4296
    @pytest.mark.parametrize(
4297
        ("kernel", "input_type"),
4298
        [
4299
            (F.equalize_image, torch.Tensor),
4300
            (F._color._equalize_image_pil, PIL.Image.Image),
4301
            (F.equalize_image, tv_tensors.Image),
4302
            (F.equalize_video, tv_tensors.Video),
4303
        ],
4304
    )
4305
    def test_functional_signature(self, kernel, input_type):
4306
        check_functional_kernel_signature_match(F.equalize, kernel=kernel, input_type=input_type)
4307

4308
    @pytest.mark.parametrize(
4309
        "make_input",
4310
        [make_image_tensor, make_image_pil, make_image, make_video],
4311
    )
4312
    def test_transform(self, make_input):
4313
        check_transform(transforms.RandomEqualize(p=1), make_input())
4314

4315
    @pytest.mark.parametrize(("low", "high"), [(0, 64), (64, 192), (192, 256), (0, 1), (127, 128), (255, 256)])
4316
    @pytest.mark.parametrize("fn", [F.equalize, transform_cls_to_functional(transforms.RandomEqualize, p=1)])
4317
    def test_image_correctness(self, low, high, fn):
4318
        # We are not using the default `make_image` here since that uniformly samples the values over the whole value
4319
        # range. Since the whole point of F.equalize is to transform an arbitrary distribution of values into a uniform
4320
        # one over the full range, the information gain is low if we already provide something really close to the
4321
        # expected value.
4322
        image = tv_tensors.Image(
4323
            torch.testing.make_tensor((3, 117, 253), dtype=torch.uint8, device="cpu", low=low, high=high)
4324
        )
4325

4326
        actual = fn(image)
4327
        expected = F.to_image(F.equalize(F.to_pil_image(image)))
4328

4329
        assert_equal(actual, expected)
4330

4331

4332
class TestUniformTemporalSubsample:
4333
    def test_kernel_video(self):
4334
        check_kernel(F.uniform_temporal_subsample_video, make_video(), num_samples=2)
4335

4336
    @pytest.mark.parametrize("make_input", [make_video_tensor, make_video])
4337
    def test_functional(self, make_input):
4338
        check_functional(F.uniform_temporal_subsample, make_input(), num_samples=2)
4339

4340
    @pytest.mark.parametrize(
4341
        ("kernel", "input_type"),
4342
        [
4343
            (F.uniform_temporal_subsample_video, torch.Tensor),
4344
            (F.uniform_temporal_subsample_video, tv_tensors.Video),
4345
        ],
4346
    )
4347
    def test_functional_signature(self, kernel, input_type):
4348
        check_functional_kernel_signature_match(F.uniform_temporal_subsample, kernel=kernel, input_type=input_type)
4349

4350
    @pytest.mark.parametrize("make_input", [make_video_tensor, make_video])
4351
    def test_transform(self, make_input):
4352
        check_transform(transforms.UniformTemporalSubsample(num_samples=2), make_input())
4353

4354
    def _reference_uniform_temporal_subsample_video(self, video, *, num_samples):
4355
        # Adapted from
4356
        # https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
4357
        t = video.shape[-4]
4358
        assert num_samples > 0 and t > 0
4359
        # Sample by nearest neighbor interpolation if num_samples > t.
4360
        indices = torch.linspace(0, t - 1, num_samples, device=video.device)
4361
        indices = torch.clamp(indices, 0, t - 1).long()
4362
        return tv_tensors.Video(torch.index_select(video, -4, indices))
4363

4364
    CORRECTNESS_NUM_FRAMES = 5
4365

4366
    @pytest.mark.parametrize("num_samples", list(range(1, CORRECTNESS_NUM_FRAMES + 1)))
4367
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4368
    @pytest.mark.parametrize("device", cpu_and_cuda())
4369
    @pytest.mark.parametrize(
4370
        "fn", [F.uniform_temporal_subsample, transform_cls_to_functional(transforms.UniformTemporalSubsample)]
4371
    )
4372
    def test_video_correctness(self, num_samples, dtype, device, fn):
4373
        video = make_video(num_frames=self.CORRECTNESS_NUM_FRAMES, dtype=dtype, device=device)
4374

4375
        actual = fn(video, num_samples=num_samples)
4376
        expected = self._reference_uniform_temporal_subsample_video(video, num_samples=num_samples)
4377

4378
        assert_equal(actual, expected)
4379

4380

4381
class TestNormalize:
4382
    MEANS_STDS = [
4383
        ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
4384
        ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
4385
    ]
4386
    MEAN, STD = MEANS_STDS[0]
4387

4388
    @pytest.mark.parametrize(("mean", "std"), [*MEANS_STDS, (0.5, 2.0)])
4389
    @pytest.mark.parametrize("device", cpu_and_cuda())
4390
    def test_kernel_image(self, mean, std, device):
4391
        check_kernel(F.normalize_image, make_image(dtype=torch.float32, device=device), mean=self.MEAN, std=self.STD)
4392

4393
    @pytest.mark.parametrize("device", cpu_and_cuda())
4394
    def test_kernel_image_inplace(self, device):
4395
        input = make_image_tensor(dtype=torch.float32, device=device)
4396
        input_version = input._version
4397

4398
        output_out_of_place = F.normalize_image(input, mean=self.MEAN, std=self.STD)
4399
        assert output_out_of_place.data_ptr() != input.data_ptr()
4400
        assert output_out_of_place is not input
4401

4402
        output_inplace = F.normalize_image(input, mean=self.MEAN, std=self.STD, inplace=True)
4403
        assert output_inplace.data_ptr() == input.data_ptr()
4404
        assert output_inplace._version > input_version
4405
        assert output_inplace is input
4406

4407
        assert_equal(output_inplace, output_out_of_place)
4408

4409
    def test_kernel_video(self):
4410
        check_kernel(F.normalize_video, make_video(dtype=torch.float32), mean=self.MEAN, std=self.STD)
4411

4412
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
4413
    def test_functional(self, make_input):
4414
        check_functional(F.normalize, make_input(dtype=torch.float32), mean=self.MEAN, std=self.STD)
4415

4416
    @pytest.mark.parametrize(
4417
        ("kernel", "input_type"),
4418
        [
4419
            (F.normalize_image, torch.Tensor),
4420
            (F.normalize_image, tv_tensors.Image),
4421
            (F.normalize_video, tv_tensors.Video),
4422
        ],
4423
    )
4424
    def test_functional_signature(self, kernel, input_type):
4425
        check_functional_kernel_signature_match(F.normalize, kernel=kernel, input_type=input_type)
4426

4427
    def test_functional_error(self):
4428
        with pytest.raises(TypeError, match="should be a float tensor"):
4429
            F.normalize_image(make_image(dtype=torch.uint8), mean=self.MEAN, std=self.STD)
4430

4431
        with pytest.raises(ValueError, match="tensor image of size"):
4432
            F.normalize_image(torch.rand(16, 16, dtype=torch.float32), mean=self.MEAN, std=self.STD)
4433

4434
        for std in [0, [0, 0, 0], [0, 1, 1]]:
4435
            with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"):
4436
                F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std)
4437

4438
    def _sample_input_adapter(self, transform, input, device):
4439
        adapted_input = {}
4440
        for key, value in input.items():
4441
            if isinstance(value, PIL.Image.Image):
4442
                # normalize doesn't support PIL images
4443
                continue
4444
            elif check_type(value, (is_pure_tensor, tv_tensors.Image, tv_tensors.Video)):
4445
                # normalize doesn't support integer images
4446
                value = F.to_dtype(value, torch.float32, scale=True)
4447
            adapted_input[key] = value
4448
        return adapted_input
4449

4450
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
4451
    def test_transform(self, make_input):
4452
        check_transform(
4453
            transforms.Normalize(mean=self.MEAN, std=self.STD),
4454
            make_input(dtype=torch.float32),
4455
            check_sample_input=self._sample_input_adapter,
4456
        )
4457

4458
    def _reference_normalize_image(self, image, *, mean, std):
4459
        image = image.numpy()
4460
        mean, std = [np.array(stat, dtype=image.dtype).reshape((-1, 1, 1)) for stat in [mean, std]]
4461
        return tv_tensors.Image((image - mean) / std)
4462

4463
    @pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
4464
    @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64])
4465
    @pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
4466
    def test_correctness_image(self, mean, std, dtype, fn):
4467
        image = make_image(dtype=dtype)
4468

4469
        actual = fn(image, mean=mean, std=std)
4470
        expected = self._reference_normalize_image(image, mean=mean, std=std)
4471

4472
        assert_equal(actual, expected)
4473

4474

4475
class TestClampBoundingBoxes:
4476
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4477
    @pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
4478
    @pytest.mark.parametrize("device", cpu_and_cuda())
4479
    def test_kernel(self, format, dtype, device):
4480
        bounding_boxes = make_bounding_boxes(format=format, dtype=dtype, device=device)
4481
        check_kernel(
4482
            F.clamp_bounding_boxes,
4483
            bounding_boxes,
4484
            format=bounding_boxes.format,
4485
            canvas_size=bounding_boxes.canvas_size,
4486
        )
4487

4488
    @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
4489
    def test_functional(self, format):
4490
        check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format))
4491

4492
    def test_errors(self):
4493
        input_tv_tensor = make_bounding_boxes()
4494
        input_pure_tensor = input_tv_tensor.as_subclass(torch.Tensor)
4495
        format, canvas_size = input_tv_tensor.format, input_tv_tensor.canvas_size
4496

4497
        for format_, canvas_size_ in [(None, None), (format, None), (None, canvas_size)]:
4498
            with pytest.raises(
4499
                ValueError, match="For pure tensor inputs, `format` and `canvas_size` have to be passed."
4500
            ):
4501
                F.clamp_bounding_boxes(input_pure_tensor, format=format_, canvas_size=canvas_size_)
4502

4503
        for format_, canvas_size_ in [(format, canvas_size), (format, None), (None, canvas_size)]:
4504
            with pytest.raises(
4505
                ValueError, match="For bounding box tv_tensor inputs, `format` and `canvas_size` must not be passed."
4506
            ):
4507
                F.clamp_bounding_boxes(input_tv_tensor, format=format_, canvas_size=canvas_size_)
4508

4509
    def test_transform(self):
4510
        check_transform(transforms.ClampBoundingBoxes(), make_bounding_boxes())
4511

4512

4513
class TestInvert:
4514
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.float32])
4515
    @pytest.mark.parametrize("device", cpu_and_cuda())
4516
    def test_kernel_image(self, dtype, device):
4517
        check_kernel(F.invert_image, make_image(dtype=dtype, device=device))
4518

4519
    def test_kernel_video(self):
4520
        check_kernel(F.invert_video, make_video())
4521

4522
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
4523
    def test_functional(self, make_input):
4524
        check_functional(F.invert, make_input())
4525

4526
    @pytest.mark.parametrize(
4527
        ("kernel", "input_type"),
4528
        [
4529
            (F.invert_image, torch.Tensor),
4530
            (F._color._invert_image_pil, PIL.Image.Image),
4531
            (F.invert_image, tv_tensors.Image),
4532
            (F.invert_video, tv_tensors.Video),
4533
        ],
4534
    )
4535
    def test_functional_signature(self, kernel, input_type):
4536
        check_functional_kernel_signature_match(F.invert, kernel=kernel, input_type=input_type)
4537

4538
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
4539
    def test_transform(self, make_input):
4540
        check_transform(transforms.RandomInvert(p=1), make_input())
4541

4542
    @pytest.mark.parametrize("fn", [F.invert, transform_cls_to_functional(transforms.RandomInvert, p=1)])
4543
    def test_correctness_image(self, fn):
4544
        image = make_image(dtype=torch.uint8, device="cpu")
4545

4546
        actual = fn(image)
4547
        expected = F.to_image(F.invert(F.to_pil_image(image)))
4548

4549
        assert_equal(actual, expected)
4550

4551

4552
class TestPosterize:
4553
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4554
    @pytest.mark.parametrize("device", cpu_and_cuda())
4555
    def test_kernel_image(self, dtype, device):
4556
        check_kernel(F.posterize_image, make_image(dtype=dtype, device=device), bits=1)
4557

4558
    def test_kernel_video(self):
4559
        check_kernel(F.posterize_video, make_video(), bits=1)
4560

4561
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
4562
    def test_functional(self, make_input):
4563
        check_functional(F.posterize, make_input(), bits=1)
4564

4565
    @pytest.mark.parametrize(
4566
        ("kernel", "input_type"),
4567
        [
4568
            (F.posterize_image, torch.Tensor),
4569
            (F._color._posterize_image_pil, PIL.Image.Image),
4570
            (F.posterize_image, tv_tensors.Image),
4571
            (F.posterize_video, tv_tensors.Video),
4572
        ],
4573
    )
4574
    def test_functional_signature(self, kernel, input_type):
4575
        check_functional_kernel_signature_match(F.posterize, kernel=kernel, input_type=input_type)
4576

4577
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
4578
    def test_transform(self, make_input):
4579
        check_transform(transforms.RandomPosterize(bits=1, p=1), make_input())
4580

4581
    @pytest.mark.parametrize("bits", [1, 4, 8])
4582
    @pytest.mark.parametrize("fn", [F.posterize, transform_cls_to_functional(transforms.RandomPosterize, p=1)])
4583
    def test_correctness_image(self, bits, fn):
4584
        image = make_image(dtype=torch.uint8, device="cpu")
4585

4586
        actual = fn(image, bits=bits)
4587
        expected = F.to_image(F.posterize(F.to_pil_image(image), bits=bits))
4588

4589
        assert_equal(actual, expected)
4590

4591

4592
class TestSolarize:
4593
    def _make_threshold(self, input, *, factor=0.5):
4594
        dtype = input.dtype if isinstance(input, torch.Tensor) else torch.uint8
4595
        return (float if dtype.is_floating_point else int)(get_max_value(dtype) * factor)
4596

4597
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4598
    @pytest.mark.parametrize("device", cpu_and_cuda())
4599
    def test_kernel_image(self, dtype, device):
4600
        image = make_image(dtype=dtype, device=device)
4601
        check_kernel(F.solarize_image, image, threshold=self._make_threshold(image))
4602

4603
    def test_kernel_video(self):
4604
        video = make_video()
4605
        check_kernel(F.solarize_video, video, threshold=self._make_threshold(video))
4606

4607
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
4608
    def test_functional(self, make_input):
4609
        input = make_input()
4610
        check_functional(F.solarize, input, threshold=self._make_threshold(input))
4611

4612
    @pytest.mark.parametrize(
4613
        ("kernel", "input_type"),
4614
        [
4615
            (F.solarize_image, torch.Tensor),
4616
            (F._color._solarize_image_pil, PIL.Image.Image),
4617
            (F.solarize_image, tv_tensors.Image),
4618
            (F.solarize_video, tv_tensors.Video),
4619
        ],
4620
    )
4621
    def test_functional_signature(self, kernel, input_type):
4622
        check_functional_kernel_signature_match(F.solarize, kernel=kernel, input_type=input_type)
4623

4624
    @pytest.mark.parametrize(("dtype", "threshold"), [(torch.uint8, 256), (torch.float, 1.5)])
4625
    def test_functional_error(self, dtype, threshold):
4626
        with pytest.raises(TypeError, match="Threshold should be less or equal the maximum value of the dtype"):
4627
            F.solarize(make_image(dtype=dtype), threshold=threshold)
4628

4629
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
4630
    def test_transform(self, make_input):
4631
        input = make_input()
4632
        check_transform(transforms.RandomSolarize(threshold=self._make_threshold(input), p=1), input)
4633

4634
    @pytest.mark.parametrize("threshold_factor", [0.0, 0.1, 0.5, 0.9, 1.0])
4635
    @pytest.mark.parametrize("fn", [F.solarize, transform_cls_to_functional(transforms.RandomSolarize, p=1)])
4636
    def test_correctness_image(self, threshold_factor, fn):
4637
        image = make_image(dtype=torch.uint8, device="cpu")
4638
        threshold = self._make_threshold(image, factor=threshold_factor)
4639

4640
        actual = fn(image, threshold=threshold)
4641
        expected = F.to_image(F.solarize(F.to_pil_image(image), threshold=threshold))
4642

4643
        assert_equal(actual, expected)
4644

4645

4646
class TestAutocontrast:
4647
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.float32])
4648
    @pytest.mark.parametrize("device", cpu_and_cuda())
4649
    def test_kernel_image(self, dtype, device):
4650
        check_kernel(F.autocontrast_image, make_image(dtype=dtype, device=device))
4651

4652
    def test_kernel_video(self):
4653
        check_kernel(F.autocontrast_video, make_video())
4654

4655
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
4656
    def test_functional(self, make_input):
4657
        check_functional(F.autocontrast, make_input())
4658

4659
    @pytest.mark.parametrize(
4660
        ("kernel", "input_type"),
4661
        [
4662
            (F.autocontrast_image, torch.Tensor),
4663
            (F._color._autocontrast_image_pil, PIL.Image.Image),
4664
            (F.autocontrast_image, tv_tensors.Image),
4665
            (F.autocontrast_video, tv_tensors.Video),
4666
        ],
4667
    )
4668
    def test_functional_signature(self, kernel, input_type):
4669
        check_functional_kernel_signature_match(F.autocontrast, kernel=kernel, input_type=input_type)
4670

4671
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
4672
    def test_transform(self, make_input):
4673
        check_transform(transforms.RandomAutocontrast(p=1), make_input(), check_v1_compatibility=dict(rtol=0, atol=1))
4674

4675
    @pytest.mark.parametrize("fn", [F.autocontrast, transform_cls_to_functional(transforms.RandomAutocontrast, p=1)])
4676
    def test_correctness_image(self, fn):
4677
        image = make_image(dtype=torch.uint8, device="cpu")
4678

4679
        actual = fn(image)
4680
        expected = F.to_image(F.autocontrast(F.to_pil_image(image)))
4681

4682
        assert_close(actual, expected, rtol=0, atol=1)
4683

4684

4685
class TestAdjustSharpness:
4686
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4687
    @pytest.mark.parametrize("device", cpu_and_cuda())
4688
    def test_kernel_image(self, dtype, device):
4689
        check_kernel(F.adjust_sharpness_image, make_image(dtype=dtype, device=device), sharpness_factor=0.5)
4690

4691
    def test_kernel_video(self):
4692
        check_kernel(F.adjust_sharpness_video, make_video(), sharpness_factor=0.5)
4693

4694
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
4695
    def test_functional(self, make_input):
4696
        check_functional(F.adjust_sharpness, make_input(), sharpness_factor=0.5)
4697

4698
    @pytest.mark.parametrize(
4699
        ("kernel", "input_type"),
4700
        [
4701
            (F.adjust_sharpness_image, torch.Tensor),
4702
            (F._color._adjust_sharpness_image_pil, PIL.Image.Image),
4703
            (F.adjust_sharpness_image, tv_tensors.Image),
4704
            (F.adjust_sharpness_video, tv_tensors.Video),
4705
        ],
4706
    )
4707
    def test_functional_signature(self, kernel, input_type):
4708
        check_functional_kernel_signature_match(F.adjust_sharpness, kernel=kernel, input_type=input_type)
4709

4710
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
4711
    def test_transform(self, make_input):
4712
        check_transform(transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1), make_input())
4713

4714
    def test_functional_error(self):
4715
        with pytest.raises(TypeError, match="can have 1 or 3 channels"):
4716
            F.adjust_sharpness(make_image(color_space="RGBA"), sharpness_factor=0.5)
4717

4718
        with pytest.raises(ValueError, match="is not non-negative"):
4719
            F.adjust_sharpness(make_image(), sharpness_factor=-1)
4720

4721
    @pytest.mark.parametrize("sharpness_factor", [0.1, 0.5, 1.0])
4722
    @pytest.mark.parametrize(
4723
        "fn", [F.adjust_sharpness, transform_cls_to_functional(transforms.RandomAdjustSharpness, p=1)]
4724
    )
4725
    def test_correctness_image(self, sharpness_factor, fn):
4726
        image = make_image(dtype=torch.uint8, device="cpu")
4727

4728
        actual = fn(image, sharpness_factor=sharpness_factor)
4729
        expected = F.to_image(F.adjust_sharpness(F.to_pil_image(image), sharpness_factor=sharpness_factor))
4730

4731
        assert_equal(actual, expected)
4732

4733

4734
class TestAdjustContrast:
4735
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4736
    @pytest.mark.parametrize("device", cpu_and_cuda())
4737
    def test_kernel_image(self, dtype, device):
4738
        check_kernel(F.adjust_contrast_image, make_image(dtype=dtype, device=device), contrast_factor=0.5)
4739

4740
    def test_kernel_video(self):
4741
        check_kernel(F.adjust_contrast_video, make_video(), contrast_factor=0.5)
4742

4743
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
4744
    def test_functional(self, make_input):
4745
        check_functional(F.adjust_contrast, make_input(), contrast_factor=0.5)
4746

4747
    @pytest.mark.parametrize(
4748
        ("kernel", "input_type"),
4749
        [
4750
            (F.adjust_contrast_image, torch.Tensor),
4751
            (F._color._adjust_contrast_image_pil, PIL.Image.Image),
4752
            (F.adjust_contrast_image, tv_tensors.Image),
4753
            (F.adjust_contrast_video, tv_tensors.Video),
4754
        ],
4755
    )
4756
    def test_functional_signature(self, kernel, input_type):
4757
        check_functional_kernel_signature_match(F.adjust_contrast, kernel=kernel, input_type=input_type)
4758

4759
    def test_functional_error(self):
4760
        with pytest.raises(TypeError, match="permitted channel values are 1 or 3"):
4761
            F.adjust_contrast(make_image(color_space="RGBA"), contrast_factor=0.5)
4762

4763
        with pytest.raises(ValueError, match="is not non-negative"):
4764
            F.adjust_contrast(make_image(), contrast_factor=-1)
4765

4766
    @pytest.mark.parametrize("contrast_factor", [0.1, 0.5, 1.0])
4767
    def test_correctness_image(self, contrast_factor):
4768
        image = make_image(dtype=torch.uint8, device="cpu")
4769

4770
        actual = F.adjust_contrast(image, contrast_factor=contrast_factor)
4771
        expected = F.to_image(F.adjust_contrast(F.to_pil_image(image), contrast_factor=contrast_factor))
4772

4773
        assert_close(actual, expected, rtol=0, atol=1)
4774

4775

4776
class TestAdjustGamma:
4777
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4778
    @pytest.mark.parametrize("device", cpu_and_cuda())
4779
    def test_kernel_image(self, dtype, device):
4780
        check_kernel(F.adjust_gamma_image, make_image(dtype=dtype, device=device), gamma=0.5)
4781

4782
    def test_kernel_video(self):
4783
        check_kernel(F.adjust_gamma_video, make_video(), gamma=0.5)
4784

4785
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
4786
    def test_functional(self, make_input):
4787
        check_functional(F.adjust_gamma, make_input(), gamma=0.5)
4788

4789
    @pytest.mark.parametrize(
4790
        ("kernel", "input_type"),
4791
        [
4792
            (F.adjust_gamma_image, torch.Tensor),
4793
            (F._color._adjust_gamma_image_pil, PIL.Image.Image),
4794
            (F.adjust_gamma_image, tv_tensors.Image),
4795
            (F.adjust_gamma_video, tv_tensors.Video),
4796
        ],
4797
    )
4798
    def test_functional_signature(self, kernel, input_type):
4799
        check_functional_kernel_signature_match(F.adjust_gamma, kernel=kernel, input_type=input_type)
4800

4801
    def test_functional_error(self):
4802
        with pytest.raises(ValueError, match="Gamma should be a non-negative real number"):
4803
            F.adjust_gamma(make_image(), gamma=-1)
4804

4805
    @pytest.mark.parametrize("gamma", [0.1, 0.5, 1.0])
4806
    @pytest.mark.parametrize("gain", [0.1, 1.0, 2.0])
4807
    def test_correctness_image(self, gamma, gain):
4808
        image = make_image(dtype=torch.uint8, device="cpu")
4809

4810
        actual = F.adjust_gamma(image, gamma=gamma, gain=gain)
4811
        expected = F.to_image(F.adjust_gamma(F.to_pil_image(image), gamma=gamma, gain=gain))
4812

4813
        assert_equal(actual, expected)
4814

4815

4816
class TestAdjustHue:
4817
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4818
    @pytest.mark.parametrize("device", cpu_and_cuda())
4819
    def test_kernel_image(self, dtype, device):
4820
        check_kernel(F.adjust_hue_image, make_image(dtype=dtype, device=device), hue_factor=0.25)
4821

4822
    def test_kernel_video(self):
4823
        check_kernel(F.adjust_hue_video, make_video(), hue_factor=0.25)
4824

4825
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
4826
    def test_functional(self, make_input):
4827
        check_functional(F.adjust_hue, make_input(), hue_factor=0.25)
4828

4829
    @pytest.mark.parametrize(
4830
        ("kernel", "input_type"),
4831
        [
4832
            (F.adjust_hue_image, torch.Tensor),
4833
            (F._color._adjust_hue_image_pil, PIL.Image.Image),
4834
            (F.adjust_hue_image, tv_tensors.Image),
4835
            (F.adjust_hue_video, tv_tensors.Video),
4836
        ],
4837
    )
4838
    def test_functional_signature(self, kernel, input_type):
4839
        check_functional_kernel_signature_match(F.adjust_hue, kernel=kernel, input_type=input_type)
4840

4841
    def test_functional_error(self):
4842
        with pytest.raises(TypeError, match="permitted channel values are 1 or 3"):
4843
            F.adjust_hue(make_image(color_space="RGBA"), hue_factor=0.25)
4844

4845
        for hue_factor in [-1, 1]:
4846
            with pytest.raises(ValueError, match=re.escape("is not in [-0.5, 0.5]")):
4847
                F.adjust_hue(make_image(), hue_factor=hue_factor)
4848

4849
    @pytest.mark.parametrize("hue_factor", [-0.5, -0.3, 0.0, 0.2, 0.5])
4850
    def test_correctness_image(self, hue_factor):
4851
        image = make_image(dtype=torch.uint8, device="cpu")
4852

4853
        actual = F.adjust_hue(image, hue_factor=hue_factor)
4854
        expected = F.to_image(F.adjust_hue(F.to_pil_image(image), hue_factor=hue_factor))
4855

4856
        mae = (actual.float() - expected.float()).abs().mean()
4857
        assert mae < 2
4858

4859

4860
class TestAdjustSaturation:
4861
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4862
    @pytest.mark.parametrize("device", cpu_and_cuda())
4863
    def test_kernel_image(self, dtype, device):
4864
        check_kernel(F.adjust_saturation_image, make_image(dtype=dtype, device=device), saturation_factor=0.5)
4865

4866
    def test_kernel_video(self):
4867
        check_kernel(F.adjust_saturation_video, make_video(), saturation_factor=0.5)
4868

4869
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
4870
    def test_functional(self, make_input):
4871
        check_functional(F.adjust_saturation, make_input(), saturation_factor=0.5)
4872

4873
    @pytest.mark.parametrize(
4874
        ("kernel", "input_type"),
4875
        [
4876
            (F.adjust_saturation_image, torch.Tensor),
4877
            (F._color._adjust_saturation_image_pil, PIL.Image.Image),
4878
            (F.adjust_saturation_image, tv_tensors.Image),
4879
            (F.adjust_saturation_video, tv_tensors.Video),
4880
        ],
4881
    )
4882
    def test_functional_signature(self, kernel, input_type):
4883
        check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type)
4884

4885
    def test_functional_error(self):
4886
        with pytest.raises(TypeError, match="permitted channel values are 1 or 3"):
4887
            F.adjust_saturation(make_image(color_space="RGBA"), saturation_factor=0.5)
4888

4889
        with pytest.raises(ValueError, match="is not non-negative"):
4890
            F.adjust_saturation(make_image(), saturation_factor=-1)
4891

4892
    @pytest.mark.parametrize("saturation_factor", [0.1, 0.5, 1.0])
4893
    def test_correctness_image(self, saturation_factor):
4894
        image = make_image(dtype=torch.uint8, device="cpu")
4895

4896
        actual = F.adjust_saturation(image, saturation_factor=saturation_factor)
4897
        expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor))
4898

4899
        assert_close(actual, expected, rtol=0, atol=1)
4900

4901

4902
class TestFiveTenCrop:
4903
    INPUT_SIZE = (17, 11)
4904
    OUTPUT_SIZE = (3, 5)
4905

4906
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
4907
    @pytest.mark.parametrize("device", cpu_and_cuda())
4908
    @pytest.mark.parametrize("kernel", [F.five_crop_image, F.ten_crop_image])
4909
    def test_kernel_image(self, dtype, device, kernel):
4910
        check_kernel(
4911
            kernel,
4912
            make_image(self.INPUT_SIZE, dtype=dtype, device=device),
4913
            size=self.OUTPUT_SIZE,
4914
            check_batched_vs_unbatched=False,
4915
        )
4916

4917
    @pytest.mark.parametrize("kernel", [F.five_crop_video, F.ten_crop_video])
4918
    def test_kernel_video(self, kernel):
4919
        check_kernel(kernel, make_video(self.INPUT_SIZE), size=self.OUTPUT_SIZE, check_batched_vs_unbatched=False)
4920

4921
    def _functional_wrapper(self, fn):
4922
        # This wrapper is needed to make five_crop / ten_crop compatible with check_functional, since that requires a
4923
        # single output rather than a sequence.
4924
        @functools.wraps(fn)
4925
        def wrapper(*args, **kwargs):
4926
            outputs = fn(*args, **kwargs)
4927
            return outputs[0]
4928

4929
        return wrapper
4930

4931
    @pytest.mark.parametrize(
4932
        "make_input",
4933
        [make_image_tensor, make_image_pil, make_image, make_video],
4934
    )
4935
    @pytest.mark.parametrize("functional", [F.five_crop, F.ten_crop])
4936
    def test_functional(self, make_input, functional):
4937
        check_functional(
4938
            self._functional_wrapper(functional),
4939
            make_input(self.INPUT_SIZE),
4940
            size=self.OUTPUT_SIZE,
4941
            check_scripted_smoke=False,
4942
        )
4943

4944
    @pytest.mark.parametrize(
4945
        ("functional", "kernel", "input_type"),
4946
        [
4947
            (F.five_crop, F.five_crop_image, torch.Tensor),
4948
            (F.five_crop, F._geometry._five_crop_image_pil, PIL.Image.Image),
4949
            (F.five_crop, F.five_crop_image, tv_tensors.Image),
4950
            (F.five_crop, F.five_crop_video, tv_tensors.Video),
4951
            (F.ten_crop, F.ten_crop_image, torch.Tensor),
4952
            (F.ten_crop, F._geometry._ten_crop_image_pil, PIL.Image.Image),
4953
            (F.ten_crop, F.ten_crop_image, tv_tensors.Image),
4954
            (F.ten_crop, F.ten_crop_video, tv_tensors.Video),
4955
        ],
4956
    )
4957
    def test_functional_signature(self, functional, kernel, input_type):
4958
        check_functional_kernel_signature_match(functional, kernel=kernel, input_type=input_type)
4959

4960
    class _TransformWrapper(nn.Module):
4961
        # This wrapper is needed to make FiveCrop / TenCrop compatible with check_transform, since that requires a
4962
        # single output rather than a sequence.
4963
        _v1_transform_cls = None
4964

4965
        def _extract_params_for_v1_transform(self):
4966
            return dict(five_ten_crop_transform=self.five_ten_crop_transform)
4967

4968
        def __init__(self, five_ten_crop_transform):
4969
            super().__init__()
4970
            type(self)._v1_transform_cls = type(self)
4971
            self.five_ten_crop_transform = five_ten_crop_transform
4972

4973
        def forward(self, input: torch.Tensor) -> torch.Tensor:
4974
            outputs = self.five_ten_crop_transform(input)
4975
            return outputs[0]
4976

4977
    @pytest.mark.parametrize(
4978
        "make_input",
4979
        [make_image_tensor, make_image_pil, make_image, make_video],
4980
    )
4981
    @pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
4982
    def test_transform(self, make_input, transform_cls):
4983
        check_transform(
4984
            self._TransformWrapper(transform_cls(size=self.OUTPUT_SIZE)),
4985
            make_input(self.INPUT_SIZE),
4986
            check_sample_input=False,
4987
        )
4988

4989
    @pytest.mark.parametrize("make_input", [make_bounding_boxes, make_detection_masks])
4990
    @pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
4991
    def test_transform_error(self, make_input, transform_cls):
4992
        transform = transform_cls(size=self.OUTPUT_SIZE)
4993

4994
        with pytest.raises(TypeError, match="not supported"):
4995
            transform(make_input(self.INPUT_SIZE))
4996

4997
    @pytest.mark.parametrize("fn", [F.five_crop, transform_cls_to_functional(transforms.FiveCrop)])
4998
    def test_correctness_image_five_crop(self, fn):
4999
        image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
5000

5001
        actual = fn(image, size=self.OUTPUT_SIZE)
5002
        expected = F.five_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE)
5003

5004
        assert isinstance(actual, tuple)
5005
        assert_equal(actual, [F.to_image(e) for e in expected])
5006

5007
    @pytest.mark.parametrize("fn_or_class", [F.ten_crop, transforms.TenCrop])
5008
    @pytest.mark.parametrize("vertical_flip", [False, True])
5009
    def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip):
5010
        if fn_or_class is transforms.TenCrop:
5011
            fn = transform_cls_to_functional(fn_or_class, size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
5012
            kwargs = dict()
5013
        else:
5014
            fn = fn_or_class
5015
            kwargs = dict(size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
5016

5017
        image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
5018

5019
        actual = fn(image, **kwargs)
5020
        expected = F.ten_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
5021

5022
        assert isinstance(actual, tuple)
5023
        assert_equal(actual, [F.to_image(e) for e in expected])
5024

5025

5026
class TestColorJitter:
5027
    @pytest.mark.parametrize(
5028
        "make_input",
5029
        [make_image_tensor, make_image_pil, make_image, make_video],
5030
    )
5031
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5032
    @pytest.mark.parametrize("device", cpu_and_cuda())
5033
    def test_transform(self, make_input, dtype, device):
5034
        if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"):
5035
            pytest.skip(
5036
                "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
5037
                "will degenerate to that anyway."
5038
            )
5039

5040
        check_transform(
5041
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25),
5042
            make_input(dtype=dtype, device=device),
5043
        )
5044

5045
    def test_transform_noop(self):
5046
        input = make_image()
5047
        input_version = input._version
5048

5049
        transform = transforms.ColorJitter()
5050
        output = transform(input)
5051

5052
        assert output is input
5053
        assert output.data_ptr() == input.data_ptr()
5054
        assert output._version == input_version
5055

5056
    def test_transform_error(self):
5057
        with pytest.raises(ValueError, match="must be non negative"):
5058
            transforms.ColorJitter(brightness=-1)
5059

5060
        for brightness in [object(), [1, 2, 3]]:
5061
            with pytest.raises(TypeError, match="single number or a sequence with length 2"):
5062
                transforms.ColorJitter(brightness=brightness)
5063

5064
        with pytest.raises(ValueError, match="values should be between"):
5065
            transforms.ColorJitter(brightness=(-1, 0.5))
5066

5067
        with pytest.raises(ValueError, match="values should be between"):
5068
            transforms.ColorJitter(hue=1)
5069

5070
    @pytest.mark.parametrize("brightness", [None, 0.1, (0.2, 0.3)])
5071
    @pytest.mark.parametrize("contrast", [None, 0.4, (0.5, 0.6)])
5072
    @pytest.mark.parametrize("saturation", [None, 0.7, (0.8, 0.9)])
5073
    @pytest.mark.parametrize("hue", [None, 0.3, (-0.1, 0.2)])
5074
    def test_transform_correctness(self, brightness, contrast, saturation, hue):
5075
        image = make_image(dtype=torch.uint8, device="cpu")
5076

5077
        transform = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
5078

5079
        with freeze_rng_state():
5080
            torch.manual_seed(0)
5081
            actual = transform(image)
5082

5083
            torch.manual_seed(0)
5084
            expected = F.to_image(transform(F.to_pil_image(image)))
5085

5086
        mae = (actual.float() - expected.float()).abs().mean()
5087
        assert mae < 2
5088

5089

5090
class TestRgbToGrayscale:
5091
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5092
    @pytest.mark.parametrize("device", cpu_and_cuda())
5093
    def test_kernel_image(self, dtype, device):
5094
        check_kernel(F.rgb_to_grayscale_image, make_image(dtype=dtype, device=device))
5095

5096
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
5097
    def test_functional(self, make_input):
5098
        check_functional(F.rgb_to_grayscale, make_input())
5099

5100
    @pytest.mark.parametrize(
5101
        ("kernel", "input_type"),
5102
        [
5103
            (F.rgb_to_grayscale_image, torch.Tensor),
5104
            (F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
5105
            (F.rgb_to_grayscale_image, tv_tensors.Image),
5106
        ],
5107
    )
5108
    def test_functional_signature(self, kernel, input_type):
5109
        check_functional_kernel_signature_match(F.rgb_to_grayscale, kernel=kernel, input_type=input_type)
5110

5111
    @pytest.mark.parametrize("transform", [transforms.Grayscale(), transforms.RandomGrayscale(p=1)])
5112
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
5113
    def test_transform(self, transform, make_input):
5114
        check_transform(transform, make_input())
5115

5116
    @pytest.mark.parametrize("num_output_channels", [1, 3])
5117
    @pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5118
    @pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
5119
    def test_image_correctness(self, num_output_channels, color_space, fn):
5120
        image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space)
5121

5122
        actual = fn(image, num_output_channels=num_output_channels)
5123
        expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))
5124

5125
        assert_equal(actual, expected, rtol=0, atol=1)
5126

5127
    def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self):
5128
        image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
5129

5130
        output_image = F.rgb_to_grayscale(image, num_output_channels=3)
5131
        assert_equal(output_image[0][0][0], output_image[1][0][0])
5132
        output_image[0][0][0] = output_image[0][0][0] + 1
5133
        assert output_image[0][0][0] != output_image[1][0][0]
5134

5135
    @pytest.mark.parametrize("num_input_channels", [1, 3])
5136
    def test_random_transform_correctness(self, num_input_channels):
5137
        image = make_image(
5138
            color_space={
5139
                1: "GRAY",
5140
                3: "RGB",
5141
            }[num_input_channels],
5142
            dtype=torch.uint8,
5143
            device="cpu",
5144
        )
5145

5146
        transform = transforms.RandomGrayscale(p=1)
5147

5148
        actual = transform(image)
5149
        expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_input_channels))
5150

5151
        assert_equal(actual, expected, rtol=0, atol=1)
5152

5153

5154
class TestGrayscaleToRgb:
5155
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5156
    @pytest.mark.parametrize("device", cpu_and_cuda())
5157
    def test_kernel_image(self, dtype, device):
5158
        check_kernel(F.grayscale_to_rgb_image, make_image(dtype=dtype, device=device))
5159

5160
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
5161
    def test_functional(self, make_input):
5162
        check_functional(F.grayscale_to_rgb, make_input())
5163

5164
    @pytest.mark.parametrize(
5165
        ("kernel", "input_type"),
5166
        [
5167
            (F.rgb_to_grayscale_image, torch.Tensor),
5168
            (F._color._rgb_to_grayscale_image_pil, PIL.Image.Image),
5169
            (F.rgb_to_grayscale_image, tv_tensors.Image),
5170
        ],
5171
    )
5172
    def test_functional_signature(self, kernel, input_type):
5173
        check_functional_kernel_signature_match(F.grayscale_to_rgb, kernel=kernel, input_type=input_type)
5174

5175
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image])
5176
    def test_transform(self, make_input):
5177
        check_transform(transforms.RGB(), make_input(color_space="GRAY"))
5178

5179
    @pytest.mark.parametrize("fn", [F.grayscale_to_rgb, transform_cls_to_functional(transforms.RGB)])
5180
    def test_image_correctness(self, fn):
5181
        image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
5182

5183
        actual = fn(image)
5184
        expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))
5185

5186
        assert_equal(actual, expected, rtol=0, atol=1)
5187

5188
    def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self):
5189
        image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
5190

5191
        output_image = F.grayscale_to_rgb(image)
5192
        assert_equal(output_image[0][0][0], output_image[1][0][0])
5193
        output_image[0][0][0] = output_image[0][0][0] + 1
5194
        assert output_image[0][0][0] != output_image[1][0][0]
5195

5196
    def test_rgb_image_is_unchanged(self):
5197
        image = make_image(dtype=torch.uint8, device="cpu", color_space="RGB")
5198
        assert_equal(image.shape[-3], 3)
5199
        assert_equal(F.grayscale_to_rgb(image), image)
5200

5201

5202
class TestRandomZoomOut:
5203
    # Tests are light because this largely relies on the already tested `pad` kernels.
5204

5205
    @pytest.mark.parametrize(
5206
        "make_input",
5207
        [
5208
            make_image_tensor,
5209
            make_image_pil,
5210
            make_image,
5211
            make_bounding_boxes,
5212
            make_segmentation_mask,
5213
            make_detection_masks,
5214
            make_video,
5215
        ],
5216
    )
5217
    def test_transform(self, make_input):
5218
        check_transform(transforms.RandomZoomOut(p=1), make_input())
5219

5220
    def test_transform_error(self):
5221
        for side_range in [None, 1, [1, 2, 3]]:
5222
            with pytest.raises(
5223
                ValueError if isinstance(side_range, list) else TypeError, match="should be a sequence of length 2"
5224
            ):
5225
                transforms.RandomZoomOut(side_range=side_range)
5226

5227
        for side_range in [[0.5, 1.5], [2.0, 1.0]]:
5228
            with pytest.raises(ValueError, match="Invalid side range"):
5229
                transforms.RandomZoomOut(side_range=side_range)
5230

5231
    @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
5232
    @pytest.mark.parametrize(
5233
        "make_input",
5234
        [
5235
            make_image_tensor,
5236
            make_image_pil,
5237
            make_image,
5238
            make_bounding_boxes,
5239
            make_segmentation_mask,
5240
            make_detection_masks,
5241
            make_video,
5242
        ],
5243
    )
5244
    @pytest.mark.parametrize("device", cpu_and_cuda())
5245
    def test_transform_params_correctness(self, side_range, make_input, device):
5246
        if make_input is make_image_pil and device != "cpu":
5247
            pytest.skip("PIL image tests with parametrization device!='cpu' will degenerate to that anyway.")
5248

5249
        transform = transforms.RandomZoomOut(side_range=side_range)
5250

5251
        input = make_input()
5252
        height, width = F.get_size(input)
5253

5254
        params = transform._get_params([input])
5255
        assert "padding" in params
5256

5257
        padding = params["padding"]
5258
        assert len(padding) == 4
5259

5260
        assert 0 <= padding[0] <= (side_range[1] - 1) * width
5261
        assert 0 <= padding[1] <= (side_range[1] - 1) * height
5262
        assert 0 <= padding[2] <= (side_range[1] - 1) * width
5263
        assert 0 <= padding[3] <= (side_range[1] - 1) * height
5264

5265

5266
class TestRandomPhotometricDistort:
5267
    # Tests are light because this largely relies on the already tested
5268
    # `adjust_{brightness,contrast,saturation,hue}` and `permute_channels` kernels.
5269

5270
    @pytest.mark.parametrize(
5271
        "make_input",
5272
        [make_image_tensor, make_image_pil, make_image, make_video],
5273
    )
5274
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5275
    @pytest.mark.parametrize("device", cpu_and_cuda())
5276
    def test_transform(self, make_input, dtype, device):
5277
        if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"):
5278
            pytest.skip(
5279
                "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
5280
                "will degenerate to that anyway."
5281
            )
5282

5283
        check_transform(
5284
            transforms.RandomPhotometricDistort(
5285
                brightness=(0.3, 0.4), contrast=(0.5, 0.6), saturation=(0.7, 0.8), hue=(-0.1, 0.2), p=1
5286
            ),
5287
            make_input(dtype=dtype, device=device),
5288
        )
5289

5290

5291
class TestScaleJitter:
5292
    # Tests are light because this largely relies on the already tested `resize` kernels.
5293

5294
    INPUT_SIZE = (17, 11)
5295
    TARGET_SIZE = (12, 13)
5296

5297
    @pytest.mark.parametrize(
5298
        "make_input",
5299
        [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
5300
    )
5301
    @pytest.mark.parametrize("device", cpu_and_cuda())
5302
    def test_transform(self, make_input, device):
5303
        if make_input is make_image_pil and device != "cpu":
5304
            pytest.skip("PIL image tests with parametrization device!='cpu' will degenerate to that anyway.")
5305

5306
        check_transform(transforms.ScaleJitter(self.TARGET_SIZE), make_input(self.INPUT_SIZE, device=device))
5307

5308
    def test__get_params(self):
5309
        input_size = self.INPUT_SIZE
5310
        target_size = self.TARGET_SIZE
5311
        scale_range = (0.5, 1.5)
5312

5313
        transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
5314
        params = transform._get_params([make_image(input_size)])
5315

5316
        assert "size" in params
5317
        size = params["size"]
5318

5319
        assert isinstance(size, tuple) and len(size) == 2
5320
        height, width = size
5321

5322
        r_min = min(target_size[1] / input_size[0], target_size[0] / input_size[1]) * scale_range[0]
5323
        r_max = min(target_size[1] / input_size[0], target_size[0] / input_size[1]) * scale_range[1]
5324

5325
        assert int(input_size[0] * r_min) <= height <= int(input_size[0] * r_max)
5326
        assert int(input_size[1] * r_min) <= width <= int(input_size[1] * r_max)
5327

5328

5329
class TestLinearTransform:
5330
    def _make_matrix_and_vector(self, input, *, device=None):
5331
        device = device or input.device
5332
        numel = math.prod(F.get_dimensions(input))
5333
        transformation_matrix = torch.randn((numel, numel), device=device)
5334
        mean_vector = torch.randn((numel,), device=device)
5335
        return transformation_matrix, mean_vector
5336

5337
    def _sample_input_adapter(self, transform, input, device):
5338
        return {key: value for key, value in input.items() if not isinstance(value, PIL.Image.Image)}
5339

5340
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
5341
    @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
5342
    @pytest.mark.parametrize("device", cpu_and_cuda())
5343
    def test_transform(self, make_input, dtype, device):
5344
        input = make_input(dtype=dtype, device=device)
5345
        check_transform(
5346
            transforms.LinearTransformation(*self._make_matrix_and_vector(input)),
5347
            input,
5348
            check_sample_input=self._sample_input_adapter,
5349
            # Compat check is failing on M1 with:
5350
            # AssertionError: Tensor-likes are not close!
5351
            # Mismatched elements: 1 / 561 (0.2%)
5352
            # See https://github.com/pytorch/vision/issues/8453
5353
            check_v1_compatibility=(sys.platform != "darwin"),
5354
        )
5355

5356
    def test_transform_error(self):
5357
        with pytest.raises(ValueError, match="transformation_matrix should be square"):
5358
            transforms.LinearTransformation(transformation_matrix=torch.rand(2, 3), mean_vector=torch.rand(2))
5359

5360
        with pytest.raises(ValueError, match="mean_vector should have the same length"):
5361
            transforms.LinearTransformation(transformation_matrix=torch.rand(2, 2), mean_vector=torch.rand(1))
5362

5363
        for matrix_dtype, vector_dtype in [(torch.float32, torch.float64), (torch.float64, torch.float32)]:
5364
            with pytest.raises(ValueError, match="Input tensors should have the same dtype"):
5365
                transforms.LinearTransformation(
5366
                    transformation_matrix=torch.rand(2, 2, dtype=matrix_dtype),
5367
                    mean_vector=torch.rand(2, dtype=vector_dtype),
5368
                )
5369

5370
        image = make_image()
5371
        transform = transforms.LinearTransformation(transformation_matrix=torch.rand(2, 2), mean_vector=torch.rand(2))
5372
        with pytest.raises(ValueError, match="Input tensor and transformation matrix have incompatible shape"):
5373
            transform(image)
5374

5375
        transform = transforms.LinearTransformation(*self._make_matrix_and_vector(image))
5376
        with pytest.raises(TypeError, match="does not support PIL images"):
5377
            transform(F.to_pil_image(image))
5378

5379
    @needs_cuda
5380
    def test_transform_error_cuda(self):
5381
        for matrix_device, vector_device in [("cuda", "cpu"), ("cpu", "cuda")]:
5382
            with pytest.raises(ValueError, match="Input tensors should be on the same device"):
5383
                transforms.LinearTransformation(
5384
                    transformation_matrix=torch.rand(2, 2, device=matrix_device),
5385
                    mean_vector=torch.rand(2, device=vector_device),
5386
                )
5387

5388
        for input_device, param_device in [("cuda", "cpu"), ("cpu", "cuda")]:
5389
            input = make_image(device=input_device)
5390
            transform = transforms.LinearTransformation(*self._make_matrix_and_vector(input, device=param_device))
5391
            with pytest.raises(
5392
                ValueError, match="Input tensor should be on the same device as transformation matrix and mean vector"
5393
            ):
5394
                transform(input)
5395

5396

5397
def make_image_numpy(*args, **kwargs):
5398
    image = make_image_tensor(*args, **kwargs)
5399
    return image.permute((1, 2, 0)).numpy()
5400

5401

5402
class TestToImage:
5403
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy])
5404
    @pytest.mark.parametrize("fn", [F.to_image, transform_cls_to_functional(transforms.ToImage)])
5405
    def test_functional_and_transform(self, make_input, fn):
5406
        input = make_input()
5407
        output = fn(input)
5408

5409
        assert isinstance(output, tv_tensors.Image)
5410

5411
        input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
5412
        assert F.get_size(output) == input_size
5413

5414
        if isinstance(input, torch.Tensor):
5415
            assert output.data_ptr() == input.data_ptr()
5416

5417
    def test_2d_np_array(self):
5418
        # Non-regression test for https://github.com/pytorch/vision/issues/8255
5419
        input = np.random.rand(10, 10)
5420
        assert F.to_image(input).shape == (1, 10, 10)
5421

5422
    def test_functional_error(self):
5423
        with pytest.raises(TypeError, match="Input can either be a pure Tensor, a numpy array, or a PIL image"):
5424
            F.to_image(object())
5425

5426

5427
class TestToPILImage:
5428
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_numpy])
5429
    @pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5430
    @pytest.mark.parametrize("fn", [F.to_pil_image, transform_cls_to_functional(transforms.ToPILImage)])
5431
    def test_functional_and_transform(self, make_input, color_space, fn):
5432
        input = make_input(color_space=color_space)
5433
        output = fn(input)
5434

5435
        assert isinstance(output, PIL.Image.Image)
5436

5437
        input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
5438
        assert F.get_size(output) == input_size
5439

5440
    def test_functional_error(self):
5441
        with pytest.raises(TypeError, match="pic should be Tensor or ndarray"):
5442
            F.to_pil_image(object())
5443

5444
        for ndim in [1, 4]:
5445
            with pytest.raises(ValueError, match="pic should be 2/3 dimensional"):
5446
                F.to_pil_image(torch.empty(*[1] * ndim))
5447

5448
        with pytest.raises(ValueError, match="pic should not have > 4 channels"):
5449
            num_channels = 5
5450
            F.to_pil_image(torch.empty(num_channels, 1, 1))
5451

5452

5453
class TestToTensor:
5454
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_image_numpy])
5455
    def test_smoke(self, make_input):
5456
        with pytest.warns(UserWarning, match="deprecated and will be removed"):
5457
            transform = transforms.ToTensor()
5458

5459
        input = make_input()
5460
        output = transform(input)
5461

5462
        input_size = list(input.shape[:2]) if isinstance(input, np.ndarray) else F.get_size(input)
5463
        assert F.get_size(output) == input_size
5464

5465

5466
class TestPILToTensor:
5467
    @pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
5468
    @pytest.mark.parametrize("fn", [F.pil_to_tensor, transform_cls_to_functional(transforms.PILToTensor)])
5469
    def test_functional_and_transform(self, color_space, fn):
5470
        input = make_image_pil(color_space=color_space)
5471
        output = fn(input)
5472

5473
        assert isinstance(output, torch.Tensor) and not isinstance(output, tv_tensors.TVTensor)
5474
        assert F.get_size(output) == F.get_size(input)
5475

5476
    def test_functional_error(self):
5477
        with pytest.raises(TypeError, match="pic should be PIL Image"):
5478
            F.pil_to_tensor(object())
5479

5480

5481
class TestLambda:
5482
    @pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
5483
    @pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])
5484
    def test_transform(self, input, types):
5485
        was_applied = False
5486

5487
        def was_applied_fn(input):
5488
            nonlocal was_applied
5489
            was_applied = True
5490
            return input
5491

5492
        transform = transforms.Lambda(was_applied_fn, *types)
5493
        output = transform(input)
5494

5495
        assert output is input
5496
        assert was_applied is (not types or isinstance(input, types))
5497

5498

5499
@pytest.mark.parametrize(
5500
    ("alias", "target"),
5501
    [
5502
        pytest.param(alias, target, id=alias.__name__)
5503
        for alias, target in [
5504
            (F.hflip, F.horizontal_flip),
5505
            (F.vflip, F.vertical_flip),
5506
            (F.get_image_num_channels, F.get_num_channels),
5507
            (F.to_pil_image, F.to_pil_image),
5508
            (F.elastic_transform, F.elastic),
5509
            (F.to_grayscale, F.rgb_to_grayscale),
5510
        ]
5511
    ],
5512
)
5513
def test_alias(alias, target):
5514
    assert alias is target
5515

5516

5517
@pytest.mark.parametrize(
5518
    "make_inputs",
5519
    itertools.permutations(
5520
        [
5521
            make_image_tensor,
5522
            make_image_tensor,
5523
            make_image_pil,
5524
            make_image,
5525
            make_video,
5526
        ],
5527
        3,
5528
    ),
5529
)
5530
def test_pure_tensor_heuristic(make_inputs):
5531
    flat_inputs = [make_input() for make_input in make_inputs]
5532

5533
    def split_on_pure_tensor(to_split):
5534
        # This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
5535
        # 1. The first pure tensor. If none is present, this will be `None`
5536
        # 2. A list of the remaining pure tensors
5537
        # 3. A list of all other items
5538
        pure_tensors = []
5539
        others = []
5540
        # Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
5541
        # affect the splitting.
5542
        for item, inpt in zip(to_split, flat_inputs):
5543
            (pure_tensors if is_pure_tensor(inpt) else others).append(item)
5544
        return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others
5545

5546
    class CopyCloneTransform(transforms.Transform):
5547
        def _transform(self, inpt, params):
5548
            return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy()
5549

5550
        @staticmethod
5551
        def was_applied(output, inpt):
5552
            identity = output is inpt
5553
            if identity:
5554
                return False
5555

5556
            # Make sure nothing fishy is going on
5557
            assert_equal(output, inpt)
5558
            return True
5559

5560
    first_pure_tensor_input, other_pure_tensor_inputs, other_inputs = split_on_pure_tensor(flat_inputs)
5561

5562
    transform = CopyCloneTransform()
5563
    transformed_sample = transform(flat_inputs)
5564

5565
    first_pure_tensor_output, other_pure_tensor_outputs, other_outputs = split_on_pure_tensor(transformed_sample)
5566

5567
    if first_pure_tensor_input is not None:
5568
        if other_inputs:
5569
            assert not transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
5570
        else:
5571
            assert transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
5572

5573
    for output, inpt in zip(other_pure_tensor_outputs, other_pure_tensor_inputs):
5574
        assert not transform.was_applied(output, inpt)
5575

5576
    for input, output in zip(other_inputs, other_outputs):
5577
        assert transform.was_applied(output, input)
5578

5579

5580
class TestRandomIoUCrop:
5581
    @pytest.mark.parametrize("device", cpu_and_cuda())
5582
    @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]])
5583
    def test__get_params(self, device, options):
5584
        orig_h, orig_w = size = (24, 32)
5585
        image = make_image(size)
5586
        bboxes = tv_tensors.BoundingBoxes(
5587
            torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
5588
            format="XYXY",
5589
            canvas_size=size,
5590
            device=device,
5591
        )
5592
        sample = [image, bboxes]
5593

5594
        transform = transforms.RandomIoUCrop(sampler_options=options)
5595

5596
        n_samples = 5
5597
        for _ in range(n_samples):
5598

5599
            params = transform._get_params(sample)
5600

5601
            if options == [2.0]:
5602
                assert len(params) == 0
5603
                return
5604

5605
            assert len(params["is_within_crop_area"]) > 0
5606
            assert params["is_within_crop_area"].dtype == torch.bool
5607

5608
            assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h)
5609
            assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w)
5610

5611
            left, top = params["left"], params["top"]
5612
            new_h, new_w = params["height"], params["width"]
5613
            ious = box_iou(
5614
                bboxes,
5615
                torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device),
5616
            )
5617
            assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}"
5618

5619
    def test__transform_empty_params(self, mocker):
5620
        transform = transforms.RandomIoUCrop(sampler_options=[2.0])
5621
        image = tv_tensors.Image(torch.rand(1, 3, 4, 4))
5622
        bboxes = tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4))
5623
        label = torch.tensor([1])
5624
        sample = [image, bboxes, label]
5625
        # Let's mock transform._get_params to control the output:
5626
        transform._get_params = mocker.MagicMock(return_value={})
5627
        output = transform(sample)
5628
        torch.testing.assert_close(output, sample)
5629

5630
    def test_forward_assertion(self):
5631
        transform = transforms.RandomIoUCrop()
5632
        with pytest.raises(
5633
            TypeError,
5634
            match="requires input sample to contain tensor or PIL images and bounding boxes",
5635
        ):
5636
            transform(torch.tensor(0))
5637

5638
    def test__transform(self, mocker):
5639
        transform = transforms.RandomIoUCrop()
5640

5641
        size = (32, 24)
5642
        image = make_image(size)
5643
        bboxes = make_bounding_boxes(format="XYXY", canvas_size=size, num_boxes=6)
5644
        masks = make_detection_masks(size, num_masks=6)
5645

5646
        sample = [image, bboxes, masks]
5647

5648
        is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)
5649

5650
        params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
5651
        transform._get_params = mocker.MagicMock(return_value=params)
5652
        output = transform(sample)
5653

5654
        # check number of bboxes vs number of labels:
5655
        output_bboxes = output[1]
5656
        assert isinstance(output_bboxes, tv_tensors.BoundingBoxes)
5657
        assert (output_bboxes[~is_within_crop_area] == 0).all()
5658

5659
        output_masks = output[2]
5660
        assert isinstance(output_masks, tv_tensors.Mask)
5661

5662

5663
class TestRandomShortestSize:
5664
    @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
5665
    def test__get_params(self, min_size, max_size):
5666
        canvas_size = (3, 10)
5667

5668
        transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size, antialias=True)
5669

5670
        sample = make_image(canvas_size)
5671
        params = transform._get_params([sample])
5672

5673
        assert "size" in params
5674
        size = params["size"]
5675

5676
        assert isinstance(size, tuple) and len(size) == 2
5677

5678
        longer = max(size)
5679
        shorter = min(size)
5680
        if max_size is not None:
5681
            assert longer <= max_size
5682
            assert shorter <= max_size
5683
        else:
5684
            assert shorter in min_size
5685

5686

5687
class TestRandomResize:
5688
    def test__get_params(self):
5689
        min_size = 3
5690
        max_size = 6
5691

5692
        transform = transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
5693

5694
        for _ in range(10):
5695
            params = transform._get_params([])
5696

5697
            assert isinstance(params["size"], list) and len(params["size"]) == 1
5698
            size = params["size"][0]
5699

5700
            assert min_size <= size < max_size
5701

5702

5703
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
5704
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
5705
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
5706
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
5707
def test_classification_preset(image_type, label_type, dataset_return_type, to_tensor):
5708

5709
    image = tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
5710
    if image_type is PIL.Image:
5711
        image = to_pil_image(image[0])
5712
    elif image_type is torch.Tensor:
5713
        image = image.as_subclass(torch.Tensor)
5714
        assert is_pure_tensor(image)
5715

5716
    label = 1 if label_type is int else torch.tensor([1])
5717

5718
    if dataset_return_type is dict:
5719
        sample = {
5720
            "image": image,
5721
            "label": label,
5722
        }
5723
    else:
5724
        sample = image, label
5725

5726
    if to_tensor is transforms.ToTensor:
5727
        with pytest.warns(UserWarning, match="deprecated and will be removed"):
5728
            to_tensor = to_tensor()
5729
    else:
5730
        to_tensor = to_tensor()
5731

5732
    t = transforms.Compose(
5733
        [
5734
            transforms.RandomResizedCrop((224, 224), antialias=True),
5735
            transforms.RandomHorizontalFlip(p=1),
5736
            transforms.RandAugment(),
5737
            transforms.TrivialAugmentWide(),
5738
            transforms.AugMix(),
5739
            transforms.AutoAugment(),
5740
            to_tensor,
5741
            # TODO: ConvertImageDtype is a pass-through on PIL images, is that
5742
            # intended?  This results in a failure if we convert to tensor after
5743
            # it, because the image would still be uint8 which make Normalize
5744
            # fail.
5745
            transforms.ConvertImageDtype(torch.float),
5746
            transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
5747
            transforms.RandomErasing(p=1),
5748
        ]
5749
    )
5750

5751
    out = t(sample)
5752

5753
    assert type(out) == type(sample)
5754

5755
    if dataset_return_type is tuple:
5756
        out_image, out_label = out
5757
    else:
5758
        assert out.keys() == sample.keys()
5759
        out_image, out_label = out.values()
5760

5761
    assert out_image.shape[-2:] == (224, 224)
5762
    assert out_label == label
5763

5764

5765
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
5766
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
5767
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
5768
@pytest.mark.parametrize("sanitize", (True, False))
5769
def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
5770
    torch.manual_seed(0)
5771

5772
    if to_tensor is transforms.ToTensor:
5773
        with pytest.warns(UserWarning, match="deprecated and will be removed"):
5774
            to_tensor = to_tensor()
5775
    else:
5776
        to_tensor = to_tensor()
5777

5778
    if data_augmentation == "hflip":
5779
        t = [
5780
            transforms.RandomHorizontalFlip(p=1),
5781
            to_tensor,
5782
            transforms.ConvertImageDtype(torch.float),
5783
        ]
5784
    elif data_augmentation == "lsj":
5785
        t = [
5786
            transforms.ScaleJitter(target_size=(1024, 1024), antialias=True),
5787
            # Note: replaced FixedSizeCrop with RandomCrop, becuase we're
5788
            # leaving FixedSizeCrop in prototype for now, and it expects Label
5789
            # classes which we won't release yet.
5790
            # transforms.FixedSizeCrop(
5791
            #     size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {tv_tensors.Mask: 0})
5792
            # ),
5793
            transforms.RandomCrop((1024, 1024), pad_if_needed=True),
5794
            transforms.RandomHorizontalFlip(p=1),
5795
            to_tensor,
5796
            transforms.ConvertImageDtype(torch.float),
5797
        ]
5798
    elif data_augmentation == "multiscale":
5799
        t = [
5800
            transforms.RandomShortestSize(
5801
                min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333, antialias=True
5802
            ),
5803
            transforms.RandomHorizontalFlip(p=1),
5804
            to_tensor,
5805
            transforms.ConvertImageDtype(torch.float),
5806
        ]
5807
    elif data_augmentation == "ssd":
5808
        t = [
5809
            transforms.RandomPhotometricDistort(p=1),
5810
            transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), tv_tensors.Mask: 0}, p=1),
5811
            transforms.RandomIoUCrop(),
5812
            transforms.RandomHorizontalFlip(p=1),
5813
            to_tensor,
5814
            transforms.ConvertImageDtype(torch.float),
5815
        ]
5816
    elif data_augmentation == "ssdlite":
5817
        t = [
5818
            transforms.RandomIoUCrop(),
5819
            transforms.RandomHorizontalFlip(p=1),
5820
            to_tensor,
5821
            transforms.ConvertImageDtype(torch.float),
5822
        ]
5823
    if sanitize:
5824
        t += [transforms.SanitizeBoundingBoxes()]
5825
    t = transforms.Compose(t)
5826

5827
    num_boxes = 5
5828
    H = W = 250
5829

5830
    image = tv_tensors.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8))
5831
    if image_type is PIL.Image:
5832
        image = to_pil_image(image[0])
5833
    elif image_type is torch.Tensor:
5834
        image = image.as_subclass(torch.Tensor)
5835
        assert is_pure_tensor(image)
5836

5837
    label = torch.randint(0, 10, size=(num_boxes,))
5838

5839
    boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
5840
    boxes[:, 2:] += boxes[:, :2]
5841
    boxes = boxes.clamp(min=0, max=min(H, W))
5842
    boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
5843

5844
    masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))
5845

5846
    sample = {
5847
        "image": image,
5848
        "label": label,
5849
        "boxes": boxes,
5850
        "masks": masks,
5851
    }
5852

5853
    out = t(sample)
5854

5855
    if isinstance(to_tensor, transforms.ToTensor) and image_type is not tv_tensors.Image:
5856
        assert is_pure_tensor(out["image"])
5857
    else:
5858
        assert isinstance(out["image"], tv_tensors.Image)
5859
    assert isinstance(out["label"], type(sample["label"]))
5860

5861
    num_boxes_expected = {
5862
        # ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
5863
        # doesn't remove them strictly speaking, it just marks some boxes as
5864
        # degenerate and those boxes will be later removed by
5865
        # SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
5866
        # param is True.
5867
        # Note that the values below are probably specific to the random seed
5868
        # set above (which is fine).
5869
        (True, "ssd"): 5,
5870
        (True, "ssdlite"): 4,
5871
    }.get((sanitize, data_augmentation), num_boxes)
5872

5873
    assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes_expected
5874

5875

5876
class TestSanitizeBoundingBoxes:
5877
    def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10, min_area=10):
5878
        boxes_and_validity = [
5879
            ([0, 1, 10, 1], False),  # Y1 == Y2
5880
            ([0, 1, 0, 20], False),  # X1 == X2
5881
            ([0, 0, min_size - 1, 10], False),  # H < min_size
5882
            ([0, 0, 10, min_size - 1], False),  # W < min_size
5883
            ([0, 0, 10, H + 1], False),  # Y2 > H
5884
            ([0, 0, W + 1, 10], False),  # X2 > W
5885
            ([-1, 1, 10, 20], False),  # any < 0
5886
            ([0, 0, -1, 20], False),  # any < 0
5887
            ([0, 0, -10, -1], False),  # any < 0
5888
            ([0, 0, min_size, 10], min_size * 10 >= min_area),  # H < min_size
5889
            ([0, 0, 10, min_size], min_size * 10 >= min_area),  # W < min_size
5890
            ([0, 0, W, H], W * H >= min_area),
5891
            ([1, 1, 30, 20], 29 * 19 >= min_area),
5892
            ([0, 0, 10, 10], 9 * 9 >= min_area),
5893
            ([1, 1, 30, 20], 29 * 19 >= min_area),
5894
        ]
5895

5896
        random.shuffle(boxes_and_validity)  # For test robustness: mix order of wrong and correct cases
5897
        boxes, expected_valid_mask = zip(*boxes_and_validity)
5898
        boxes = tv_tensors.BoundingBoxes(
5899
            boxes,
5900
            format=tv_tensors.BoundingBoxFormat.XYXY,
5901
            canvas_size=(H, W),
5902
        )
5903

5904
        return boxes, expected_valid_mask
5905

5906
    @pytest.mark.parametrize("min_size, min_area", ((1, 1), (10, 1), (10, 101)))
5907
    @pytest.mark.parametrize(
5908
        "labels_getter",
5909
        (
5910
            "default",
5911
            lambda inputs: inputs["labels"],
5912
            lambda inputs: (inputs["labels"], inputs["other_labels"]),
5913
            lambda inputs: [inputs["labels"], inputs["other_labels"]],
5914
            None,
5915
            lambda inputs: None,
5916
        ),
5917
    )
5918
    @pytest.mark.parametrize("sample_type", (tuple, dict))
5919
    def test_transform(self, min_size, min_area, labels_getter, sample_type):
5920

5921
        if sample_type is tuple and not isinstance(labels_getter, str):
5922
            # The "lambda inputs: inputs["labels"]" labels_getter used in this test
5923
            # doesn't work if the input is a tuple.
5924
            return
5925

5926
        H, W = 256, 128
5927
        boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size, min_area=min_area)
5928
        valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]
5929

5930
        labels = torch.arange(boxes.shape[0])
5931
        masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
5932
        # other_labels corresponds to properties from COCO like iscrowd, area...
5933
        # We only sanitize it when labels_getter returns a tuple
5934
        other_labels = torch.arange(boxes.shape[0])
5935
        whatever = torch.rand(10)
5936
        input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
5937
        sample = {
5938
            "image": input_img,
5939
            "labels": labels,
5940
            "boxes": boxes,
5941
            "other_labels": other_labels,
5942
            "whatever": whatever,
5943
            "None": None,
5944
            "masks": masks,
5945
        }
5946

5947
        if sample_type is tuple:
5948
            img = sample.pop("image")
5949
            sample = (img, sample)
5950

5951
        out = transforms.SanitizeBoundingBoxes(min_size=min_size, min_area=min_area, labels_getter=labels_getter)(
5952
            sample
5953
        )
5954

5955
        if sample_type is tuple:
5956
            out_image = out[0]
5957
            out_labels = out[1]["labels"]
5958
            out_other_labels = out[1]["other_labels"]
5959
            out_boxes = out[1]["boxes"]
5960
            out_masks = out[1]["masks"]
5961
            out_whatever = out[1]["whatever"]
5962
        else:
5963
            out_image = out["image"]
5964
            out_labels = out["labels"]
5965
            out_other_labels = out["other_labels"]
5966
            out_boxes = out["boxes"]
5967
            out_masks = out["masks"]
5968
            out_whatever = out["whatever"]
5969

5970
        assert out_image is input_img
5971
        assert out_whatever is whatever
5972

5973
        assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
5974
        assert isinstance(out_masks, tv_tensors.Mask)
5975

5976
        if labels_getter is None or (callable(labels_getter) and labels_getter(sample) is None):
5977
            assert out_labels is labels
5978
            assert out_other_labels is other_labels
5979
        else:
5980
            assert isinstance(out_labels, torch.Tensor)
5981
            assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0]
5982
            # This works because we conveniently set labels to arange(num_boxes)
5983
            assert out_labels.tolist() == valid_indices
5984

5985
            if callable(labels_getter) and isinstance(labels_getter(sample), (tuple, list)):
5986
                assert_equal(out_other_labels, out_labels)
5987
            else:
5988
                assert_equal(out_other_labels, other_labels)
5989

5990
    @pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes))
5991
    def test_functional(self, input_type):
5992
        # Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some
5993
        # redundancy with test_transform() in terms of correctness checks. But that's OK.
5994

5995
        H, W, min_size = 256, 128, 10
5996

5997
        boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
5998

5999
        if input_type is tv_tensors.BoundingBoxes:
6000
            format = canvas_size = None
6001
        else:
6002
            # just passing "XYXY" explicitly to make sure we support strings
6003
            format, canvas_size = "XYXY", boxes.canvas_size
6004
            boxes = boxes.as_subclass(torch.Tensor)
6005

6006
        boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size)
6007

6008
        assert_equal(valid, torch.tensor(expected_valid_mask))
6009
        assert type(valid) == torch.Tensor
6010
        assert boxes.shape[0] == sum(valid)
6011
        assert isinstance(boxes, input_type)
6012

6013
    def test_kernel(self):
6014
        H, W, min_size = 256, 128, 10
6015
        boxes, _ = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
6016

6017
        format, canvas_size = boxes.format, boxes.canvas_size
6018
        boxes = boxes.as_subclass(torch.Tensor)
6019

6020
        check_kernel(
6021
            F.sanitize_bounding_boxes,
6022
            input=boxes,
6023
            format=format,
6024
            canvas_size=canvas_size,
6025
            check_batched_vs_unbatched=False,
6026
        )
6027

6028
    def test_no_label(self):
6029
        # Non-regression test for https://github.com/pytorch/vision/issues/7878
6030

6031
        img = make_image()
6032
        boxes = make_bounding_boxes()
6033

6034
        with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
6035
            transforms.SanitizeBoundingBoxes()(img, boxes)
6036

6037
        out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
6038
        assert isinstance(out_img, tv_tensors.Image)
6039
        assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
6040

6041
    def test_errors_transform(self):
6042
        good_bbox = tv_tensors.BoundingBoxes(
6043
            [[0, 0, 10, 10]],
6044
            format=tv_tensors.BoundingBoxFormat.XYXY,
6045
            canvas_size=(20, 20),
6046
        )
6047

6048
        with pytest.raises(ValueError, match="min_size must be >= 1"):
6049
            transforms.SanitizeBoundingBoxes(min_size=0)
6050
        with pytest.raises(ValueError, match="min_area must be >= 1"):
6051
            transforms.SanitizeBoundingBoxes(min_area=0)
6052
        with pytest.raises(ValueError, match="labels_getter should either be 'default'"):
6053
            transforms.SanitizeBoundingBoxes(labels_getter=12)
6054

6055
        with pytest.raises(ValueError, match="Could not infer where the labels are"):
6056
            bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
6057
            transforms.SanitizeBoundingBoxes()(bad_labels_key)
6058

6059
        with pytest.raises(ValueError, match="must be a tensor"):
6060
            not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
6061
            transforms.SanitizeBoundingBoxes()(not_a_tensor)
6062

6063
        with pytest.raises(ValueError, match="Number of boxes"):
6064
            different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
6065
            transforms.SanitizeBoundingBoxes()(different_sizes)
6066

6067
    def test_errors_functional(self):
6068

6069
        good_bbox = tv_tensors.BoundingBoxes(
6070
            [[0, 0, 10, 10]],
6071
            format=tv_tensors.BoundingBoxFormat.XYXY,
6072
            canvas_size=(20, 20),
6073
        )
6074

6075
        with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"):
6076
            F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format="XYXY", canvas_size=None)
6077

6078
        with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"):
6079
            F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format=None, canvas_size=(10, 10))
6080

6081
        with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"):
6082
            F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None)
6083

6084
        with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"):
6085
            F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None)
6086

6087
        with pytest.raises(ValueError, match="bounding_boxes must be a tv_tensors.BoundingBoxes instance or a"):
6088
            F.sanitize_bounding_boxes(good_bbox.tolist())
6089

6090

6091
class TestJPEG:
6092
    @pytest.mark.parametrize("quality", [5, 75])
6093
    @pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
6094
    def test_kernel_image(self, quality, color_space):
6095
        check_kernel(F.jpeg_image, make_image(color_space=color_space), quality=quality)
6096

6097
    def test_kernel_video(self):
6098
        check_kernel(F.jpeg_video, make_video(), quality=5)
6099

6100
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
6101
    def test_functional(self, make_input):
6102
        check_functional(F.jpeg, make_input(), quality=5)
6103

6104
    @pytest.mark.parametrize(
6105
        ("kernel", "input_type"),
6106
        [
6107
            (F.jpeg_image, torch.Tensor),
6108
            (F._augment._jpeg_image_pil, PIL.Image.Image),
6109
            (F.jpeg_image, tv_tensors.Image),
6110
            (F.jpeg_video, tv_tensors.Video),
6111
        ],
6112
    )
6113
    def test_functional_signature(self, kernel, input_type):
6114
        check_functional_kernel_signature_match(F.jpeg, kernel=kernel, input_type=input_type)
6115

6116
    @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
6117
    @pytest.mark.parametrize("quality", [5, (10, 20)])
6118
    @pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
6119
    def test_transform(self, make_input, quality, color_space):
6120
        check_transform(transforms.JPEG(quality=quality), make_input(color_space=color_space))
6121

6122
    @pytest.mark.parametrize("quality", [5])
6123
    def test_functional_image_correctness(self, quality):
6124
        image = make_image()
6125

6126
        actual = F.jpeg(image, quality=quality)
6127
        expected = F.to_image(F.jpeg(F.to_pil_image(image), quality=quality))
6128

6129
        # NOTE: this will fail if torchvision and Pillow use different JPEG encoder/decoder
6130
        torch.testing.assert_close(actual, expected, rtol=0, atol=1)
6131

6132
    @pytest.mark.parametrize("quality", [5, (10, 20)])
6133
    @pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
6134
    @pytest.mark.parametrize("seed", list(range(5)))
6135
    def test_transform_image_correctness(self, quality, color_space, seed):
6136
        image = make_image(color_space=color_space)
6137

6138
        transform = transforms.JPEG(quality=quality)
6139

6140
        with freeze_rng_state():
6141
            torch.manual_seed(seed)
6142
            actual = transform(image)
6143

6144
            torch.manual_seed(seed)
6145
            expected = F.to_image(transform(F.to_pil_image(image)))
6146

6147
        torch.testing.assert_close(actual, expected, rtol=0, atol=1)
6148

6149
    @pytest.mark.parametrize("quality", [5, (10, 20)])
6150
    @pytest.mark.parametrize("seed", list(range(10)))
6151
    def test_transform_get_params_bounds(self, quality, seed):
6152
        transform = transforms.JPEG(quality=quality)
6153

6154
        with freeze_rng_state():
6155
            torch.manual_seed(seed)
6156
            params = transform._get_params([])
6157

6158
        if isinstance(quality, int):
6159
            assert params["quality"] == quality
6160
        else:
6161
            assert quality[0] <= params["quality"] <= quality[1]
6162

6163
    @pytest.mark.parametrize("quality", [[0], [0, 0, 0]])
6164
    def test_transform_sequence_len_error(self, quality):
6165
        with pytest.raises(ValueError, match="quality should be a sequence of length 2"):
6166
            transforms.JPEG(quality=quality)
6167

6168
    @pytest.mark.parametrize("quality", [-1, 0, 150])
6169
    def test_transform_invalid_quality_error(self, quality):
6170
        with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"):
6171
            transforms.JPEG(quality=quality)
6172

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.