pytorch

Форк
0
/
test_export_modes.py 
149 строк · 4.4 Кб
1
# Owner(s): ["module: onnx"]
2

3
import io
4
import os
5
import shutil
6
import sys
7
import tempfile
8

9
import torch
10
import torch.nn as nn
11
from torch.autograd import Variable
12
from torch.onnx import OperatorExportTypes
13

14
# Make the helper files in test/ importable
15
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16
sys.path.append(pytorch_test_dir)
17
import pytorch_test_common
18

19
from torch.testing._internal import common_utils
20

21

22
# Smoke tests for export methods
23
class TestExportModes(pytorch_test_common.ExportTestCase):
24
    class MyModel(nn.Module):
25
        def __init__(self):
26
            super(TestExportModes.MyModel, self).__init__()
27

28
        def forward(self, x):
29
            return x.transpose(0, 1)
30

31
    def test_protobuf(self):
32
        torch_model = TestExportModes.MyModel()
33
        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
34
        f = io.BytesIO()
35
        torch.onnx._export(
36
            torch_model,
37
            (fake_input),
38
            f,
39
            verbose=False,
40
            export_type=torch.onnx.ExportTypes.PROTOBUF_FILE,
41
        )
42

43
    def test_zipfile(self):
44
        torch_model = TestExportModes.MyModel()
45
        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
46
        f = io.BytesIO()
47
        torch.onnx._export(
48
            torch_model,
49
            (fake_input),
50
            f,
51
            verbose=False,
52
            export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE,
53
        )
54

55
    def test_compressed_zipfile(self):
56
        torch_model = TestExportModes.MyModel()
57
        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
58
        f = io.BytesIO()
59
        torch.onnx._export(
60
            torch_model,
61
            (fake_input),
62
            f,
63
            verbose=False,
64
            export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
65
        )
66

67
    def test_directory(self):
68
        torch_model = TestExportModes.MyModel()
69
        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
70
        d = tempfile.mkdtemp()
71
        torch.onnx._export(
72
            torch_model,
73
            (fake_input),
74
            d,
75
            verbose=False,
76
            export_type=torch.onnx.ExportTypes.DIRECTORY,
77
        )
78
        shutil.rmtree(d)
79

80
    def test_onnx_multiple_return(self):
81
        @torch.jit.script
82
        def foo(a):
83
            return (a, a)
84

85
        f = io.BytesIO()
86
        x = torch.ones(3)
87
        torch.onnx.export(foo, (x,), f)
88

89
    @common_utils.skipIfNoCaffe2
90
    @common_utils.skipIfNoLapack
91
    def test_caffe2_aten_fallback(self):
92
        class ModelWithAtenNotONNXOp(nn.Module):
93
            def forward(self, x, y):
94
                abcd = x + y
95
                defg = torch.linalg.qr(abcd)
96
                return defg
97

98
        x = torch.rand(3, 4)
99
        y = torch.rand(3, 4)
100
        torch.onnx.export_to_pretty_string(
101
            ModelWithAtenNotONNXOp(),
102
            (x, y),
103
            add_node_names=False,
104
            do_constant_folding=False,
105
            operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
106
        )
107

108
    @common_utils.skipIfCaffe2
109
    @common_utils.skipIfNoLapack
110
    def test_aten_fallback(self):
111
        class ModelWithAtenNotONNXOp(nn.Module):
112
            def forward(self, x, y):
113
                abcd = x + y
114
                defg = torch.linalg.qr(abcd)
115
                return defg
116

117
        x = torch.rand(3, 4)
118
        y = torch.rand(3, 4)
119
        torch.onnx.export_to_pretty_string(
120
            ModelWithAtenNotONNXOp(),
121
            (x, y),
122
            add_node_names=False,
123
            do_constant_folding=False,
124
            operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
125
            # support for linalg.qr was added in later op set versions.
126
            opset_version=9,
127
        )
128

129
    # torch.fmod is using to test ONNX_ATEN.
130
    # If you plan to remove fmod from aten, or found this test failed.
131
    # please contact @Rui.
132
    def test_onnx_aten(self):
133
        class ModelWithAtenFmod(nn.Module):
134
            def forward(self, x, y):
135
                return torch.fmod(x, y)
136

137
        x = torch.randn(3, 4, dtype=torch.float32)
138
        y = torch.randn(3, 4, dtype=torch.float32)
139
        torch.onnx.export_to_pretty_string(
140
            ModelWithAtenFmod(),
141
            (x, y),
142
            add_node_names=False,
143
            do_constant_folding=False,
144
            operator_export_type=OperatorExportTypes.ONNX_ATEN,
145
        )
146

147

148
if __name__ == "__main__":
149
    common_utils.run_tests()
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.