pytorch

Форк
0
/
test_exporter_api.py 
167 строк · 5.9 Кб
1
# Owner(s): ["module: onnx"]
2
import io
3
import os
4

5
import onnx
6

7
import torch
8
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
9
from torch.onnx._internal import _exporter_legacy
10
from torch.onnx._internal._exporter_legacy import (
11
    ONNXProgramSerializer,
12
    ResolvedExportOptions,
13
)
14
from torch.testing._internal import common_utils
15

16

17
class SampleModel(torch.nn.Module):
18
    def forward(self, x):
19
        y = x + 1
20
        z = y.relu()
21
        return (y, z)
22

23

24
class SampleModelTwoInputs(torch.nn.Module):
25
    def forward(self, x, b):
26
        y = x + b
27
        z = y.relu()
28
        return (y, z)
29

30

31
class SampleModelForDynamicShapes(torch.nn.Module):
32
    def forward(self, x, b):
33
        return x.relu(), b.sigmoid()
34

35

36
class TestExportOptionsAPI(common_utils.TestCase):
37
    def test_dynamic_shapes_default(self):
38
        options = ResolvedExportOptions(ExportOptions())
39
        self.assertFalse(options.dynamic_shapes)
40

41
    def test_dynamic_shapes_explicit(self):
42
        options = ResolvedExportOptions(ExportOptions(dynamic_shapes=None))
43
        self.assertFalse(options.dynamic_shapes)
44
        options = ResolvedExportOptions(ExportOptions(dynamic_shapes=True))
45
        self.assertTrue(options.dynamic_shapes)
46
        options = ResolvedExportOptions(ExportOptions(dynamic_shapes=False))
47
        self.assertFalse(options.dynamic_shapes)
48

49

50
class TestDynamoExportAPI(common_utils.TestCase):
51
    def test_default_export(self):
52
        output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
53
        self.assertIsInstance(output, ONNXProgram)
54
        self.assertIsInstance(output.model_proto, onnx.ModelProto)
55

56
    def test_export_with_options(self):
57
        self.assertIsInstance(
58
            dynamo_export(
59
                SampleModel(),
60
                torch.randn(1, 1, 2),
61
                export_options=ExportOptions(
62
                    dynamic_shapes=True,
63
                ),
64
            ),
65
            ONNXProgram,
66
        )
67

68
    def test_save_to_file_default_serializer(self):
69
        with common_utils.TemporaryFileName() as path:
70
            dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
71
            onnx.load(path)
72

73
    def test_save_to_existing_buffer_default_serializer(self):
74
        buffer = io.BytesIO()
75
        dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
76
        onnx.load(buffer)
77

78
    def test_save_to_file_using_specified_serializer(self):
79
        expected_buffer = "I am not actually ONNX"
80

81
        class CustomSerializer(ONNXProgramSerializer):
82
            def serialize(
83
                self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
84
            ) -> None:
85
                destination.write(expected_buffer.encode())
86

87
        with common_utils.TemporaryFileName() as path:
88
            dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
89
                path, serializer=CustomSerializer()
90
            )
91
            with open(path) as fp:
92
                self.assertEqual(fp.read(), expected_buffer)
93

94
    def test_save_to_file_using_specified_serializer_without_inheritance(self):
95
        expected_buffer = "I am not actually ONNX"
96

97
        # NOTE: Inheritance from `ONNXProgramSerializer` is not required.
98
        # Because `ONNXProgramSerializer` is a Protocol class.
99
        class CustomSerializer:
100
            def serialize(
101
                self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
102
            ) -> None:
103
                destination.write(expected_buffer.encode())
104

105
        with common_utils.TemporaryFileName() as path:
106
            dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
107
                path, serializer=CustomSerializer()
108
            )
109
            with open(path) as fp:
110
                self.assertEqual(fp.read(), expected_buffer)
111

112
    def test_save_sarif_log_to_file_with_successful_export(self):
113
        with common_utils.TemporaryFileName(suffix=".sarif") as path:
114
            dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path)
115
            self.assertTrue(os.path.exists(path))
116

117
    def test_save_sarif_log_to_file_with_failed_export(self):
118
        class ModelWithExportError(torch.nn.Module):
119
            def forward(self, x):
120
                raise RuntimeError("Export error")
121

122
        with self.assertRaises(RuntimeError):
123
            dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
124
        self.assertTrue(
125
            os.path.exists(_exporter_legacy._DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH)
126
        )
127

128
    def test_onnx_program_accessible_from_exception_when_export_failed(self):
129
        class ModelWithExportError(torch.nn.Module):
130
            def forward(self, x):
131
                raise RuntimeError("Export error")
132

133
        with self.assertRaises(torch.onnx.OnnxExporterError) as cm:
134
            dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
135
        self.assertIsInstance(cm.exception, torch.onnx.OnnxExporterError)
136
        self.assertIsInstance(cm.exception.onnx_program, ONNXProgram)
137

138
    def test_access_onnx_program_model_proto_raises_when_onnx_program_is_emitted_from_failed_export(
139
        self,
140
    ):
141
        class ModelWithExportError(torch.nn.Module):
142
            def forward(self, x):
143
                raise RuntimeError("Export error")
144

145
        with self.assertRaises(torch.onnx.OnnxExporterError) as cm:
146
            dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
147
        onnx_program = cm.exception.onnx_program
148
        with self.assertRaises(RuntimeError):
149
            onnx_program.model_proto
150

151
    def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_is_true(
152
        self,
153
    ):
154
        with self.assertRaises(torch.onnx.OnnxExporterError):
155
            dynamo_export(
156
                SampleModel(),
157
                torch.randn(1, 1, 2),
158
                export_options=ExportOptions(
159
                    diagnostic_options=torch.onnx.DiagnosticOptions(
160
                        warnings_as_errors=True
161
                    )
162
                ),
163
            )
164

165

166
if __name__ == "__main__":
167
    common_utils.run_tests()
168

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.