6
import torchvision.transforms.v2._utils
7
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image
9
from torchvision import tv_tensors
10
from torchvision.transforms.v2._utils import has_all, has_any
11
from torchvision.transforms.v2.functional import to_pil_image
14
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
15
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
16
MASK = make_detection_masks(DEFAULT_SIZE)
19
@pytest.mark.parametrize(
20
("sample", "types", "expected"),
22
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
23
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
24
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
25
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
26
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
27
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
28
((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
29
((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
30
((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
32
(IMAGE, BOUNDING_BOX, MASK),
33
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
36
((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
37
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
38
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
39
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
40
((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
42
(torch.Tensor(IMAGE),),
43
(tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
47
(to_pil_image(IMAGE),),
48
(tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
53
def test_has_any(sample, types, expected):
54
assert has_any(sample, *types) is expected
57
@pytest.mark.parametrize(
58
("sample", "types", "expected"),
60
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
61
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
62
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
63
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
64
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
65
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
67
(IMAGE, BOUNDING_BOX, MASK),
68
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
71
((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
72
((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), False),
73
((IMAGE, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
75
(IMAGE, BOUNDING_BOX, MASK),
76
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
79
((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
80
((IMAGE, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
81
((IMAGE, BOUNDING_BOX), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
83
(IMAGE, BOUNDING_BOX, MASK),
84
(lambda obj: isinstance(obj, (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask)),),
87
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
88
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
91
def test_has_all(sample, types, expected):
92
assert has_all(sample, *types) is expected