12
from subprocess import CalledProcessError, check_output, STDOUT
21
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
22
from torchvision import io, tv_tensors
23
from torchvision.transforms._functional_tensor import _max_value as get_max_value
24
from torchvision.transforms.v2.functional import to_image, to_pil_image
27
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
28
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
29
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
30
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
31
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
32
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
35
@contextlib.contextmanager
36
def get_tmp_dir(src=None, **kwargs):
37
tmp_dir = tempfile.mkdtemp(**kwargs)
40
shutil.copytree(src, tmp_dir)
44
shutil.rmtree(tmp_dir)
47
def set_rng_seed(seed):
48
torch.manual_seed(seed)
52
class MapNestedTensorObjectImpl:
53
def __init__(self, tensor_map_fn):
54
self.tensor_map_fn = tensor_map_fn
56
def __call__(self, object):
57
if isinstance(object, torch.Tensor):
58
return self.tensor_map_fn(object)
60
elif isinstance(object, dict):
62
for key, value in object.items():
63
mapped_dict[self(key)] = self(value)
66
elif isinstance(object, (list, tuple)):
69
mapped_iter.append(self(iter))
70
return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter)
76
def map_nested_tensor_object(object, tensor_map_fn):
77
impl = MapNestedTensorObjectImpl(tensor_map_fn)
89
@contextlib.contextmanager
90
def freeze_rng_state():
91
rng_state = torch.get_rng_state()
92
if torch.cuda.is_available():
93
cuda_rng_state = torch.cuda.get_rng_state()
95
if torch.cuda.is_available():
96
torch.cuda.set_rng_state(cuda_rng_state)
97
torch.set_rng_state(rng_state)
101
for idx, obj1 in enumerate(objs):
102
for obj2 in objs[:idx] + objs[idx + 1 :]:
107
return (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
111
return (torch.float32, torch.float64)
114
@contextlib.contextmanager
115
def disable_console_output():
116
with contextlib.ExitStack() as stack, open(os.devnull, "w") as devnull:
117
stack.enter_context(contextlib.redirect_stdout(devnull))
118
stack.enter_context(contextlib.redirect_stderr(devnull))
125
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
128
def cpu_and_cuda_and_mps():
129
return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)
132
def needs_cuda(test_func):
135
return pytest.mark.needs_cuda(test_func)
138
def needs_mps(test_func):
141
return pytest.mark.needs_mps(test_func)
144
def _create_data(height=3, width=3, channels=3, device="cpu"):
145
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
146
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
147
data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
152
pil_img = Image.fromarray(data, mode=mode)
153
return tensor, pil_img
156
def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
157
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
158
batch_tensor = torch.randint(0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device)
162
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
164
for i in range(num_videos):
173
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
174
name = os.path.join(tmpdir, f"{i}.mp4")
176
io.write_video(name, data, fps=f)
181
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
182
# FIXME: this is handled automatically by `assert_equal` below. Let's remove this in favor of it
183
np_pil_image = np.array(pil_image)
184
if np_pil_image.ndim == 2:
185
np_pil_image = np_pil_image[:, :, None]
186
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
188
msg = f"tensor:\n{tensor} \ndid not equal PIL tensor:\n{pil_tensor}"
189
assert_equal(tensor.cpu(), pil_tensor, msg=msg)
192
def _assert_approx_equal_tensor_to_pil(
193
tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None
195
# FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it
196
# TODO: we could just merge this into _assert_equal_tensor_to_pil
197
np_pil_image = np.array(pil_image)
198
if np_pil_image.ndim == 2:
199
np_pil_image = np_pil_image[:, :, None]
200
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
202
if allowed_percentage_diff is not None:
203
# Assert that less than a given %age of pixels are different
204
assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
206
# error value can be mean absolute error, max abs error
207
# Convert to float to avoid underflow when computing absolute difference
208
tensor = tensor.to(torch.float)
209
pil_tensor = pil_tensor.to(torch.float)
210
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
211
assert err < tol, f"{err} vs {tol}"
214
def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
215
transformed_batch = fn(batch_tensors, **fn_kwargs)
216
for i in range(len(batch_tensors)):
217
img_tensor = batch_tensors[i, ...]
218
transformed_img = fn(img_tensor, **fn_kwargs)
219
torch.testing.assert_close(transformed_img, transformed_batch[i, ...], rtol=0, atol=1e-6)
221
if scripted_fn_atol >= 0:
222
scripted_fn = torch.jit.script(fn)
223
# scriptable function test
224
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
225
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
229
"""Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
230
but this also caches exceptions.
237
def wrapper(*args, **kwargs):
238
key = args + tuple(kwargs.values())
240
out = out_cache.get(key, sentinel)
241
if out is not sentinel:
244
exc_tb = exc_tb_cache.get(key, sentinel)
245
if exc_tb is not sentinel:
246
raise exc_tb[0].with_traceback(exc_tb[1])
249
out = fn(*args, **kwargs)
250
except Exception as exc:
251
# We need to cache the traceback here as well. Otherwise, each re-raise will add the internal pytest
252
# traceback frames anew, but they will only be removed once. Thus, the traceback will be ginormous hiding
253
# the actual information in the noise. See https://github.com/pytest-dev/pytest/issues/10363 for details.
254
exc_tb_cache[key] = exc, exc.__traceback__
263
def combinations_grid(**kwargs):
264
"""Creates a grid of input combinations.
266
Each element in the returned sequence is a dictionary containing one possible combination as values.
269
>>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
271
{'foo': 'bar', 'spam': 'eggs'},
272
{'foo': 'bar', 'spam': 'ham'},
273
{'foo': 'baz', 'spam': 'eggs'},
274
{'foo': 'baz', 'spam': 'ham'}
277
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
280
class ImagePair(TensorLikePair):
289
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
290
actual, expected = [to_image(input) for input in [actual, expected]]
292
super().__init__(actual, expected, **other_parameters)
295
def compare(self) -> None:
296
actual, expected = self.actual, self.expected
298
self._compare_attributes(actual, expected)
299
actual, expected = self._equalize_attributes(actual, expected)
302
if actual.dtype is torch.uint8:
303
actual, expected = actual.to(torch.int), expected.to(torch.int)
304
mae = float(torch.abs(actual - expected).float().mean())
308
f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
311
super()._compare_values(actual, expected)
318
allow_subclasses=True,
329
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
330
__tracebackhide__ = True
332
error_metas = not_close_error_metas(
342
allow_subclasses=allow_subclasses,
346
check_device=check_device,
347
check_dtype=check_dtype,
348
check_layout=check_layout,
349
check_stride=check_stride,
354
raise error_metas[0].to_error(msg)
357
assert_equal = functools.partial(assert_close, rtol=0, atol=0)
360
DEFAULT_SIZE = (17, 11)
378
memory_format=torch.contiguous_format,
380
num_channels = NUM_CHANNELS_MAP[color_space]
381
dtype = dtype or torch.uint8
382
max_value = get_max_value(dtype)
383
data = torch.testing.make_tensor(
384
(*batch_dims, num_channels, *size),
389
memory_format=memory_format,
391
if color_space in {"GRAY_ALPHA", "RGBA"}:
392
data[..., -1, :, :] = max_value
394
return tv_tensors.Image(data)
397
def make_image_tensor(*args, **kwargs):
398
return make_image(*args, **kwargs).as_subclass(torch.Tensor)
401
def make_image_pil(*args, **kwargs):
402
return to_pil_image(make_image(*args, **kwargs))
405
def make_bounding_boxes(
406
canvas_size=DEFAULT_SIZE,
408
format=tv_tensors.BoundingBoxFormat.XYXY,
413
def sample_position(values, max_value):
414
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
415
# However, if we have batch_dims, we need tensors as limits.
416
return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])
418
if isinstance(format, str):
419
format = tv_tensors.BoundingBoxFormat[format]
421
dtype = dtype or torch.float32
423
h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size]
424
y = sample_position(h, canvas_size[0])
425
x = sample_position(w, canvas_size[1])
427
if format is tv_tensors.BoundingBoxFormat.XYWH:
429
elif format is tv_tensors.BoundingBoxFormat.XYXY:
433
parts = (x1, y1, x2, y2)
434
elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
437
parts = (cx, cy, w, h)
439
raise ValueError(f"Format {format} is not supported")
441
return tv_tensors.BoundingBoxes(
442
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
446
def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):
447
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
448
return tv_tensors.Mask(
449
torch.testing.make_tensor(
453
dtype=dtype or torch.bool,
459
def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
460
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
461
return tv_tensors.Mask(
462
torch.testing.make_tensor(
463
(*batch_dims, *size),
466
dtype=dtype or torch.uint8,
472
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
473
return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
476
def make_video_tensor(*args, **kwargs):
477
return make_video(*args, **kwargs).as_subclass(torch.Tensor)
480
def assert_run_python_script(source_code):
481
"""Utility to check assertions in an independent Python subprocess.
483
The script provided in the source code should return 0 and not print
484
anything on stderr or stdout. Modified from scikit-learn test utils.
487
source_code (str): The Python source code to execute.
489
with get_tmp_dir() as root:
490
path = pathlib.Path(root) / "main.py"
491
with open(path, "w") as file:
492
file.write(source_code)
495
out = check_output([sys.executable, str(path)], stderr=STDOUT)
496
except CalledProcessError as e:
497
raise RuntimeError(f"script errored with output:\n{e.output.decode()}")
499
raise AssertionError(out.decode())
502
@contextlib.contextmanager
503
def assert_no_warnings():
504
# The name `catch_warnings` is a misnomer as the context manager does **not** catch any warnings, but rather scopes
505
# the warning filters. All changes that are made to the filters while in this context, will be reset upon exit.
506
with warnings.catch_warnings():
507
warnings.simplefilter("error")
511
@contextlib.contextmanager
512
def ignore_jit_no_profile_information_warning():
513
# Calling a scripted object often triggers a warning like
514
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
515
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
517
with warnings.catch_warnings():
518
warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)