vision

Форк
0
/
common_utils.py 
519 строк · 15.8 Кб
1
import contextlib
2
import functools
3
import itertools
4
import os
5
import pathlib
6
import random
7
import re
8
import shutil
9
import sys
10
import tempfile
11
import warnings
12
from subprocess import CalledProcessError, check_output, STDOUT
13

14
import numpy as np
15
import PIL.Image
16
import pytest
17
import torch
18
import torch.testing
19
from PIL import Image
20

21
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
22
from torchvision import io, tv_tensors
23
from torchvision.transforms._functional_tensor import _max_value as get_max_value
24
from torchvision.transforms.v2.functional import to_image, to_pil_image
25

26

27
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
28
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
29
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
30
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
31
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
32
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
33

34

35
@contextlib.contextmanager
36
def get_tmp_dir(src=None, **kwargs):
37
    tmp_dir = tempfile.mkdtemp(**kwargs)
38
    if src is not None:
39
        os.rmdir(tmp_dir)
40
        shutil.copytree(src, tmp_dir)
41
    try:
42
        yield tmp_dir
43
    finally:
44
        shutil.rmtree(tmp_dir)
45

46

47
def set_rng_seed(seed):
48
    torch.manual_seed(seed)
49
    random.seed(seed)
50

51

52
class MapNestedTensorObjectImpl:
53
    def __init__(self, tensor_map_fn):
54
        self.tensor_map_fn = tensor_map_fn
55

56
    def __call__(self, object):
57
        if isinstance(object, torch.Tensor):
58
            return self.tensor_map_fn(object)
59

60
        elif isinstance(object, dict):
61
            mapped_dict = {}
62
            for key, value in object.items():
63
                mapped_dict[self(key)] = self(value)
64
            return mapped_dict
65

66
        elif isinstance(object, (list, tuple)):
67
            mapped_iter = []
68
            for iter in object:
69
                mapped_iter.append(self(iter))
70
            return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter)
71

72
        else:
73
            return object
74

75

76
def map_nested_tensor_object(object, tensor_map_fn):
77
    impl = MapNestedTensorObjectImpl(tensor_map_fn)
78
    return impl(object)
79

80

81
def is_iterable(obj):
82
    try:
83
        iter(obj)
84
        return True
85
    except TypeError:
86
        return False
87

88

89
@contextlib.contextmanager
90
def freeze_rng_state():
91
    rng_state = torch.get_rng_state()
92
    if torch.cuda.is_available():
93
        cuda_rng_state = torch.cuda.get_rng_state()
94
    yield
95
    if torch.cuda.is_available():
96
        torch.cuda.set_rng_state(cuda_rng_state)
97
    torch.set_rng_state(rng_state)
98

99

100
def cycle_over(objs):
101
    for idx, obj1 in enumerate(objs):
102
        for obj2 in objs[:idx] + objs[idx + 1 :]:
103
            yield obj1, obj2
104

105

106
def int_dtypes():
107
    return (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
108

109

110
def float_dtypes():
111
    return (torch.float32, torch.float64)
112

113

114
@contextlib.contextmanager
115
def disable_console_output():
116
    with contextlib.ExitStack() as stack, open(os.devnull, "w") as devnull:
117
        stack.enter_context(contextlib.redirect_stdout(devnull))
118
        stack.enter_context(contextlib.redirect_stderr(devnull))
119
        yield
120

121

122
def cpu_and_cuda():
123
    import pytest  # noqa
124

125
    return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
126

127

128
def cpu_and_cuda_and_mps():
129
    return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)
130

131

132
def needs_cuda(test_func):
133
    import pytest  # noqa
134

135
    return pytest.mark.needs_cuda(test_func)
136

137

138
def needs_mps(test_func):
139
    import pytest  # noqa
140

141
    return pytest.mark.needs_mps(test_func)
142

143

144
def _create_data(height=3, width=3, channels=3, device="cpu"):
145
    # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
146
    tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
147
    data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
148
    mode = "RGB"
149
    if channels == 1:
150
        mode = "L"
151
        data = data[..., 0]
152
    pil_img = Image.fromarray(data, mode=mode)
153
    return tensor, pil_img
154

155

156
def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
157
    # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
158
    batch_tensor = torch.randint(0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device)
159
    return batch_tensor
160

161

162
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
163
    names = []
164
    for i in range(num_videos):
165
        if sizes is None:
166
            size = 5 * (i + 1)
167
        else:
168
            size = sizes[i]
169
        if fps is None:
170
            f = 5
171
        else:
172
            f = fps[i]
173
        data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
174
        name = os.path.join(tmpdir, f"{i}.mp4")
175
        names.append(name)
176
        io.write_video(name, data, fps=f)
177

178
    return names
179

180

181
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
182
    # FIXME: this is handled automatically by `assert_equal` below. Let's remove this in favor of it
183
    np_pil_image = np.array(pil_image)
184
    if np_pil_image.ndim == 2:
185
        np_pil_image = np_pil_image[:, :, None]
186
    pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
187
    if msg is None:
188
        msg = f"tensor:\n{tensor} \ndid not equal PIL tensor:\n{pil_tensor}"
189
    assert_equal(tensor.cpu(), pil_tensor, msg=msg)
190

191

192
def _assert_approx_equal_tensor_to_pil(
193
    tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None
194
):
195
    # FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it
196
    # TODO: we could just merge this into _assert_equal_tensor_to_pil
197
    np_pil_image = np.array(pil_image)
198
    if np_pil_image.ndim == 2:
199
        np_pil_image = np_pil_image[:, :, None]
200
    pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
201

202
    if allowed_percentage_diff is not None:
203
        # Assert that less than a given %age of pixels are different
204
        assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
205

206
    # error value can be mean absolute error, max abs error
207
    # Convert to float to avoid underflow when computing absolute difference
208
    tensor = tensor.to(torch.float)
209
    pil_tensor = pil_tensor.to(torch.float)
210
    err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
211
    assert err < tol, f"{err} vs {tol}"
212

213

214
def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
215
    transformed_batch = fn(batch_tensors, **fn_kwargs)
216
    for i in range(len(batch_tensors)):
217
        img_tensor = batch_tensors[i, ...]
218
        transformed_img = fn(img_tensor, **fn_kwargs)
219
        torch.testing.assert_close(transformed_img, transformed_batch[i, ...], rtol=0, atol=1e-6)
220

221
    if scripted_fn_atol >= 0:
222
        scripted_fn = torch.jit.script(fn)
223
        # scriptable function test
224
        s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
225
        torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
226

227

228
def cache(fn):
229
    """Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
230
    but this also caches exceptions.
231
    """
232
    sentinel = object()
233
    out_cache = {}
234
    exc_tb_cache = {}
235

236
    @functools.wraps(fn)
237
    def wrapper(*args, **kwargs):
238
        key = args + tuple(kwargs.values())
239

240
        out = out_cache.get(key, sentinel)
241
        if out is not sentinel:
242
            return out
243

244
        exc_tb = exc_tb_cache.get(key, sentinel)
245
        if exc_tb is not sentinel:
246
            raise exc_tb[0].with_traceback(exc_tb[1])
247

248
        try:
249
            out = fn(*args, **kwargs)
250
        except Exception as exc:
251
            # We need to cache the traceback here as well. Otherwise, each re-raise will add the internal pytest
252
            # traceback frames anew, but they will only be removed once. Thus, the traceback will be ginormous hiding
253
            # the actual information in the noise. See https://github.com/pytest-dev/pytest/issues/10363 for details.
254
            exc_tb_cache[key] = exc, exc.__traceback__
255
            raise exc
256

257
        out_cache[key] = out
258
        return out
259

260
    return wrapper
261

262

263
def combinations_grid(**kwargs):
264
    """Creates a grid of input combinations.
265

266
    Each element in the returned sequence is a dictionary containing one possible combination as values.
267

268
    Example:
269
        >>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
270
        [
271
            {'foo': 'bar', 'spam': 'eggs'},
272
            {'foo': 'bar', 'spam': 'ham'},
273
            {'foo': 'baz', 'spam': 'eggs'},
274
            {'foo': 'baz', 'spam': 'ham'}
275
        ]
276
    """
277
    return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
278

279

280
class ImagePair(TensorLikePair):
281
    def __init__(
282
        self,
283
        actual,
284
        expected,
285
        *,
286
        mae=False,
287
        **other_parameters,
288
    ):
289
        if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
290
            actual, expected = [to_image(input) for input in [actual, expected]]
291

292
        super().__init__(actual, expected, **other_parameters)
293
        self.mae = mae
294

295
    def compare(self) -> None:
296
        actual, expected = self.actual, self.expected
297

298
        self._compare_attributes(actual, expected)
299
        actual, expected = self._equalize_attributes(actual, expected)
300

301
        if self.mae:
302
            if actual.dtype is torch.uint8:
303
                actual, expected = actual.to(torch.int), expected.to(torch.int)
304
            mae = float(torch.abs(actual - expected).float().mean())
305
            if mae > self.atol:
306
                self._fail(
307
                    AssertionError,
308
                    f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
309
                )
310
        else:
311
            super()._compare_values(actual, expected)
312

313

314
def assert_close(
315
    actual,
316
    expected,
317
    *,
318
    allow_subclasses=True,
319
    rtol=None,
320
    atol=None,
321
    equal_nan=False,
322
    check_device=True,
323
    check_dtype=True,
324
    check_layout=True,
325
    check_stride=False,
326
    msg=None,
327
    **kwargs,
328
):
329
    """Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
330
    __tracebackhide__ = True
331

332
    error_metas = not_close_error_metas(
333
        actual,
334
        expected,
335
        pair_types=(
336
            NonePair,
337
            BooleanPair,
338
            NumberPair,
339
            ImagePair,
340
            TensorLikePair,
341
        ),
342
        allow_subclasses=allow_subclasses,
343
        rtol=rtol,
344
        atol=atol,
345
        equal_nan=equal_nan,
346
        check_device=check_device,
347
        check_dtype=check_dtype,
348
        check_layout=check_layout,
349
        check_stride=check_stride,
350
        **kwargs,
351
    )
352

353
    if error_metas:
354
        raise error_metas[0].to_error(msg)
355

356

357
assert_equal = functools.partial(assert_close, rtol=0, atol=0)
358

359

360
DEFAULT_SIZE = (17, 11)
361

362

363
NUM_CHANNELS_MAP = {
364
    "GRAY": 1,
365
    "GRAY_ALPHA": 2,
366
    "RGB": 3,
367
    "RGBA": 4,
368
}
369

370

371
def make_image(
372
    size=DEFAULT_SIZE,
373
    *,
374
    color_space="RGB",
375
    batch_dims=(),
376
    dtype=None,
377
    device="cpu",
378
    memory_format=torch.contiguous_format,
379
):
380
    num_channels = NUM_CHANNELS_MAP[color_space]
381
    dtype = dtype or torch.uint8
382
    max_value = get_max_value(dtype)
383
    data = torch.testing.make_tensor(
384
        (*batch_dims, num_channels, *size),
385
        low=0,
386
        high=max_value,
387
        dtype=dtype,
388
        device=device,
389
        memory_format=memory_format,
390
    )
391
    if color_space in {"GRAY_ALPHA", "RGBA"}:
392
        data[..., -1, :, :] = max_value
393

394
    return tv_tensors.Image(data)
395

396

397
def make_image_tensor(*args, **kwargs):
398
    return make_image(*args, **kwargs).as_subclass(torch.Tensor)
399

400

401
def make_image_pil(*args, **kwargs):
402
    return to_pil_image(make_image(*args, **kwargs))
403

404

405
def make_bounding_boxes(
406
    canvas_size=DEFAULT_SIZE,
407
    *,
408
    format=tv_tensors.BoundingBoxFormat.XYXY,
409
    num_boxes=1,
410
    dtype=None,
411
    device="cpu",
412
):
413
    def sample_position(values, max_value):
414
        # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
415
        # However, if we have batch_dims, we need tensors as limits.
416
        return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])
417

418
    if isinstance(format, str):
419
        format = tv_tensors.BoundingBoxFormat[format]
420

421
    dtype = dtype or torch.float32
422

423
    h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size]
424
    y = sample_position(h, canvas_size[0])
425
    x = sample_position(w, canvas_size[1])
426

427
    if format is tv_tensors.BoundingBoxFormat.XYWH:
428
        parts = (x, y, w, h)
429
    elif format is tv_tensors.BoundingBoxFormat.XYXY:
430
        x1, y1 = x, y
431
        x2 = x1 + w
432
        y2 = y1 + h
433
        parts = (x1, y1, x2, y2)
434
    elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
435
        cx = x + w / 2
436
        cy = y + h / 2
437
        parts = (cx, cy, w, h)
438
    else:
439
        raise ValueError(f"Format {format} is not supported")
440

441
    return tv_tensors.BoundingBoxes(
442
        torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
443
    )
444

445

446
def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):
447
    """Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
448
    return tv_tensors.Mask(
449
        torch.testing.make_tensor(
450
            (num_masks, *size),
451
            low=0,
452
            high=2,
453
            dtype=dtype or torch.bool,
454
            device=device,
455
        )
456
    )
457

458

459
def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
460
    """Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
461
    return tv_tensors.Mask(
462
        torch.testing.make_tensor(
463
            (*batch_dims, *size),
464
            low=0,
465
            high=num_categories,
466
            dtype=dtype or torch.uint8,
467
            device=device,
468
        )
469
    )
470

471

472
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
473
    return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
474

475

476
def make_video_tensor(*args, **kwargs):
477
    return make_video(*args, **kwargs).as_subclass(torch.Tensor)
478

479

480
def assert_run_python_script(source_code):
481
    """Utility to check assertions in an independent Python subprocess.
482

483
    The script provided in the source code should return 0 and not print
484
    anything on stderr or stdout. Modified from scikit-learn test utils.
485

486
    Args:
487
        source_code (str): The Python source code to execute.
488
    """
489
    with get_tmp_dir() as root:
490
        path = pathlib.Path(root) / "main.py"
491
        with open(path, "w") as file:
492
            file.write(source_code)
493

494
        try:
495
            out = check_output([sys.executable, str(path)], stderr=STDOUT)
496
        except CalledProcessError as e:
497
            raise RuntimeError(f"script errored with output:\n{e.output.decode()}")
498
        if out != b"":
499
            raise AssertionError(out.decode())
500

501

502
@contextlib.contextmanager
503
def assert_no_warnings():
504
    # The name `catch_warnings` is a misnomer as the context manager does **not** catch any warnings, but rather scopes
505
    # the warning filters. All changes that are made to the filters while in this context, will be reset upon exit.
506
    with warnings.catch_warnings():
507
        warnings.simplefilter("error")
508
        yield
509

510

511
@contextlib.contextmanager
512
def ignore_jit_no_profile_information_warning():
513
    # Calling a scripted object often triggers a warning like
514
    # `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
515
    # with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
516
    # them.
517
    with warnings.catch_warnings():
518
        warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
519
        yield
520

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.