vision

Форк
0
/
test_transforms_v2_utils.py 
92 строки · 4.0 Кб
1
import PIL.Image
2
import pytest
3

4
import torch
5

6
import torchvision.transforms.v2._utils
7
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image
8

9
from torchvision import tv_tensors
10
from torchvision.transforms.v2._utils import has_all, has_any
11
from torchvision.transforms.v2.functional import to_pil_image
12

13

14
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
15
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
16
MASK = make_detection_masks(DEFAULT_SIZE)
17

18

19
@pytest.mark.parametrize(
20
    ("sample", "types", "expected"),
21
    [
22
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
23
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
24
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
25
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
26
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
27
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
28
        ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
29
        ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
30
        ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
31
        (
32
            (IMAGE, BOUNDING_BOX, MASK),
33
            (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
34
            True,
35
        ),
36
        ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
37
        ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
38
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
39
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
40
        ((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
41
        (
42
            (torch.Tensor(IMAGE),),
43
            (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
44
            True,
45
        ),
46
        (
47
            (to_pil_image(IMAGE),),
48
            (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
49
            True,
50
        ),
51
    ],
52
)
53
def test_has_any(sample, types, expected):
54
    assert has_any(sample, *types) is expected
55

56

57
@pytest.mark.parametrize(
58
    ("sample", "types", "expected"),
59
    [
60
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
61
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
62
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
63
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
64
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
65
        ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
66
        (
67
            (IMAGE, BOUNDING_BOX, MASK),
68
            (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
69
            True,
70
        ),
71
        ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
72
        ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), False),
73
        ((IMAGE, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
74
        (
75
            (IMAGE, BOUNDING_BOX, MASK),
76
            (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
77
            True,
78
        ),
79
        ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
80
        ((IMAGE, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
81
        ((IMAGE, BOUNDING_BOX), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
82
        (
83
            (IMAGE, BOUNDING_BOX, MASK),
84
            (lambda obj: isinstance(obj, (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask)),),
85
            True,
86
        ),
87
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
88
        ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
89
    ],
90
)
91
def test_has_all(sample, types, expected):
92
    assert has_all(sample, *types) is expected
93

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.