9
from collections import OrderedDict
10
from tempfile import TemporaryDirectory
17
from _utils_internal import get_relative_path
18
from common_utils import cpu_and_cuda, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
20
from torchvision import models, transforms
21
from torchvision.models import get_model_builder, list_models
24
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
25
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
28
def list_model_fns(module):
29
return [get_model_builder(name) for name in list_models(module)]
32
def _get_image(input_shape, real_image, device, dtype=None):
33
"""This routine loads a real or random image based on `real_image` argument.
34
Currently, the real image is utilized for the following list of models:
35
- `retinanet_resnet50_fpn`,
36
- `retinanet_resnet50_fpn_v2`,
37
- `keypointrcnn_resnet50_fpn`,
38
- `fasterrcnn_resnet50_fpn`,
39
- `fasterrcnn_resnet50_fpn_v2`,
40
- `fcos_resnet50_fpn`,
41
- `maskrcnn_resnet50_fpn`,
42
- `maskrcnn_resnet50_fpn_v2`,
43
in `test_classification_model` and `test_detection_model`.
44
To do so, a keyword argument `real_image` was added to the abovelisted models in `_model_params`
47
# TODO: Maybe unify file discovery logic with test_image.py
48
GRACE_HOPPER = os.path.join(
49
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
52
img = Image.open(GRACE_HOPPER)
54
original_width, original_height = img.size
56
# make the image square
57
img = img.crop((0, 0, original_width, original_width))
58
img = img.resize(input_shape[1:3])
60
convert_tensor = transforms.ToTensor()
61
image = convert_tensor(img)
62
assert tuple(image.size()) == input_shape
63
return image.to(device=device, dtype=dtype)
65
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
66
return torch.rand(input_shape).to(device=device, dtype=dtype)
70
def disable_weight_loading(mocker):
71
"""When testing models, the two slowest operations are the downloading of the weights to a file and loading them
72
into the model. Unless, you want to test against specific weights, these steps can be disabled without any
75
Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
76
through all models in `torchvision.models` and will patch all occurrences of the function
77
`download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
82
Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
83
fixture if you want to compare the model output against reference values.
86
starting_point = models
87
function_name = "load_state_dict_from_url"
88
method_name = "load_state_dict"
90
module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")}
91
targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"}
92
for name in module_names:
93
module = sys.modules.get(name)
97
if function_name in module.__dict__:
98
targets.add(f"{module.__name__}.{function_name}")
102
f"{module.__name__}.{obj.__name__}.{method_name}"
103
for obj in module.__dict__.values()
104
if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__
108
for target in targets:
109
# See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
110
with contextlib.suppress(AttributeError):
114
def _get_expected_file(name=None):
115
# Determine expected file based on environment
116
expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")
118
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
119
# We hardcode it here to avoid having to re-generate the reference files
120
expected_file = os.path.join(expected_file_base, "ModelTester.test_" + name)
121
expected_file += "_expect.pkl"
123
if not ACCEPT and not os.path.exists(expected_file):
125
f"No expect file exists for {os.path.basename(expected_file)} in {expected_file}; "
126
"to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
127
"env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
133
def _assert_expected(output, name, prec=None, atol=None, rtol=None):
134
"""Test that a python value matches the recorded contents of a file
135
based on a "check" name. The value must be
136
pickable with `torch.save`. This file
137
is placed in the 'expect' directory in the same directory
138
as the test script. You can automatically update the recorded test
139
output using an EXPECTTEST_ACCEPT=1 env variable.
141
expected_file = _get_expected_file(name)
144
filename = {os.path.basename(expected_file)}
145
print(f"Accepting updated output for {filename}:\n\n{output}")
146
torch.save(output, expected_file)
147
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
148
binary_size = os.path.getsize(expected_file)
149
if binary_size > MAX_PICKLE_SIZE:
150
raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
152
expected = torch.load(expected_file, weights_only=True)
153
rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally
155
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, check_device=False)
158
def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None):
159
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
161
def get_export_import_copy(m):
162
"""Save and load a TorchScript model"""
163
with TemporaryDirectory() as dir:
164
path = os.path.join(dir, "script.pt")
166
imported = torch.jit.load(path)
169
sm = torch.jit.script(nn_module)
172
if eager_out is None:
173
with torch.no_grad(), freeze_rng_state():
174
eager_out = nn_module(*args)
176
with torch.no_grad(), freeze_rng_state():
177
script_out = sm(*args)
179
script_out = unwrapper(script_out)
181
torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
183
m_import = get_export_import_copy(sm)
184
with torch.no_grad(), freeze_rng_state():
185
imported_script_out = m_import(*args)
187
imported_script_out = unwrapper(imported_script_out)
189
torch.testing.assert_close(script_out, imported_script_out, atol=3e-4, rtol=3e-4)
192
def _check_fx_compatible(model, inputs, eager_out=None):
193
model_fx = torch.fx.symbolic_trace(model)
194
if eager_out is None:
195
eager_out = model(inputs)
196
with torch.no_grad(), freeze_rng_state():
197
fx_out = model_fx(inputs)
198
torch.testing.assert_close(eager_out, fx_out)
201
def _check_input_backprop(model, inputs):
202
if isinstance(inputs, list):
203
requires_grad = list()
205
requires_grad.append(inp.requires_grad)
206
inp.requires_grad_(True)
208
requires_grad = inputs.requires_grad
209
inputs.requires_grad_(True)
213
if isinstance(out, dict):
214
out["out"].sum().backward()
216
if isinstance(out[0], dict):
217
out[0]["scores"].sum().backward()
219
out[0].sum().backward()
221
if isinstance(inputs, list):
222
for i, inp in enumerate(inputs):
223
assert inputs[i].grad is not None
224
inp.requires_grad_(requires_grad[i])
226
assert inputs.grad is not None
227
inputs.requires_grad_(requires_grad)
230
# If 'unwrapper' is provided it will be called with the script model outputs
231
# before they are compared to the eager model outputs. This is useful if the
232
# model outputs are different between TorchScript / Eager mode
233
script_model_unwrapper = {
234
"googlenet": lambda x: x.logits,
235
"inception_v3": lambda x: x.logits,
236
"fasterrcnn_resnet50_fpn": lambda x: x[1],
237
"fasterrcnn_resnet50_fpn_v2": lambda x: x[1],
238
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
239
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
240
"maskrcnn_resnet50_fpn": lambda x: x[1],
241
"maskrcnn_resnet50_fpn_v2": lambda x: x[1],
242
"keypointrcnn_resnet50_fpn": lambda x: x[1],
243
"retinanet_resnet50_fpn": lambda x: x[1],
244
"retinanet_resnet50_fpn_v2": lambda x: x[1],
245
"ssd300_vgg16": lambda x: x[1],
246
"ssdlite320_mobilenet_v3_large": lambda x: x[1],
247
"fcos_resnet50_fpn": lambda x: x[1],
251
# The following models exhibit flaky numerics under autocast in _test_*_model harnesses.
252
# This may be caused by the harness environment (e.g. num classes, input initialization
253
# via torch.rand), and does not prove autocast is unsuitable when training with real data
254
# (autocast has been used successfully with real data for some of these models).
255
# TODO: investigate why autocast numerics are flaky in the harnesses.
257
# For the following models, _test_*_model harnesses skip numerical checks on outputs when
258
# trying autocast. However, they still try an autocasted forward pass, so they still ensure
259
# autocast coverage suffices to prevent dtype errors in each model.
260
autocast_flaky_numerics = (
265
"deeplabv3_resnet50",
266
"deeplabv3_resnet101",
267
"deeplabv3_mobilenet_v3_large",
270
"lraspp_mobilenet_v3_large",
271
"maskrcnn_resnet50_fpn",
272
"maskrcnn_resnet50_fpn_v2",
273
"keypointrcnn_resnet50_fpn",
276
# The tests for the following quantized models are flaky possibly due to inconsistent
277
# rounding errors in different platforms. For this reason the input/output consistency
278
# tests under test_quantized_classification_model will be skipped for the following models.
279
quantized_flaky_models = ("inception_v3", "resnet50")
281
# The tests for the following detection models are flaky.
282
# We run those tests on float64 to avoid floating point errors.
283
# FIXME: we shouldn't have to do that :'/
284
detection_flaky_models = ("keypointrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn_v2")
287
# The following contains configuration parameters for all models which are used by
288
# the _test_*_model methods.
290
"inception_v3": {"input_shape": (1, 3, 299, 299), "init_weights": True},
291
"retinanet_resnet50_fpn": {
293
"score_thresh": 0.01,
296
"input_shape": (3, 224, 224),
299
"retinanet_resnet50_fpn_v2": {
301
"score_thresh": 0.01,
304
"input_shape": (3, 224, 224),
307
"keypointrcnn_resnet50_fpn": {
311
"box_score_thresh": 0.17,
312
"input_shape": (3, 224, 224),
315
"fasterrcnn_resnet50_fpn": {
319
"input_shape": (3, 224, 224),
322
"fasterrcnn_resnet50_fpn_v2": {
326
"input_shape": (3, 224, 224),
329
"fcos_resnet50_fpn": {
331
"score_thresh": 0.05,
334
"input_shape": (3, 224, 224),
337
"maskrcnn_resnet50_fpn": {
341
"input_shape": (3, 224, 224),
344
"maskrcnn_resnet50_fpn_v2": {
348
"input_shape": (3, 224, 224),
351
"fasterrcnn_mobilenet_v3_large_fpn": {
352
"box_score_thresh": 0.02076,
354
"fasterrcnn_mobilenet_v3_large_320_fpn": {
355
"box_score_thresh": 0.02076,
356
"rpn_pre_nms_top_n_test": 1000,
357
"rpn_post_nms_top_n_test": 1000,
361
"input_shape": (1, 3, 56, 56),
364
"input_shape": (1, 3, 16, 224, 224),
367
"input_shape": (1, 3, 16, 224, 224),
370
"input_shape": (1, 3, 16, 224, 224),
372
"googlenet": {"init_weights": True},
374
# speeding up slow models:
398
_model_params[m] = {"input_shape": (1, 3, 64, 64)}
401
# skip big models to reduce memory usage on CI test. We can exclude combinations of (platform-system, device).
402
skipped_big_models = {
403
"vit_h_14": {("Windows", "cpu"), ("Windows", "cuda")},
404
"regnet_y_128gf": {("Windows", "cpu"), ("Windows", "cuda")},
405
"mvit_v1_b": {("Windows", "cuda"), ("Linux", "cuda")},
406
"mvit_v2_s": {("Windows", "cuda"), ("Linux", "cuda")},
410
def is_skippable(model_name, device):
411
if model_name not in skipped_big_models:
414
platform_system = platform.system()
415
device_name = str(device).split(":")[0]
417
return (platform_system, device_name) in skipped_big_models[model_name]
420
# The following contains configuration and expected values to be used tests that are model specific
421
_model_tests_values = {
422
"retinanet_resnet50_fpn": {
424
"n_trn_params_per_layer": [36, 46, 65, 78, 88, 89],
426
"retinanet_resnet50_fpn_v2": {
428
"n_trn_params_per_layer": [44, 74, 131, 170, 200, 203],
430
"keypointrcnn_resnet50_fpn": {
432
"n_trn_params_per_layer": [48, 58, 77, 90, 100, 101],
434
"fasterrcnn_resnet50_fpn": {
436
"n_trn_params_per_layer": [30, 40, 59, 72, 82, 83],
438
"fasterrcnn_resnet50_fpn_v2": {
440
"n_trn_params_per_layer": [50, 80, 137, 176, 206, 209],
442
"maskrcnn_resnet50_fpn": {
444
"n_trn_params_per_layer": [42, 52, 71, 84, 94, 95],
446
"maskrcnn_resnet50_fpn_v2": {
448
"n_trn_params_per_layer": [66, 96, 153, 192, 222, 225],
450
"fasterrcnn_mobilenet_v3_large_fpn": {
452
"n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
454
"fasterrcnn_mobilenet_v3_large_320_fpn": {
456
"n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
460
"n_trn_params_per_layer": [45, 51, 57, 63, 67, 71],
462
"ssdlite320_mobilenet_v3_large": {
464
"n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266],
466
"fcos_resnet50_fpn": {
468
"n_trn_params_per_layer": [54, 64, 83, 96, 106, 107],
473
def _make_sliced_model(model, stop_layer):
474
layers = OrderedDict()
475
for name, layer in model.named_children():
477
if name == stop_layer:
479
new_model = torch.nn.Sequential(layers)
483
@pytest.mark.parametrize("model_fn", [models.densenet121, models.densenet169, models.densenet201, models.densenet161])
484
def test_memory_efficient_densenet(model_fn):
485
input_shape = (1, 3, 300, 300)
486
x = torch.rand(input_shape)
488
model1 = model_fn(num_classes=50, memory_efficient=True)
489
params = model1.state_dict()
490
num_params = sum(x.numel() for x in model1.parameters())
493
out1.sum().backward()
494
num_grad = sum(x.grad.numel() for x in model1.parameters() if x.grad is not None)
496
model2 = model_fn(num_classes=50, memory_efficient=False)
497
model2.load_state_dict(params)
501
assert num_params == num_grad
502
torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5)
504
_check_input_backprop(model1, x)
505
_check_input_backprop(model2, x)
508
@pytest.mark.parametrize("dilate_layer_2", (True, False))
509
@pytest.mark.parametrize("dilate_layer_3", (True, False))
510
@pytest.mark.parametrize("dilate_layer_4", (True, False))
511
def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4):
512
# TODO improve tests to also check that each layer has the right dimensionality
513
model = models.resnet50(replace_stride_with_dilation=(dilate_layer_2, dilate_layer_3, dilate_layer_4))
514
model = _make_sliced_model(model, stop_layer="layer4")
516
x = torch.rand(1, 3, 224, 224)
518
f = 2 ** sum((dilate_layer_2, dilate_layer_3, dilate_layer_4))
519
assert out.shape == (1, 2048, 7 * f, 7 * f)
522
def test_mobilenet_v2_residual_setting():
523
model = models.mobilenet_v2(inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]])
525
x = torch.rand(1, 3, 224, 224)
527
assert out.shape[-1] == 1000
530
@pytest.mark.parametrize("model_fn", [models.mobilenet_v2, models.mobilenet_v3_large, models.mobilenet_v3_small])
531
def test_mobilenet_norm_layer(model_fn):
533
assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules())
535
def get_gn(num_channels):
536
return nn.GroupNorm(1, num_channels)
538
model = model_fn(norm_layer=get_gn)
539
assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
540
assert any(isinstance(x, nn.GroupNorm) for x in model.modules())
543
def test_inception_v3_eval():
545
kwargs["transform_input"] = True
546
kwargs["aux_logits"] = True
547
kwargs["init_weights"] = False
548
name = "inception_v3"
549
model = models.Inception3(**kwargs)
550
model.aux_logits = False
551
model.AuxLogits = None
553
x = torch.rand(1, 3, 299, 299)
554
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
555
_check_input_backprop(model, x)
558
def test_fasterrcnn_double():
559
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None)
562
input_shape = (3, 300, 300)
563
x = torch.rand(input_shape, dtype=torch.float64)
565
out = model(model_input)
566
assert model_input[0] is x
568
assert "boxes" in out[0]
569
assert "scores" in out[0]
570
assert "labels" in out[0]
571
_check_input_backprop(model, model_input)
574
def test_googlenet_eval():
576
kwargs["transform_input"] = True
577
kwargs["aux_logits"] = True
578
kwargs["init_weights"] = False
580
model = models.GoogLeNet(**kwargs)
581
model.aux_logits = False
585
x = torch.rand(1, 3, 224, 224)
586
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
587
_check_input_backprop(model, x)
591
def test_fasterrcnn_switch_devices():
594
assert "boxes" in out[0]
595
assert "scores" in out[0]
596
assert "labels" in out[0]
598
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None)
601
input_shape = (3, 300, 300)
602
x = torch.rand(input_shape, device="cuda")
604
out = model(model_input)
605
assert model_input[0] is x
609
with torch.cuda.amp.autocast():
610
out = model(model_input)
614
_check_input_backprop(model, model_input)
616
# now switch to cpu and make sure it works
623
_check_input_backprop(model, [x])
626
def test_generalizedrcnn_transform_repr():
628
min_size, max_size = 224, 299
629
image_mean = [0.485, 0.456, 0.406]
630
image_std = [0.229, 0.224, 0.225]
632
t = models.detection.transform.GeneralizedRCNNTransform(
633
min_size=min_size, max_size=max_size, image_mean=image_mean, image_std=image_std
636
# Check integrity of object __repr__ attribute
637
expected_string = "GeneralizedRCNNTransform("
639
expected_string += f"{_indent}Normalize(mean={image_mean}, std={image_std})"
640
expected_string += f"{_indent}Resize(min_size=({min_size},), max_size={max_size}, "
641
expected_string += "mode='bilinear')\n)"
642
assert t.__repr__() == expected_string
645
test_vit_conv_stem_configs = [
646
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=64),
647
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=128),
648
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=128),
649
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=256),
650
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=256),
651
models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=512),
655
def vitc_b_16(**kwargs: Any):
656
return models.VisionTransformer(
663
conv_stem_configs=test_vit_conv_stem_configs,
668
@pytest.mark.parametrize("model_fn", [vitc_b_16])
669
@pytest.mark.parametrize("dev", cpu_and_cuda())
670
def test_vitc_models(model_fn, dev):
671
test_classification_model(model_fn, dev)
674
@torch.backends.cudnn.flags(allow_tf32=False) # see: https://github.com/pytorch/vision/issues/7618
675
@pytest.mark.parametrize("model_fn", list_model_fns(models))
676
@pytest.mark.parametrize("dev", cpu_and_cuda())
677
def test_classification_model(model_fn, dev):
681
"input_shape": (1, 3, 224, 224),
683
model_name = model_fn.__name__
684
if SKIP_BIG_MODEL and is_skippable(model_name, dev):
685
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
686
kwargs = {**defaults, **_model_params.get(model_name, {})}
687
num_classes = kwargs.get("num_classes")
688
input_shape = kwargs.pop("input_shape")
689
real_image = kwargs.pop("real_image", False)
691
model = model_fn(**kwargs)
692
model.eval().to(device=dev)
693
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
695
# FIXME: this if/else is nasty and only here to please our CI prior to the
696
# release. We rethink these tests altogether.
697
if model_name == "resnet101":
700
# FIXME: this is probably still way too high.
702
_assert_expected(out.cpu(), model_name, prec=prec)
703
assert out.shape[-1] == num_classes
704
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
705
_check_fx_compatible(model, x, eager_out=out)
708
with torch.cuda.amp.autocast():
710
# See autocast_flaky_numerics comment at top of file.
711
if model_name not in autocast_flaky_numerics:
712
_assert_expected(out.cpu(), model_name, prec=0.1)
713
assert out.shape[-1] == 50
715
_check_input_backprop(model, x)
718
@pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation))
719
@pytest.mark.parametrize("dev", cpu_and_cuda())
720
def test_segmentation_model(model_fn, dev):
724
"weights_backbone": None,
725
"input_shape": (1, 3, 32, 32),
727
model_name = model_fn.__name__
728
kwargs = {**defaults, **_model_params.get(model_name, {})}
729
input_shape = kwargs.pop("input_shape")
731
model = model_fn(**kwargs)
732
model.eval().to(device=dev)
733
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
734
x = torch.rand(input_shape).to(device=dev)
735
with torch.no_grad(), freeze_rng_state():
741
# We first try to assert the entire output if possible. This is not
742
# only the best way to assert results but also handles the cases
743
# where we need to create a new expected result.
744
_assert_expected(out.cpu(), model_name, prec=prec)
745
except AssertionError:
746
# Unfortunately some segmentation models are flaky with autocast
747
# so instead of validating the probability scores, check that the class
749
expected_file = _get_expected_file(model_name)
750
expected = torch.load(expected_file, weights_only=True)
751
torch.testing.assert_close(
752
out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False
754
return False # Partial validation performed
756
return True # Full validation performed
758
full_validation = check_out(out["out"])
760
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
761
_check_fx_compatible(model, x, eager_out=out)
764
with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
766
# See autocast_flaky_numerics comment at top of file.
767
if model_name not in autocast_flaky_numerics:
768
full_validation &= check_out(out["out"])
770
if not full_validation:
772
f"The output of {test_segmentation_model.__name__} could only be partially validated. "
773
"This is likely due to unit-test flakiness, but you may "
774
"want to do additional manual checks if you made "
775
"significant changes to the codebase."
777
warnings.warn(msg, RuntimeWarning)
780
_check_input_backprop(model, x)
783
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
784
@pytest.mark.parametrize("dev", cpu_and_cuda())
785
def test_detection_model(model_fn, dev):
789
"weights_backbone": None,
790
"input_shape": (3, 300, 300),
792
model_name = model_fn.__name__
793
if model_name in detection_flaky_models:
794
dtype = torch.float64
796
dtype = torch.get_default_dtype()
797
kwargs = {**defaults, **_model_params.get(model_name, {})}
798
input_shape = kwargs.pop("input_shape")
799
real_image = kwargs.pop("real_image", False)
801
model = model_fn(**kwargs)
802
model.eval().to(device=dev, dtype=dtype)
803
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev, dtype=dtype)
805
with torch.no_grad(), freeze_rng_state():
806
out = model(model_input)
807
assert model_input[0] is x
813
tensor = tensor.cpu()
815
elements_per_sample = functools.reduce(operator.mul, size[1:], 1)
816
if elements_per_sample > 30:
817
return compute_mean_std(tensor)
819
return subsample_tensor(tensor)
821
def subsample_tensor(tensor):
822
num_elems = tensor.size(0)
824
if num_elems <= num_samples:
827
ith_index = num_elems // num_samples
828
return tensor[ith_index - 1 :: ith_index]
830
def compute_mean_std(tensor):
831
# can't compute mean of integral tensor
832
tensor = tensor.to(torch.double)
833
mean = torch.mean(tensor)
834
std = torch.std(tensor)
835
return {"mean": mean, "std": std}
837
output = map_nested_tensor_object(out, tensor_map_fn=compact)
840
# We first try to assert the entire output if possible. This is not
841
# only the best way to assert results but also handles the cases
842
# where we need to create a new expected result.
843
_assert_expected(output, model_name, prec=prec)
844
except AssertionError:
845
# Unfortunately detection models are flaky due to the unstable sort
846
# in NMS. If matching across all outputs fails, use the same approach
847
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate
849
expected_file = _get_expected_file(model_name)
850
expected = torch.load(expected_file, weights_only=True)
851
torch.testing.assert_close(
852
output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False
855
# Note: Fmassa proposed turning off NMS by adapting the threshold
856
# and then using the Hungarian algorithm as in DETR to find the
857
# best match between output and expected boxes and eliminate some
858
# of the flakiness. Worth exploring.
859
return False # Partial validation performed
861
return True # Full validation performed
863
full_validation = check_out(out)
864
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
867
with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
868
out = model(model_input)
869
# See autocast_flaky_numerics comment at top of file.
870
if model_name not in autocast_flaky_numerics:
871
full_validation &= check_out(out)
873
if not full_validation:
875
f"The output of {test_detection_model.__name__} could only be partially validated. "
876
"This is likely due to unit-test flakiness, but you may "
877
"want to do additional manual checks if you made "
878
"significant changes to the codebase."
880
warnings.warn(msg, RuntimeWarning)
883
_check_input_backprop(model, model_input)
886
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
887
def test_detection_model_validation(model_fn):
889
model = model_fn(num_classes=50, weights=None, weights_backbone=None)
890
input_shape = (3, 300, 300)
891
x = [torch.rand(input_shape)]
893
# validate that targets are present in training
894
with pytest.raises(AssertionError):
898
targets = [{"boxes": 0.0}]
899
with pytest.raises(AssertionError):
900
model(x, targets=targets)
902
# validate boxes shape
903
for boxes in (torch.rand((4,)), torch.rand((1, 5))):
904
targets = [{"boxes": boxes}]
905
with pytest.raises(AssertionError):
906
model(x, targets=targets)
908
# validate that no degenerate boxes are present
909
boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]])
910
targets = [{"boxes": boxes}]
911
with pytest.raises(AssertionError):
912
model(x, targets=targets)
915
@pytest.mark.parametrize("model_fn", list_model_fns(models.video))
916
@pytest.mark.parametrize("dev", cpu_and_cuda())
917
def test_video_model(model_fn, dev):
919
# the default input shape is
920
# bs * num_channels * clip_len * h *w
922
"input_shape": (1, 3, 4, 112, 112),
925
model_name = model_fn.__name__
926
if SKIP_BIG_MODEL and is_skippable(model_name, dev):
927
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
928
kwargs = {**defaults, **_model_params.get(model_name, {})}
929
num_classes = kwargs.get("num_classes")
930
input_shape = kwargs.pop("input_shape")
931
# test both basicblock and Bottleneck
932
model = model_fn(**kwargs)
933
model.eval().to(device=dev)
934
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
935
x = torch.rand(input_shape).to(device=dev)
937
_assert_expected(out.cpu(), model_name, prec=0.1)
938
assert out.shape[-1] == num_classes
939
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
940
_check_fx_compatible(model, x, eager_out=out)
941
assert out.shape[-1] == num_classes
944
with torch.cuda.amp.autocast():
946
# See autocast_flaky_numerics comment at top of file.
947
if model_name not in autocast_flaky_numerics:
948
_assert_expected(out.cpu(), model_name, prec=0.1)
949
assert out.shape[-1] == num_classes
951
_check_input_backprop(model, x)
956
"fbgemm" in torch.backends.quantized.supported_engines
957
and "qnnpack" in torch.backends.quantized.supported_engines
959
reason="This Pytorch Build has not been built with fbgemm and qnnpack",
961
@pytest.mark.parametrize("model_fn", list_model_fns(models.quantization))
962
def test_quantized_classification_model(model_fn):
966
"input_shape": (1, 3, 224, 224),
969
model_name = model_fn.__name__
970
kwargs = {**defaults, **_model_params.get(model_name, {})}
971
input_shape = kwargs.pop("input_shape")
973
# First check if quantize=True provides models that can run with input data
974
model = model_fn(**kwargs)
976
x = torch.rand(input_shape)
979
if model_name not in quantized_flaky_models:
980
_assert_expected(out.cpu(), model_name + "_quantized", prec=2e-2)
981
assert out.shape[-1] == 5
982
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
983
_check_fx_compatible(model, x, eager_out=out)
986
torch.jit.script(model)
987
except Exception as e:
988
raise AssertionError("model cannot be scripted.") from e
990
kwargs["quantize"] = False
991
for eval_mode in [True, False]:
992
model = model_fn(**kwargs)
995
model.qconfig = torch.ao.quantization.default_qconfig
998
model.qconfig = torch.ao.quantization.default_qat_qconfig
1000
model.fuse_model(is_qat=not eval_mode)
1002
torch.ao.quantization.prepare(model, inplace=True)
1004
torch.ao.quantization.prepare_qat(model, inplace=True)
1007
torch.ao.quantization.convert(model, inplace=True)
1010
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
1011
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
1012
model_name = model_fn.__name__
1013
max_trainable = _model_tests_values[model_name]["max_trainable"]
1014
n_trainable_params = []
1015
for trainable_layers in range(0, max_trainable + 1):
1016
model = model_fn(weights=None, weights_backbone="DEFAULT", trainable_backbone_layers=trainable_layers)
1018
n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad]))
1019
assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"]
1023
@pytest.mark.parametrize("model_fn", list_model_fns(models.optical_flow))
1024
@pytest.mark.parametrize("scripted", (False, True))
1025
def test_raft(model_fn, scripted):
1027
torch.manual_seed(0)
1029
# We need very small images, otherwise the pickle size would exceed the 50KB
1030
# As a result we need to override the correlation pyramid to not downsample
1031
# too much, otherwise we would get nan values (effective H and W would be
1033
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)
1035
model = model_fn(corr_block=corr_block).eval().to("cuda")
1037
model = torch.jit.script(model)
1040
img1 = torch.rand(bs, 3, 80, 72).cuda()
1041
img2 = torch.rand(bs, 3, 80, 72).cuda()
1043
preds = model(img1, img2)
1044
flow_pred = preds[-1]
1045
# Tolerance is fairly high, but there are 2 * H * W outputs to check
1046
# The .pkl were generated on the AWS cluter, on the CI it looks like the results are slightly different
1047
_assert_expected(flow_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1)
1050
if __name__ == "__main__":
1051
pytest.main([__file__])