pytorch

Форк
0
/
test_models.py 
286 строк · 10.7 Кб
1
# Owner(s): ["module: onnx"]
2

3
import unittest
4

5
import pytorch_test_common
6

7
import torch
8
from model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init
9
from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2
10
from model_defs.mnist import MNIST
11
from model_defs.op_test import ConcatNet, DummyNet, FakeQuantNet, PermuteNet, PReluNet
12
from model_defs.squeezenet import SqueezeNet
13
from model_defs.srresnet import SRResNet
14
from model_defs.super_resolution import SuperResolutionNet
15
from pytorch_test_common import skipIfUnsupportedMinOpsetVersion, skipScriptTest
16
from torch.ao import quantization
17
from torch.autograd import Variable
18
from torch.onnx import OperatorExportTypes
19
from torch.testing._internal import common_utils
20
from torch.testing._internal.common_utils import skipIfNoLapack
21
from torchvision.models import shufflenet_v2_x1_0
22
from torchvision.models.alexnet import alexnet
23
from torchvision.models.densenet import densenet121
24
from torchvision.models.googlenet import googlenet
25
from torchvision.models.inception import inception_v3
26
from torchvision.models.mnasnet import mnasnet1_0
27
from torchvision.models.mobilenet import mobilenet_v2
28
from torchvision.models.resnet import resnet50
29
from torchvision.models.segmentation import deeplabv3_resnet101, fcn_resnet101
30
from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
31
from torchvision.models.video import mc3_18, r2plus1d_18, r3d_18
32
from verify import verify
33

34
if torch.cuda.is_available():
35

36
    def toC(x):
37
        return x.cuda()
38

39
else:
40

41
    def toC(x):
42
        return x
43

44

45
BATCH_SIZE = 2
46

47

48
class TestModels(pytorch_test_common.ExportTestCase):
49
    opset_version = 9  # Caffe2 doesn't support the default.
50
    keep_initializers_as_inputs = False
51

52
    def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, **kwargs):
53
        import caffe2.python.onnx.backend as backend
54

55
        with torch.onnx.select_model_mode_for_export(
56
            model, torch.onnx.TrainingMode.EVAL
57
        ):
58
            graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
59
            torch._C._jit_pass_lint(graph)
60
            verify(
61
                model,
62
                inputs,
63
                backend,
64
                rtol=rtol,
65
                atol=atol,
66
                opset_version=self.opset_version,
67
            )
68

69
    def test_ops(self):
70
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
71
        self.exportTest(toC(DummyNet()), toC(x))
72

73
    def test_prelu(self):
74
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
75
        self.exportTest(PReluNet(), x)
76

77
    @skipScriptTest()
78
    def test_concat(self):
79
        input_a = Variable(torch.randn(BATCH_SIZE, 3))
80
        input_b = Variable(torch.randn(BATCH_SIZE, 3))
81
        inputs = ((toC(input_a), toC(input_b)),)
82
        self.exportTest(toC(ConcatNet()), inputs)
83

84
    def test_permute(self):
85
        x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
86
        self.exportTest(PermuteNet(), x)
87

88
    @skipScriptTest()
89
    def test_embedding_sequential_1(self):
90
        x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
91
        self.exportTest(EmbeddingNetwork1(), x)
92

93
    @skipScriptTest()
94
    def test_embedding_sequential_2(self):
95
        x = Variable(torch.randint(0, 10, (BATCH_SIZE, 3)))
96
        self.exportTest(EmbeddingNetwork2(), x)
97

98
    @unittest.skip("This model takes too much memory")
99
    def test_srresnet(self):
100
        x = Variable(torch.randn(1, 3, 224, 224).fill_(1.0))
101
        self.exportTest(
102
            toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x)
103
        )
104

105
    @skipIfNoLapack
106
    def test_super_resolution(self):
107
        x = Variable(torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0))
108
        self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6)
109

110
    def test_alexnet(self):
111
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
112
        self.exportTest(toC(alexnet()), toC(x))
113

114
    def test_mnist(self):
115
        x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
116
        self.exportTest(toC(MNIST()), toC(x))
117

118
    @unittest.skip("This model takes too much memory")
119
    def test_vgg16(self):
120
        # VGG 16-layer model (configuration "D")
121
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
122
        self.exportTest(toC(vgg16()), toC(x))
123

124
    @unittest.skip("This model takes too much memory")
125
    def test_vgg16_bn(self):
126
        # VGG 16-layer model (configuration "D") with batch normalization
127
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
128
        self.exportTest(toC(vgg16_bn()), toC(x))
129

130
    @unittest.skip("This model takes too much memory")
131
    def test_vgg19(self):
132
        # VGG 19-layer model (configuration "E")
133
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
134
        self.exportTest(toC(vgg19()), toC(x))
135

136
    @unittest.skip("This model takes too much memory")
137
    def test_vgg19_bn(self):
138
        # VGG 19-layer model (configuration "E") with batch normalization
139
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
140
        self.exportTest(toC(vgg19_bn()), toC(x))
141

142
    def test_resnet(self):
143
        # ResNet50 model
144
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
145
        self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
146

147
    # This test is numerically unstable. Sporadic single element mismatch occurs occasionally.
148
    def test_inception(self):
149
        x = Variable(torch.randn(BATCH_SIZE, 3, 299, 299))
150
        self.exportTest(toC(inception_v3()), toC(x), acceptable_error_percentage=0.01)
151

152
    def test_squeezenet(self):
153
        # SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and
154
        # <0.5MB model size
155
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
156
        sqnet_v1_0 = SqueezeNet(version=1.1)
157
        self.exportTest(toC(sqnet_v1_0), toC(x))
158

159
        # SqueezeNet 1.1 has 2.4x less computation and slightly fewer params
160
        # than SqueezeNet 1.0, without sacrificing accuracy.
161
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
162
        sqnet_v1_1 = SqueezeNet(version=1.1)
163
        self.exportTest(toC(sqnet_v1_1), toC(x))
164

165
    def test_densenet(self):
166
        # Densenet-121 model
167
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
168
        self.exportTest(toC(densenet121()), toC(x), rtol=1e-2, atol=1e-5)
169

170
    @skipScriptTest()
171
    def test_dcgan_netD(self):
172
        netD = _netD(1)
173
        netD.apply(weights_init)
174
        input = Variable(torch.empty(bsz, 3, imgsz, imgsz).normal_(0, 1))
175
        self.exportTest(toC(netD), toC(input))
176

177
    @skipScriptTest()
178
    def test_dcgan_netG(self):
179
        netG = _netG(1)
180
        netG.apply(weights_init)
181
        input = Variable(torch.empty(bsz, nz, 1, 1).normal_(0, 1))
182
        self.exportTest(toC(netG), toC(input))
183

184
    @skipIfUnsupportedMinOpsetVersion(10)
185
    def test_fake_quant(self):
186
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
187
        self.exportTest(toC(FakeQuantNet()), toC(x))
188

189
    @skipIfUnsupportedMinOpsetVersion(10)
190
    def test_qat_resnet_pertensor(self):
191
        # Quantize ResNet50 model
192
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
193
        qat_resnet50 = resnet50()
194

195
        # Use per tensor for weight. Per channel support will come with opset 13
196
        qat_resnet50.qconfig = quantization.QConfig(
197
            activation=quantization.default_fake_quant,
198
            weight=quantization.default_fake_quant,
199
        )
200
        quantization.prepare_qat(qat_resnet50, inplace=True)
201
        qat_resnet50.apply(torch.ao.quantization.enable_observer)
202
        qat_resnet50.apply(torch.ao.quantization.enable_fake_quant)
203

204
        _ = qat_resnet50(x)
205
        for module in qat_resnet50.modules():
206
            if isinstance(module, quantization.FakeQuantize):
207
                module.calculate_qparams()
208
        qat_resnet50.apply(torch.ao.quantization.disable_observer)
209

210
        self.exportTest(toC(qat_resnet50), toC(x))
211

212
    @skipIfUnsupportedMinOpsetVersion(13)
213
    def test_qat_resnet_per_channel(self):
214
        # Quantize ResNet50 model
215
        x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
216
        qat_resnet50 = resnet50()
217

218
        qat_resnet50.qconfig = quantization.QConfig(
219
            activation=quantization.default_fake_quant,
220
            weight=quantization.default_per_channel_weight_fake_quant,
221
        )
222
        quantization.prepare_qat(qat_resnet50, inplace=True)
223
        qat_resnet50.apply(torch.ao.quantization.enable_observer)
224
        qat_resnet50.apply(torch.ao.quantization.enable_fake_quant)
225

226
        _ = qat_resnet50(x)
227
        for module in qat_resnet50.modules():
228
            if isinstance(module, quantization.FakeQuantize):
229
                module.calculate_qparams()
230
        qat_resnet50.apply(torch.ao.quantization.disable_observer)
231

232
        self.exportTest(toC(qat_resnet50), toC(x))
233

234
    @skipScriptTest(skip_before_opset_version=15, reason="None type in outputs")
235
    def test_googlenet(self):
236
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
237
        self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
238

239
    def test_mnasnet(self):
240
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
241
        self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5)
242

243
    def test_mobilenet(self):
244
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
245
        self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)
246

247
    @skipScriptTest()  # prim_data
248
    def test_shufflenet(self):
249
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
250
        self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)
251

252
    @skipIfUnsupportedMinOpsetVersion(11)
253
    def test_fcn(self):
254
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
255
        self.exportTest(
256
            toC(fcn_resnet101(weights=None, weights_backbone=None)),
257
            toC(x),
258
            rtol=1e-3,
259
            atol=1e-5,
260
        )
261

262
    @skipIfUnsupportedMinOpsetVersion(11)
263
    def test_deeplab(self):
264
        x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
265
        self.exportTest(
266
            toC(deeplabv3_resnet101(weights=None, weights_backbone=None)),
267
            toC(x),
268
            rtol=1e-3,
269
            atol=1e-5,
270
        )
271

272
    def test_r3d_18_video(self):
273
        x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
274
        self.exportTest(toC(r3d_18()), toC(x), rtol=1e-3, atol=1e-5)
275

276
    def test_mc3_18_video(self):
277
        x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
278
        self.exportTest(toC(mc3_18()), toC(x), rtol=1e-3, atol=1e-5)
279

280
    def test_r2plus1d_18_video(self):
281
        x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
282
        self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5)
283

284

285
if __name__ == "__main__":
286
    common_utils.run_tests()
287

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.