vision

Форк
0
/
test_models.py 
1051 строка · 36.8 Кб
1
import contextlib
2
import functools
3
import operator
4
import os
5
import pkgutil
6
import platform
7
import sys
8
import warnings
9
from collections import OrderedDict
10
from tempfile import TemporaryDirectory
11
from typing import Any
12

13
import pytest
14
import torch
15
import torch.fx
16
import torch.nn as nn
17
from _utils_internal import get_relative_path
18
from common_utils import cpu_and_cuda, freeze_rng_state, map_nested_tensor_object, needs_cuda, set_rng_seed
19
from PIL import Image
20
from torchvision import models, transforms
21
from torchvision.models import get_model_builder, list_models
22

23

24
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
25
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
26

27

28
def list_model_fns(module):
29
    return [get_model_builder(name) for name in list_models(module)]
30

31

32
def _get_image(input_shape, real_image, device, dtype=None):
33
    """This routine loads a real or random image based on `real_image` argument.
34
    Currently, the real image is utilized for the following list of models:
35
    - `retinanet_resnet50_fpn`,
36
    - `retinanet_resnet50_fpn_v2`,
37
    - `keypointrcnn_resnet50_fpn`,
38
    - `fasterrcnn_resnet50_fpn`,
39
    - `fasterrcnn_resnet50_fpn_v2`,
40
    - `fcos_resnet50_fpn`,
41
    - `maskrcnn_resnet50_fpn`,
42
    - `maskrcnn_resnet50_fpn_v2`,
43
    in `test_classification_model` and `test_detection_model`.
44
    To do so, a keyword argument `real_image` was added to the abovelisted models in `_model_params`
45
    """
46
    if real_image:
47
        # TODO: Maybe unify file discovery logic with test_image.py
48
        GRACE_HOPPER = os.path.join(
49
            os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
50
        )
51

52
        img = Image.open(GRACE_HOPPER)
53

54
        original_width, original_height = img.size
55

56
        # make the image square
57
        img = img.crop((0, 0, original_width, original_width))
58
        img = img.resize(input_shape[1:3])
59

60
        convert_tensor = transforms.ToTensor()
61
        image = convert_tensor(img)
62
        assert tuple(image.size()) == input_shape
63
        return image.to(device=device, dtype=dtype)
64

65
    # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
66
    return torch.rand(input_shape).to(device=device, dtype=dtype)
67

68

69
@pytest.fixture
70
def disable_weight_loading(mocker):
71
    """When testing models, the two slowest operations are the downloading of the weights to a file and loading them
72
    into the model. Unless, you want to test against specific weights, these steps can be disabled without any
73
    drawbacks.
74

75
    Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
76
    through all models in `torchvision.models` and will patch all occurrences of the function
77
    `download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
78
    no-ops.
79

80
    .. warning:
81

82
        Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
83
        fixture if you want to compare the model output against reference values.
84

85
    """
86
    starting_point = models
87
    function_name = "load_state_dict_from_url"
88
    method_name = "load_state_dict"
89

90
    module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")}
91
    targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"}
92
    for name in module_names:
93
        module = sys.modules.get(name)
94
        if not module:
95
            continue
96

97
        if function_name in module.__dict__:
98
            targets.add(f"{module.__name__}.{function_name}")
99

100
        targets.update(
101
            {
102
                f"{module.__name__}.{obj.__name__}.{method_name}"
103
                for obj in module.__dict__.values()
104
                if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__
105
            }
106
        )
107

108
    for target in targets:
109
        # See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
110
        with contextlib.suppress(AttributeError):
111
            mocker.patch(target)
112

113

114
def _get_expected_file(name=None):
115
    # Determine expected file based on environment
116
    expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")
117

118
    # Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
119
    # We hardcode it here to avoid having to re-generate the reference files
120
    expected_file = os.path.join(expected_file_base, "ModelTester.test_" + name)
121
    expected_file += "_expect.pkl"
122

123
    if not ACCEPT and not os.path.exists(expected_file):
124
        raise RuntimeError(
125
            f"No expect file exists for {os.path.basename(expected_file)} in {expected_file}; "
126
            "to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
127
            "env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
128
        )
129

130
    return expected_file
131

132

133
def _assert_expected(output, name, prec=None, atol=None, rtol=None):
134
    """Test that a python value matches the recorded contents of a file
135
    based on a "check" name. The value must be
136
    pickable with `torch.save`. This file
137
    is placed in the 'expect' directory in the same directory
138
    as the test script. You can automatically update the recorded test
139
    output using an EXPECTTEST_ACCEPT=1 env variable.
140
    """
141
    expected_file = _get_expected_file(name)
142

143
    if ACCEPT:
144
        filename = {os.path.basename(expected_file)}
145
        print(f"Accepting updated output for {filename}:\n\n{output}")
146
        torch.save(output, expected_file)
147
        MAX_PICKLE_SIZE = 50 * 1000  # 50 KB
148
        binary_size = os.path.getsize(expected_file)
149
        if binary_size > MAX_PICKLE_SIZE:
150
            raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
151
    else:
152
        expected = torch.load(expected_file, weights_only=True)
153
        rtol = rtol or prec  # keeping prec param for legacy reason, but could be removed ideally
154
        atol = atol or prec
155
        torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, check_device=False)
156

157

158
def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None):
159
    """Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
160

161
    def get_export_import_copy(m):
162
        """Save and load a TorchScript model"""
163
        with TemporaryDirectory() as dir:
164
            path = os.path.join(dir, "script.pt")
165
            m.save(path)
166
            imported = torch.jit.load(path)
167
        return imported
168

169
    sm = torch.jit.script(nn_module)
170
    sm.eval()
171

172
    if eager_out is None:
173
        with torch.no_grad(), freeze_rng_state():
174
            eager_out = nn_module(*args)
175

176
    with torch.no_grad(), freeze_rng_state():
177
        script_out = sm(*args)
178
        if unwrapper:
179
            script_out = unwrapper(script_out)
180

181
    torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
182

183
    m_import = get_export_import_copy(sm)
184
    with torch.no_grad(), freeze_rng_state():
185
        imported_script_out = m_import(*args)
186
        if unwrapper:
187
            imported_script_out = unwrapper(imported_script_out)
188

189
    torch.testing.assert_close(script_out, imported_script_out, atol=3e-4, rtol=3e-4)
190

191

192
def _check_fx_compatible(model, inputs, eager_out=None):
193
    model_fx = torch.fx.symbolic_trace(model)
194
    if eager_out is None:
195
        eager_out = model(inputs)
196
    with torch.no_grad(), freeze_rng_state():
197
        fx_out = model_fx(inputs)
198
    torch.testing.assert_close(eager_out, fx_out)
199

200

201
def _check_input_backprop(model, inputs):
202
    if isinstance(inputs, list):
203
        requires_grad = list()
204
        for inp in inputs:
205
            requires_grad.append(inp.requires_grad)
206
            inp.requires_grad_(True)
207
    else:
208
        requires_grad = inputs.requires_grad
209
        inputs.requires_grad_(True)
210

211
    out = model(inputs)
212

213
    if isinstance(out, dict):
214
        out["out"].sum().backward()
215
    else:
216
        if isinstance(out[0], dict):
217
            out[0]["scores"].sum().backward()
218
        else:
219
            out[0].sum().backward()
220

221
    if isinstance(inputs, list):
222
        for i, inp in enumerate(inputs):
223
            assert inputs[i].grad is not None
224
            inp.requires_grad_(requires_grad[i])
225
    else:
226
        assert inputs.grad is not None
227
        inputs.requires_grad_(requires_grad)
228

229

230
# If 'unwrapper' is provided it will be called with the script model outputs
231
# before they are compared to the eager model outputs. This is useful if the
232
# model outputs are different between TorchScript / Eager mode
233
script_model_unwrapper = {
234
    "googlenet": lambda x: x.logits,
235
    "inception_v3": lambda x: x.logits,
236
    "fasterrcnn_resnet50_fpn": lambda x: x[1],
237
    "fasterrcnn_resnet50_fpn_v2": lambda x: x[1],
238
    "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
239
    "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
240
    "maskrcnn_resnet50_fpn": lambda x: x[1],
241
    "maskrcnn_resnet50_fpn_v2": lambda x: x[1],
242
    "keypointrcnn_resnet50_fpn": lambda x: x[1],
243
    "retinanet_resnet50_fpn": lambda x: x[1],
244
    "retinanet_resnet50_fpn_v2": lambda x: x[1],
245
    "ssd300_vgg16": lambda x: x[1],
246
    "ssdlite320_mobilenet_v3_large": lambda x: x[1],
247
    "fcos_resnet50_fpn": lambda x: x[1],
248
}
249

250

251
# The following models exhibit flaky numerics under autocast in _test_*_model harnesses.
252
# This may be caused by the harness environment (e.g. num classes, input initialization
253
# via torch.rand), and does not prove autocast is unsuitable when training with real data
254
# (autocast has been used successfully with real data for some of these models).
255
# TODO:  investigate why autocast numerics are flaky in the harnesses.
256
#
257
# For the following models, _test_*_model harnesses skip numerical checks on outputs when
258
# trying autocast. However, they still try an autocasted forward pass, so they still ensure
259
# autocast coverage suffices to prevent dtype errors in each model.
260
autocast_flaky_numerics = (
261
    "inception_v3",
262
    "resnet101",
263
    "resnet152",
264
    "wide_resnet101_2",
265
    "deeplabv3_resnet50",
266
    "deeplabv3_resnet101",
267
    "deeplabv3_mobilenet_v3_large",
268
    "fcn_resnet50",
269
    "fcn_resnet101",
270
    "lraspp_mobilenet_v3_large",
271
    "maskrcnn_resnet50_fpn",
272
    "maskrcnn_resnet50_fpn_v2",
273
    "keypointrcnn_resnet50_fpn",
274
)
275

276
# The tests for the following quantized models are flaky possibly due to inconsistent
277
# rounding errors in different platforms. For this reason the input/output consistency
278
# tests under test_quantized_classification_model will be skipped for the following models.
279
quantized_flaky_models = ("inception_v3", "resnet50")
280

281
# The tests for the following detection models are flaky.
282
# We run those tests on float64 to avoid floating point errors.
283
# FIXME: we shouldn't have to do that :'/
284
detection_flaky_models = ("keypointrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn_v2")
285

286

287
# The following contains configuration parameters for all models which are used by
288
# the _test_*_model methods.
289
_model_params = {
290
    "inception_v3": {"input_shape": (1, 3, 299, 299), "init_weights": True},
291
    "retinanet_resnet50_fpn": {
292
        "num_classes": 20,
293
        "score_thresh": 0.01,
294
        "min_size": 224,
295
        "max_size": 224,
296
        "input_shape": (3, 224, 224),
297
        "real_image": True,
298
    },
299
    "retinanet_resnet50_fpn_v2": {
300
        "num_classes": 20,
301
        "score_thresh": 0.01,
302
        "min_size": 224,
303
        "max_size": 224,
304
        "input_shape": (3, 224, 224),
305
        "real_image": True,
306
    },
307
    "keypointrcnn_resnet50_fpn": {
308
        "num_classes": 2,
309
        "min_size": 224,
310
        "max_size": 224,
311
        "box_score_thresh": 0.17,
312
        "input_shape": (3, 224, 224),
313
        "real_image": True,
314
    },
315
    "fasterrcnn_resnet50_fpn": {
316
        "num_classes": 20,
317
        "min_size": 224,
318
        "max_size": 224,
319
        "input_shape": (3, 224, 224),
320
        "real_image": True,
321
    },
322
    "fasterrcnn_resnet50_fpn_v2": {
323
        "num_classes": 20,
324
        "min_size": 224,
325
        "max_size": 224,
326
        "input_shape": (3, 224, 224),
327
        "real_image": True,
328
    },
329
    "fcos_resnet50_fpn": {
330
        "num_classes": 2,
331
        "score_thresh": 0.05,
332
        "min_size": 224,
333
        "max_size": 224,
334
        "input_shape": (3, 224, 224),
335
        "real_image": True,
336
    },
337
    "maskrcnn_resnet50_fpn": {
338
        "num_classes": 10,
339
        "min_size": 224,
340
        "max_size": 224,
341
        "input_shape": (3, 224, 224),
342
        "real_image": True,
343
    },
344
    "maskrcnn_resnet50_fpn_v2": {
345
        "num_classes": 10,
346
        "min_size": 224,
347
        "max_size": 224,
348
        "input_shape": (3, 224, 224),
349
        "real_image": True,
350
    },
351
    "fasterrcnn_mobilenet_v3_large_fpn": {
352
        "box_score_thresh": 0.02076,
353
    },
354
    "fasterrcnn_mobilenet_v3_large_320_fpn": {
355
        "box_score_thresh": 0.02076,
356
        "rpn_pre_nms_top_n_test": 1000,
357
        "rpn_post_nms_top_n_test": 1000,
358
    },
359
    "vit_h_14": {
360
        "image_size": 56,
361
        "input_shape": (1, 3, 56, 56),
362
    },
363
    "mvit_v1_b": {
364
        "input_shape": (1, 3, 16, 224, 224),
365
    },
366
    "mvit_v2_s": {
367
        "input_shape": (1, 3, 16, 224, 224),
368
    },
369
    "s3d": {
370
        "input_shape": (1, 3, 16, 224, 224),
371
    },
372
    "googlenet": {"init_weights": True},
373
}
374
# speeding up slow models:
375
slow_models = [
376
    "convnext_base",
377
    "convnext_large",
378
    "resnext101_32x8d",
379
    "resnext101_64x4d",
380
    "wide_resnet101_2",
381
    "efficientnet_b6",
382
    "efficientnet_b7",
383
    "efficientnet_v2_m",
384
    "efficientnet_v2_l",
385
    "regnet_y_16gf",
386
    "regnet_y_32gf",
387
    "regnet_y_128gf",
388
    "regnet_x_16gf",
389
    "regnet_x_32gf",
390
    "swin_t",
391
    "swin_s",
392
    "swin_b",
393
    "swin_v2_t",
394
    "swin_v2_s",
395
    "swin_v2_b",
396
]
397
for m in slow_models:
398
    _model_params[m] = {"input_shape": (1, 3, 64, 64)}
399

400

401
# skip big models to reduce memory usage on CI test. We can exclude combinations of (platform-system, device).
402
skipped_big_models = {
403
    "vit_h_14": {("Windows", "cpu"), ("Windows", "cuda")},
404
    "regnet_y_128gf": {("Windows", "cpu"), ("Windows", "cuda")},
405
    "mvit_v1_b": {("Windows", "cuda"), ("Linux", "cuda")},
406
    "mvit_v2_s": {("Windows", "cuda"), ("Linux", "cuda")},
407
}
408

409

410
def is_skippable(model_name, device):
411
    if model_name not in skipped_big_models:
412
        return False
413

414
    platform_system = platform.system()
415
    device_name = str(device).split(":")[0]
416

417
    return (platform_system, device_name) in skipped_big_models[model_name]
418

419

420
# The following contains configuration and expected values to be used tests that are model specific
421
_model_tests_values = {
422
    "retinanet_resnet50_fpn": {
423
        "max_trainable": 5,
424
        "n_trn_params_per_layer": [36, 46, 65, 78, 88, 89],
425
    },
426
    "retinanet_resnet50_fpn_v2": {
427
        "max_trainable": 5,
428
        "n_trn_params_per_layer": [44, 74, 131, 170, 200, 203],
429
    },
430
    "keypointrcnn_resnet50_fpn": {
431
        "max_trainable": 5,
432
        "n_trn_params_per_layer": [48, 58, 77, 90, 100, 101],
433
    },
434
    "fasterrcnn_resnet50_fpn": {
435
        "max_trainable": 5,
436
        "n_trn_params_per_layer": [30, 40, 59, 72, 82, 83],
437
    },
438
    "fasterrcnn_resnet50_fpn_v2": {
439
        "max_trainable": 5,
440
        "n_trn_params_per_layer": [50, 80, 137, 176, 206, 209],
441
    },
442
    "maskrcnn_resnet50_fpn": {
443
        "max_trainable": 5,
444
        "n_trn_params_per_layer": [42, 52, 71, 84, 94, 95],
445
    },
446
    "maskrcnn_resnet50_fpn_v2": {
447
        "max_trainable": 5,
448
        "n_trn_params_per_layer": [66, 96, 153, 192, 222, 225],
449
    },
450
    "fasterrcnn_mobilenet_v3_large_fpn": {
451
        "max_trainable": 6,
452
        "n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
453
    },
454
    "fasterrcnn_mobilenet_v3_large_320_fpn": {
455
        "max_trainable": 6,
456
        "n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
457
    },
458
    "ssd300_vgg16": {
459
        "max_trainable": 5,
460
        "n_trn_params_per_layer": [45, 51, 57, 63, 67, 71],
461
    },
462
    "ssdlite320_mobilenet_v3_large": {
463
        "max_trainable": 6,
464
        "n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266],
465
    },
466
    "fcos_resnet50_fpn": {
467
        "max_trainable": 5,
468
        "n_trn_params_per_layer": [54, 64, 83, 96, 106, 107],
469
    },
470
}
471

472

473
def _make_sliced_model(model, stop_layer):
474
    layers = OrderedDict()
475
    for name, layer in model.named_children():
476
        layers[name] = layer
477
        if name == stop_layer:
478
            break
479
    new_model = torch.nn.Sequential(layers)
480
    return new_model
481

482

483
@pytest.mark.parametrize("model_fn", [models.densenet121, models.densenet169, models.densenet201, models.densenet161])
484
def test_memory_efficient_densenet(model_fn):
485
    input_shape = (1, 3, 300, 300)
486
    x = torch.rand(input_shape)
487

488
    model1 = model_fn(num_classes=50, memory_efficient=True)
489
    params = model1.state_dict()
490
    num_params = sum(x.numel() for x in model1.parameters())
491
    model1.eval()
492
    out1 = model1(x)
493
    out1.sum().backward()
494
    num_grad = sum(x.grad.numel() for x in model1.parameters() if x.grad is not None)
495

496
    model2 = model_fn(num_classes=50, memory_efficient=False)
497
    model2.load_state_dict(params)
498
    model2.eval()
499
    out2 = model2(x)
500

501
    assert num_params == num_grad
502
    torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5)
503

504
    _check_input_backprop(model1, x)
505
    _check_input_backprop(model2, x)
506

507

508
@pytest.mark.parametrize("dilate_layer_2", (True, False))
509
@pytest.mark.parametrize("dilate_layer_3", (True, False))
510
@pytest.mark.parametrize("dilate_layer_4", (True, False))
511
def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4):
512
    # TODO improve tests to also check that each layer has the right dimensionality
513
    model = models.resnet50(replace_stride_with_dilation=(dilate_layer_2, dilate_layer_3, dilate_layer_4))
514
    model = _make_sliced_model(model, stop_layer="layer4")
515
    model.eval()
516
    x = torch.rand(1, 3, 224, 224)
517
    out = model(x)
518
    f = 2 ** sum((dilate_layer_2, dilate_layer_3, dilate_layer_4))
519
    assert out.shape == (1, 2048, 7 * f, 7 * f)
520

521

522
def test_mobilenet_v2_residual_setting():
523
    model = models.mobilenet_v2(inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]])
524
    model.eval()
525
    x = torch.rand(1, 3, 224, 224)
526
    out = model(x)
527
    assert out.shape[-1] == 1000
528

529

530
@pytest.mark.parametrize("model_fn", [models.mobilenet_v2, models.mobilenet_v3_large, models.mobilenet_v3_small])
531
def test_mobilenet_norm_layer(model_fn):
532
    model = model_fn()
533
    assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules())
534

535
    def get_gn(num_channels):
536
        return nn.GroupNorm(1, num_channels)
537

538
    model = model_fn(norm_layer=get_gn)
539
    assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
540
    assert any(isinstance(x, nn.GroupNorm) for x in model.modules())
541

542

543
def test_inception_v3_eval():
544
    kwargs = {}
545
    kwargs["transform_input"] = True
546
    kwargs["aux_logits"] = True
547
    kwargs["init_weights"] = False
548
    name = "inception_v3"
549
    model = models.Inception3(**kwargs)
550
    model.aux_logits = False
551
    model.AuxLogits = None
552
    model = model.eval()
553
    x = torch.rand(1, 3, 299, 299)
554
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
555
    _check_input_backprop(model, x)
556

557

558
def test_fasterrcnn_double():
559
    model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None)
560
    model.double()
561
    model.eval()
562
    input_shape = (3, 300, 300)
563
    x = torch.rand(input_shape, dtype=torch.float64)
564
    model_input = [x]
565
    out = model(model_input)
566
    assert model_input[0] is x
567
    assert len(out) == 1
568
    assert "boxes" in out[0]
569
    assert "scores" in out[0]
570
    assert "labels" in out[0]
571
    _check_input_backprop(model, model_input)
572

573

574
def test_googlenet_eval():
575
    kwargs = {}
576
    kwargs["transform_input"] = True
577
    kwargs["aux_logits"] = True
578
    kwargs["init_weights"] = False
579
    name = "googlenet"
580
    model = models.GoogLeNet(**kwargs)
581
    model.aux_logits = False
582
    model.aux1 = None
583
    model.aux2 = None
584
    model = model.eval()
585
    x = torch.rand(1, 3, 224, 224)
586
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
587
    _check_input_backprop(model, x)
588

589

590
@needs_cuda
591
def test_fasterrcnn_switch_devices():
592
    def checkOut(out):
593
        assert len(out) == 1
594
        assert "boxes" in out[0]
595
        assert "scores" in out[0]
596
        assert "labels" in out[0]
597

598
    model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None)
599
    model.cuda()
600
    model.eval()
601
    input_shape = (3, 300, 300)
602
    x = torch.rand(input_shape, device="cuda")
603
    model_input = [x]
604
    out = model(model_input)
605
    assert model_input[0] is x
606

607
    checkOut(out)
608

609
    with torch.cuda.amp.autocast():
610
        out = model(model_input)
611

612
    checkOut(out)
613

614
    _check_input_backprop(model, model_input)
615

616
    # now switch to cpu and make sure it works
617
    model.cpu()
618
    x = x.cpu()
619
    out_cpu = model([x])
620

621
    checkOut(out_cpu)
622

623
    _check_input_backprop(model, [x])
624

625

626
def test_generalizedrcnn_transform_repr():
627

628
    min_size, max_size = 224, 299
629
    image_mean = [0.485, 0.456, 0.406]
630
    image_std = [0.229, 0.224, 0.225]
631

632
    t = models.detection.transform.GeneralizedRCNNTransform(
633
        min_size=min_size, max_size=max_size, image_mean=image_mean, image_std=image_std
634
    )
635

636
    # Check integrity of object __repr__ attribute
637
    expected_string = "GeneralizedRCNNTransform("
638
    _indent = "\n    "
639
    expected_string += f"{_indent}Normalize(mean={image_mean}, std={image_std})"
640
    expected_string += f"{_indent}Resize(min_size=({min_size},), max_size={max_size}, "
641
    expected_string += "mode='bilinear')\n)"
642
    assert t.__repr__() == expected_string
643

644

645
test_vit_conv_stem_configs = [
646
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=64),
647
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=128),
648
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=128),
649
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=256),
650
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=1, out_channels=256),
651
    models.vision_transformer.ConvStemConfig(kernel_size=3, stride=2, out_channels=512),
652
]
653

654

655
def vitc_b_16(**kwargs: Any):
656
    return models.VisionTransformer(
657
        image_size=224,
658
        patch_size=16,
659
        num_layers=12,
660
        num_heads=12,
661
        hidden_dim=768,
662
        mlp_dim=3072,
663
        conv_stem_configs=test_vit_conv_stem_configs,
664
        **kwargs,
665
    )
666

667

668
@pytest.mark.parametrize("model_fn", [vitc_b_16])
669
@pytest.mark.parametrize("dev", cpu_and_cuda())
670
def test_vitc_models(model_fn, dev):
671
    test_classification_model(model_fn, dev)
672

673

674
@torch.backends.cudnn.flags(allow_tf32=False)  # see: https://github.com/pytorch/vision/issues/7618
675
@pytest.mark.parametrize("model_fn", list_model_fns(models))
676
@pytest.mark.parametrize("dev", cpu_and_cuda())
677
def test_classification_model(model_fn, dev):
678
    set_rng_seed(0)
679
    defaults = {
680
        "num_classes": 50,
681
        "input_shape": (1, 3, 224, 224),
682
    }
683
    model_name = model_fn.__name__
684
    if SKIP_BIG_MODEL and is_skippable(model_name, dev):
685
        pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
686
    kwargs = {**defaults, **_model_params.get(model_name, {})}
687
    num_classes = kwargs.get("num_classes")
688
    input_shape = kwargs.pop("input_shape")
689
    real_image = kwargs.pop("real_image", False)
690

691
    model = model_fn(**kwargs)
692
    model.eval().to(device=dev)
693
    x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
694
    out = model(x)
695
    # FIXME: this if/else is nasty and only here to please our CI prior to the
696
    # release. We rethink these tests altogether.
697
    if model_name == "resnet101":
698
        prec = 0.2
699
    else:
700
        # FIXME: this is probably still way too high.
701
        prec = 0.1
702
    _assert_expected(out.cpu(), model_name, prec=prec)
703
    assert out.shape[-1] == num_classes
704
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
705
    _check_fx_compatible(model, x, eager_out=out)
706

707
    if dev == "cuda":
708
        with torch.cuda.amp.autocast():
709
            out = model(x)
710
            # See autocast_flaky_numerics comment at top of file.
711
            if model_name not in autocast_flaky_numerics:
712
                _assert_expected(out.cpu(), model_name, prec=0.1)
713
            assert out.shape[-1] == 50
714

715
    _check_input_backprop(model, x)
716

717

718
@pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation))
719
@pytest.mark.parametrize("dev", cpu_and_cuda())
720
def test_segmentation_model(model_fn, dev):
721
    set_rng_seed(0)
722
    defaults = {
723
        "num_classes": 10,
724
        "weights_backbone": None,
725
        "input_shape": (1, 3, 32, 32),
726
    }
727
    model_name = model_fn.__name__
728
    kwargs = {**defaults, **_model_params.get(model_name, {})}
729
    input_shape = kwargs.pop("input_shape")
730

731
    model = model_fn(**kwargs)
732
    model.eval().to(device=dev)
733
    # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
734
    x = torch.rand(input_shape).to(device=dev)
735
    with torch.no_grad(), freeze_rng_state():
736
        out = model(x)
737

738
    def check_out(out):
739
        prec = 0.01
740
        try:
741
            # We first try to assert the entire output if possible. This is not
742
            # only the best way to assert results but also handles the cases
743
            # where we need to create a new expected result.
744
            _assert_expected(out.cpu(), model_name, prec=prec)
745
        except AssertionError:
746
            # Unfortunately some segmentation models are flaky with autocast
747
            # so instead of validating the probability scores, check that the class
748
            # predictions match.
749
            expected_file = _get_expected_file(model_name)
750
            expected = torch.load(expected_file, weights_only=True)
751
            torch.testing.assert_close(
752
                out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False
753
            )
754
            return False  # Partial validation performed
755

756
        return True  # Full validation performed
757

758
    full_validation = check_out(out["out"])
759

760
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
761
    _check_fx_compatible(model, x, eager_out=out)
762

763
    if dev == "cuda":
764
        with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
765
            out = model(x)
766
            # See autocast_flaky_numerics comment at top of file.
767
            if model_name not in autocast_flaky_numerics:
768
                full_validation &= check_out(out["out"])
769

770
    if not full_validation:
771
        msg = (
772
            f"The output of {test_segmentation_model.__name__} could only be partially validated. "
773
            "This is likely due to unit-test flakiness, but you may "
774
            "want to do additional manual checks if you made "
775
            "significant changes to the codebase."
776
        )
777
        warnings.warn(msg, RuntimeWarning)
778
        pytest.skip(msg)
779

780
    _check_input_backprop(model, x)
781

782

783
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
784
@pytest.mark.parametrize("dev", cpu_and_cuda())
785
def test_detection_model(model_fn, dev):
786
    set_rng_seed(0)
787
    defaults = {
788
        "num_classes": 50,
789
        "weights_backbone": None,
790
        "input_shape": (3, 300, 300),
791
    }
792
    model_name = model_fn.__name__
793
    if model_name in detection_flaky_models:
794
        dtype = torch.float64
795
    else:
796
        dtype = torch.get_default_dtype()
797
    kwargs = {**defaults, **_model_params.get(model_name, {})}
798
    input_shape = kwargs.pop("input_shape")
799
    real_image = kwargs.pop("real_image", False)
800

801
    model = model_fn(**kwargs)
802
    model.eval().to(device=dev, dtype=dtype)
803
    x = _get_image(input_shape=input_shape, real_image=real_image, device=dev, dtype=dtype)
804
    model_input = [x]
805
    with torch.no_grad(), freeze_rng_state():
806
        out = model(model_input)
807
    assert model_input[0] is x
808

809
    def check_out(out):
810
        assert len(out) == 1
811

812
        def compact(tensor):
813
            tensor = tensor.cpu()
814
            size = tensor.size()
815
            elements_per_sample = functools.reduce(operator.mul, size[1:], 1)
816
            if elements_per_sample > 30:
817
                return compute_mean_std(tensor)
818
            else:
819
                return subsample_tensor(tensor)
820

821
        def subsample_tensor(tensor):
822
            num_elems = tensor.size(0)
823
            num_samples = 20
824
            if num_elems <= num_samples:
825
                return tensor
826

827
            ith_index = num_elems // num_samples
828
            return tensor[ith_index - 1 :: ith_index]
829

830
        def compute_mean_std(tensor):
831
            # can't compute mean of integral tensor
832
            tensor = tensor.to(torch.double)
833
            mean = torch.mean(tensor)
834
            std = torch.std(tensor)
835
            return {"mean": mean, "std": std}
836

837
        output = map_nested_tensor_object(out, tensor_map_fn=compact)
838
        prec = 0.01
839
        try:
840
            # We first try to assert the entire output if possible. This is not
841
            # only the best way to assert results but also handles the cases
842
            # where we need to create a new expected result.
843
            _assert_expected(output, model_name, prec=prec)
844
        except AssertionError:
845
            # Unfortunately detection models are flaky due to the unstable sort
846
            # in NMS. If matching across all outputs fails, use the same approach
847
            # as in NMSTester.test_nms_cuda to see if this is caused by duplicate
848
            # scores.
849
            expected_file = _get_expected_file(model_name)
850
            expected = torch.load(expected_file, weights_only=True)
851
            torch.testing.assert_close(
852
                output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False
853
            )
854

855
            # Note: Fmassa proposed turning off NMS by adapting the threshold
856
            # and then using the Hungarian algorithm as in DETR to find the
857
            # best match between output and expected boxes and eliminate some
858
            # of the flakiness. Worth exploring.
859
            return False  # Partial validation performed
860

861
        return True  # Full validation performed
862

863
    full_validation = check_out(out)
864
    _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
865

866
    if dev == "cuda":
867
        with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
868
            out = model(model_input)
869
            # See autocast_flaky_numerics comment at top of file.
870
            if model_name not in autocast_flaky_numerics:
871
                full_validation &= check_out(out)
872

873
    if not full_validation:
874
        msg = (
875
            f"The output of {test_detection_model.__name__} could only be partially validated. "
876
            "This is likely due to unit-test flakiness, but you may "
877
            "want to do additional manual checks if you made "
878
            "significant changes to the codebase."
879
        )
880
        warnings.warn(msg, RuntimeWarning)
881
        pytest.skip(msg)
882

883
    _check_input_backprop(model, model_input)
884

885

886
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
887
def test_detection_model_validation(model_fn):
888
    set_rng_seed(0)
889
    model = model_fn(num_classes=50, weights=None, weights_backbone=None)
890
    input_shape = (3, 300, 300)
891
    x = [torch.rand(input_shape)]
892

893
    # validate that targets are present in training
894
    with pytest.raises(AssertionError):
895
        model(x)
896

897
    # validate type
898
    targets = [{"boxes": 0.0}]
899
    with pytest.raises(AssertionError):
900
        model(x, targets=targets)
901

902
    # validate boxes shape
903
    for boxes in (torch.rand((4,)), torch.rand((1, 5))):
904
        targets = [{"boxes": boxes}]
905
        with pytest.raises(AssertionError):
906
            model(x, targets=targets)
907

908
    # validate that no degenerate boxes are present
909
    boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]])
910
    targets = [{"boxes": boxes}]
911
    with pytest.raises(AssertionError):
912
        model(x, targets=targets)
913

914

915
@pytest.mark.parametrize("model_fn", list_model_fns(models.video))
916
@pytest.mark.parametrize("dev", cpu_and_cuda())
917
def test_video_model(model_fn, dev):
918
    set_rng_seed(0)
919
    # the default input shape is
920
    # bs * num_channels * clip_len * h *w
921
    defaults = {
922
        "input_shape": (1, 3, 4, 112, 112),
923
        "num_classes": 50,
924
    }
925
    model_name = model_fn.__name__
926
    if SKIP_BIG_MODEL and is_skippable(model_name, dev):
927
        pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
928
    kwargs = {**defaults, **_model_params.get(model_name, {})}
929
    num_classes = kwargs.get("num_classes")
930
    input_shape = kwargs.pop("input_shape")
931
    # test both basicblock and Bottleneck
932
    model = model_fn(**kwargs)
933
    model.eval().to(device=dev)
934
    # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
935
    x = torch.rand(input_shape).to(device=dev)
936
    out = model(x)
937
    _assert_expected(out.cpu(), model_name, prec=0.1)
938
    assert out.shape[-1] == num_classes
939
    _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
940
    _check_fx_compatible(model, x, eager_out=out)
941
    assert out.shape[-1] == num_classes
942

943
    if dev == "cuda":
944
        with torch.cuda.amp.autocast():
945
            out = model(x)
946
            # See autocast_flaky_numerics comment at top of file.
947
            if model_name not in autocast_flaky_numerics:
948
                _assert_expected(out.cpu(), model_name, prec=0.1)
949
            assert out.shape[-1] == num_classes
950

951
    _check_input_backprop(model, x)
952

953

954
@pytest.mark.skipif(
955
    not (
956
        "fbgemm" in torch.backends.quantized.supported_engines
957
        and "qnnpack" in torch.backends.quantized.supported_engines
958
    ),
959
    reason="This Pytorch Build has not been built with fbgemm and qnnpack",
960
)
961
@pytest.mark.parametrize("model_fn", list_model_fns(models.quantization))
962
def test_quantized_classification_model(model_fn):
963
    set_rng_seed(0)
964
    defaults = {
965
        "num_classes": 5,
966
        "input_shape": (1, 3, 224, 224),
967
        "quantize": True,
968
    }
969
    model_name = model_fn.__name__
970
    kwargs = {**defaults, **_model_params.get(model_name, {})}
971
    input_shape = kwargs.pop("input_shape")
972

973
    # First check if quantize=True provides models that can run with input data
974
    model = model_fn(**kwargs)
975
    model.eval()
976
    x = torch.rand(input_shape)
977
    out = model(x)
978

979
    if model_name not in quantized_flaky_models:
980
        _assert_expected(out.cpu(), model_name + "_quantized", prec=2e-2)
981
        assert out.shape[-1] == 5
982
        _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
983
        _check_fx_compatible(model, x, eager_out=out)
984
    else:
985
        try:
986
            torch.jit.script(model)
987
        except Exception as e:
988
            raise AssertionError("model cannot be scripted.") from e
989

990
    kwargs["quantize"] = False
991
    for eval_mode in [True, False]:
992
        model = model_fn(**kwargs)
993
        if eval_mode:
994
            model.eval()
995
            model.qconfig = torch.ao.quantization.default_qconfig
996
        else:
997
            model.train()
998
            model.qconfig = torch.ao.quantization.default_qat_qconfig
999

1000
        model.fuse_model(is_qat=not eval_mode)
1001
        if eval_mode:
1002
            torch.ao.quantization.prepare(model, inplace=True)
1003
        else:
1004
            torch.ao.quantization.prepare_qat(model, inplace=True)
1005
            model.eval()
1006

1007
        torch.ao.quantization.convert(model, inplace=True)
1008

1009

1010
@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))
1011
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
1012
    model_name = model_fn.__name__
1013
    max_trainable = _model_tests_values[model_name]["max_trainable"]
1014
    n_trainable_params = []
1015
    for trainable_layers in range(0, max_trainable + 1):
1016
        model = model_fn(weights=None, weights_backbone="DEFAULT", trainable_backbone_layers=trainable_layers)
1017

1018
        n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad]))
1019
    assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"]
1020

1021

1022
@needs_cuda
1023
@pytest.mark.parametrize("model_fn", list_model_fns(models.optical_flow))
1024
@pytest.mark.parametrize("scripted", (False, True))
1025
def test_raft(model_fn, scripted):
1026

1027
    torch.manual_seed(0)
1028

1029
    # We need very small images, otherwise the pickle size would exceed the 50KB
1030
    # As a result we need to override the correlation pyramid to not downsample
1031
    # too much, otherwise we would get nan values (effective H and W would be
1032
    # reduced to 1)
1033
    corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)
1034

1035
    model = model_fn(corr_block=corr_block).eval().to("cuda")
1036
    if scripted:
1037
        model = torch.jit.script(model)
1038

1039
    bs = 1
1040
    img1 = torch.rand(bs, 3, 80, 72).cuda()
1041
    img2 = torch.rand(bs, 3, 80, 72).cuda()
1042

1043
    preds = model(img1, img2)
1044
    flow_pred = preds[-1]
1045
    # Tolerance is fairly high, but there are 2 * H * W outputs to check
1046
    # The .pkl were generated on the AWS cluter, on the CI it looks like the results are slightly different
1047
    _assert_expected(flow_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1)
1048

1049

1050
if __name__ == "__main__":
1051
    pytest.main([__file__])
1052

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.