pytorch

Форк
0
/
test_models_onnxruntime.py 
442 строки · 14.0 Кб
1
# Owner(s): ["module: onnx"]
2

3
import os
4
import unittest
5
from collections import OrderedDict
6
from typing import List, Mapping, Tuple
7

8
import onnx_test_common
9
import parameterized
10
import PIL
11
import pytorch_test_common
12
import test_models
13

14
import torch
15
import torchvision
16
from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest
17
from torch import nn
18
from torch.testing._internal import common_utils
19
from torchvision import ops
20
from torchvision.models.detection import (
21
    faster_rcnn,
22
    image_list,
23
    keypoint_rcnn,
24
    mask_rcnn,
25
    roi_heads,
26
    rpn,
27
    transform,
28
)
29

30

31
def exportTest(
32
    self,
33
    model,
34
    inputs,
35
    rtol=1e-2,
36
    atol=1e-7,
37
    opset_versions=None,
38
    acceptable_error_percentage=None,
39
):
40
    opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12, 13, 14]
41

42
    for opset_version in opset_versions:
43
        self.opset_version = opset_version
44
        self.onnx_shape_inference = True
45
        onnx_test_common.run_model_test(
46
            self,
47
            model,
48
            input_args=inputs,
49
            rtol=rtol,
50
            atol=atol,
51
            acceptable_error_percentage=acceptable_error_percentage,
52
        )
53

54
        if self.is_script_test_enabled and opset_version > 11:
55
            script_model = torch.jit.script(model)
56
            onnx_test_common.run_model_test(
57
                self,
58
                script_model,
59
                input_args=inputs,
60
                rtol=rtol,
61
                atol=atol,
62
                acceptable_error_percentage=acceptable_error_percentage,
63
            )
64

65

66
TestModels = type(
67
    "TestModels",
68
    (pytorch_test_common.ExportTestCase,),
69
    dict(
70
        test_models.TestModels.__dict__,
71
        is_script_test_enabled=False,
72
        is_script=False,
73
        exportTest=exportTest,
74
    ),
75
)
76

77

78
# model tests for scripting with new JIT APIs and shape inference
79
TestModels_new_jit_API = type(
80
    "TestModels_new_jit_API",
81
    (pytorch_test_common.ExportTestCase,),
82
    dict(
83
        TestModels.__dict__,
84
        exportTest=exportTest,
85
        is_script_test_enabled=True,
86
        is_script=True,
87
        onnx_shape_inference=True,
88
    ),
89
)
90

91

92
def _get_image(rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
93
    data_dir = os.path.join(os.path.dirname(__file__), "assets")
94
    path = os.path.join(data_dir, *rel_path.split("/"))
95
    image = PIL.Image.open(path).convert("RGB").resize(size, PIL.Image.BILINEAR)
96

97
    return torchvision.transforms.ToTensor()(image)
98

99

100
def _get_test_images() -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
101
    return (
102
        [_get_image("grace_hopper_517x606.jpg", (100, 320))],
103
        [_get_image("rgb_pytorch.png", (250, 380))],
104
    )
105

106

107
def _get_features(images):
108
    s0, s1 = images.shape[-2:]
109
    features = [
110
        ("0", torch.rand(2, 256, s0 // 4, s1 // 4)),
111
        ("1", torch.rand(2, 256, s0 // 8, s1 // 8)),
112
        ("2", torch.rand(2, 256, s0 // 16, s1 // 16)),
113
        ("3", torch.rand(2, 256, s0 // 32, s1 // 32)),
114
        ("4", torch.rand(2, 256, s0 // 64, s1 // 64)),
115
    ]
116
    features = OrderedDict(features)
117
    return features
118

119

120
def _init_test_generalized_rcnn_transform():
121
    min_size = 100
122
    max_size = 200
123
    image_mean = [0.485, 0.456, 0.406]
124
    image_std = [0.229, 0.224, 0.225]
125
    return transform.GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
126

127

128
def _init_test_rpn():
129
    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
130
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
131
    rpn_anchor_generator = rpn.AnchorGenerator(anchor_sizes, aspect_ratios)
132
    out_channels = 256
133
    rpn_head = rpn.RPNHead(
134
        out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
135
    )
136
    rpn_fg_iou_thresh = 0.7
137
    rpn_bg_iou_thresh = 0.3
138
    rpn_batch_size_per_image = 256
139
    rpn_positive_fraction = 0.5
140
    rpn_pre_nms_top_n = dict(training=2000, testing=1000)
141
    rpn_post_nms_top_n = dict(training=2000, testing=1000)
142
    rpn_nms_thresh = 0.7
143
    rpn_score_thresh = 0.0
144

145
    return rpn.RegionProposalNetwork(
146
        rpn_anchor_generator,
147
        rpn_head,
148
        rpn_fg_iou_thresh,
149
        rpn_bg_iou_thresh,
150
        rpn_batch_size_per_image,
151
        rpn_positive_fraction,
152
        rpn_pre_nms_top_n,
153
        rpn_post_nms_top_n,
154
        rpn_nms_thresh,
155
        score_thresh=rpn_score_thresh,
156
    )
157

158

159
def _init_test_roi_heads_faster_rcnn():
160
    out_channels = 256
161
    num_classes = 91
162

163
    box_fg_iou_thresh = 0.5
164
    box_bg_iou_thresh = 0.5
165
    box_batch_size_per_image = 512
166
    box_positive_fraction = 0.25
167
    bbox_reg_weights = None
168
    box_score_thresh = 0.05
169
    box_nms_thresh = 0.5
170
    box_detections_per_img = 100
171

172
    box_roi_pool = ops.MultiScaleRoIAlign(
173
        featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
174
    )
175

176
    resolution = box_roi_pool.output_size[0]
177
    representation_size = 1024
178
    box_head = faster_rcnn.TwoMLPHead(
179
        out_channels * resolution**2, representation_size
180
    )
181

182
    representation_size = 1024
183
    box_predictor = faster_rcnn.FastRCNNPredictor(representation_size, num_classes)
184

185
    return roi_heads.RoIHeads(
186
        box_roi_pool,
187
        box_head,
188
        box_predictor,
189
        box_fg_iou_thresh,
190
        box_bg_iou_thresh,
191
        box_batch_size_per_image,
192
        box_positive_fraction,
193
        bbox_reg_weights,
194
        box_score_thresh,
195
        box_nms_thresh,
196
        box_detections_per_img,
197
    )
198

199

200
@parameterized.parameterized_class(
201
    ("is_script",),
202
    [(True,), (False,)],
203
    class_name_func=onnx_test_common.parameterize_class_name,
204
)
205
class TestModelsONNXRuntime(onnx_test_common._TestONNXRuntime):
206
    @skipIfUnsupportedMinOpsetVersion(11)
207
    @skipScriptTest()  # Faster RCNN model is not scriptable
208
    def test_faster_rcnn(self):
209
        model = faster_rcnn.fasterrcnn_resnet50_fpn(
210
            pretrained=False, pretrained_backbone=True, min_size=200, max_size=300
211
        )
212
        model.eval()
213
        x1 = torch.randn(3, 200, 300, requires_grad=True)
214
        x2 = torch.randn(3, 200, 300, requires_grad=True)
215
        self.run_test(model, ([x1, x2],), rtol=1e-3, atol=1e-5)
216
        self.run_test(
217
            model,
218
            ([x1, x2],),
219
            input_names=["images_tensors"],
220
            output_names=["outputs"],
221
            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
222
            rtol=1e-3,
223
            atol=1e-5,
224
        )
225
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
226
        images, test_images = _get_test_images()
227
        self.run_test(
228
            model,
229
            (images,),
230
            additional_test_inputs=[(images,), (test_images,), (dummy_image,)],
231
            input_names=["images_tensors"],
232
            output_names=["outputs"],
233
            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
234
            rtol=1e-3,
235
            atol=1e-5,
236
        )
237
        self.run_test(
238
            model,
239
            (dummy_image,),
240
            additional_test_inputs=[(dummy_image,), (images,)],
241
            input_names=["images_tensors"],
242
            output_names=["outputs"],
243
            dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
244
            rtol=1e-3,
245
            atol=1e-5,
246
        )
247

248
    @unittest.skip("Failing after ONNX 1.13.0")
249
    @skipIfUnsupportedMinOpsetVersion(11)
250
    @skipScriptTest()
251
    def test_mask_rcnn(self):
252
        model = mask_rcnn.maskrcnn_resnet50_fpn(
253
            pretrained=False, pretrained_backbone=True, min_size=200, max_size=300
254
        )
255
        images, test_images = _get_test_images()
256
        self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
257
        self.run_test(
258
            model,
259
            (images,),
260
            input_names=["images_tensors"],
261
            output_names=["boxes", "labels", "scores", "masks"],
262
            dynamic_axes={
263
                "images_tensors": [0, 1, 2],
264
                "boxes": [0, 1],
265
                "labels": [0],
266
                "scores": [0],
267
                "masks": [0, 1, 2],
268
            },
269
            rtol=1e-3,
270
            atol=1e-5,
271
        )
272
        dummy_image = [torch.ones(3, 100, 100) * 0.3]
273
        self.run_test(
274
            model,
275
            (images,),
276
            additional_test_inputs=[(images,), (test_images,), (dummy_image,)],
277
            input_names=["images_tensors"],
278
            output_names=["boxes", "labels", "scores", "masks"],
279
            dynamic_axes={
280
                "images_tensors": [0, 1, 2],
281
                "boxes": [0, 1],
282
                "labels": [0],
283
                "scores": [0],
284
                "masks": [0, 1, 2],
285
            },
286
            rtol=1e-3,
287
            atol=1e-5,
288
        )
289
        self.run_test(
290
            model,
291
            (dummy_image,),
292
            additional_test_inputs=[(dummy_image,), (images,)],
293
            input_names=["images_tensors"],
294
            output_names=["boxes", "labels", "scores", "masks"],
295
            dynamic_axes={
296
                "images_tensors": [0, 1, 2],
297
                "boxes": [0, 1],
298
                "labels": [0],
299
                "scores": [0],
300
                "masks": [0, 1, 2],
301
            },
302
            rtol=1e-3,
303
            atol=1e-5,
304
        )
305

306
    @unittest.skip("Failing, see https://github.com/pytorch/pytorch/issues/66528")
307
    @skipIfUnsupportedMinOpsetVersion(11)
308
    @skipScriptTest()
309
    def test_keypoint_rcnn(self):
310
        model = keypoint_rcnn.keypointrcnn_resnet50_fpn(
311
            pretrained=False, pretrained_backbone=False, min_size=200, max_size=300
312
        )
313
        images, test_images = _get_test_images()
314
        self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
315
        self.run_test(
316
            model,
317
            (images,),
318
            input_names=["images_tensors"],
319
            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
320
            dynamic_axes={"images_tensors": [0, 1, 2]},
321
            rtol=1e-3,
322
            atol=1e-5,
323
        )
324
        dummy_images = [torch.ones(3, 100, 100) * 0.3]
325
        self.run_test(
326
            model,
327
            (images,),
328
            additional_test_inputs=[(images,), (test_images,), (dummy_images,)],
329
            input_names=["images_tensors"],
330
            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
331
            dynamic_axes={"images_tensors": [0, 1, 2]},
332
            rtol=5e-3,
333
            atol=1e-5,
334
        )
335
        self.run_test(
336
            model,
337
            (dummy_images,),
338
            additional_test_inputs=[(dummy_images,), (test_images,)],
339
            input_names=["images_tensors"],
340
            output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
341
            dynamic_axes={"images_tensors": [0, 1, 2]},
342
            rtol=5e-3,
343
            atol=1e-5,
344
        )
345

346
    @skipIfUnsupportedMinOpsetVersion(11)
347
    @skipScriptTest()
348
    def test_roi_heads(self):
349
        class RoIHeadsModule(torch.nn.Module):
350
            def __init__(self):
351
                super().__init__()
352
                self.transform = _init_test_generalized_rcnn_transform()
353
                self.rpn = _init_test_rpn()
354
                self.roi_heads = _init_test_roi_heads_faster_rcnn()
355

356
            def forward(self, images, features: Mapping[str, torch.Tensor]):
357
                original_image_sizes = [
358
                    (img.shape[-1], img.shape[-2]) for img in images
359
                ]
360

361
                images_m = image_list.ImageList(
362
                    images, [(i.shape[-1], i.shape[-2]) for i in images]
363
                )
364
                proposals, _ = self.rpn(images_m, features)
365
                detections, _ = self.roi_heads(
366
                    features, proposals, images_m.image_sizes
367
                )
368
                detections = self.transform.postprocess(
369
                    detections, images_m.image_sizes, original_image_sizes
370
                )
371
                return detections
372

373
        images = torch.rand(2, 3, 100, 100)
374
        features = _get_features(images)
375
        images2 = torch.rand(2, 3, 150, 150)
376
        test_features = _get_features(images2)
377

378
        model = RoIHeadsModule()
379
        model.eval()
380
        model(images, features)
381

382
        self.run_test(
383
            model,
384
            (images, features),
385
            input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
386
            dynamic_axes={
387
                "input1": [0, 1, 2, 3],
388
                "input2": [0, 1, 2, 3],
389
                "input3": [0, 1, 2, 3],
390
                "input4": [0, 1, 2, 3],
391
                "input5": [0, 1, 2, 3],
392
                "input6": [0, 1, 2, 3],
393
            },
394
            additional_test_inputs=[(images, features), (images2, test_features)],
395
        )
396

397
    @skipScriptTest()  # TODO: #75625
398
    @skipIfUnsupportedMinOpsetVersion(20)
399
    def test_transformer_encoder(self):
400
        class MyModule(torch.nn.Module):
401
            def __init__(self, ninp, nhead, nhid, dropout, nlayers):
402
                super().__init__()
403
                encoder_layers = nn.TransformerEncoderLayer(ninp, nhead, nhid, dropout)
404
                self.transformer_encoder = nn.TransformerEncoder(
405
                    encoder_layers, nlayers
406
                )
407

408
            def forward(self, input):
409
                return self.transformer_encoder(input)
410

411
        x = torch.rand(10, 32, 512)
412
        self.run_test(MyModule(512, 8, 2048, 0.0, 3), (x,), atol=1e-5)
413

414
    @skipScriptTest()
415
    def test_mobilenet_v3(self):
416
        model = torchvision.models.quantization.mobilenet_v3_large(pretrained=False)
417
        dummy_input = torch.randn(1, 3, 224, 224)
418
        self.run_test(model, (dummy_input,))
419

420
    @skipIfUnsupportedMinOpsetVersion(11)
421
    @skipScriptTest()
422
    def test_shufflenet_v2_dynamic_axes(self):
423
        model = torchvision.models.shufflenet_v2_x0_5(weights=None)
424
        dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
425
        test_inputs = torch.randn(3, 3, 224, 224, requires_grad=True)
426
        self.run_test(
427
            model,
428
            (dummy_input,),
429
            additional_test_inputs=[(dummy_input,), (test_inputs,)],
430
            input_names=["input_images"],
431
            output_names=["outputs"],
432
            dynamic_axes={
433
                "input_images": {0: "batch_size"},
434
                "output": {0: "batch_size"},
435
            },
436
            rtol=1e-3,
437
            atol=1e-5,
438
        )
439

440

441
if __name__ == "__main__":
442
    common_utils.run_tests()
443

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.