1
# Owner(s): ["module: onnx"]
8
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
9
from torch.onnx._internal import _exporter_legacy
10
from torch.onnx._internal._exporter_legacy import (
11
ONNXProgramSerializer,
12
ResolvedExportOptions,
14
from torch.testing._internal import common_utils
17
class SampleModel(torch.nn.Module):
24
class SampleModelTwoInputs(torch.nn.Module):
25
def forward(self, x, b):
31
class SampleModelForDynamicShapes(torch.nn.Module):
32
def forward(self, x, b):
33
return x.relu(), b.sigmoid()
36
class TestExportOptionsAPI(common_utils.TestCase):
37
def test_dynamic_shapes_default(self):
38
options = ResolvedExportOptions(ExportOptions())
39
self.assertFalse(options.dynamic_shapes)
41
def test_dynamic_shapes_explicit(self):
42
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=None))
43
self.assertFalse(options.dynamic_shapes)
44
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=True))
45
self.assertTrue(options.dynamic_shapes)
46
options = ResolvedExportOptions(ExportOptions(dynamic_shapes=False))
47
self.assertFalse(options.dynamic_shapes)
50
class TestDynamoExportAPI(common_utils.TestCase):
51
def test_default_export(self):
52
output = dynamo_export(SampleModel(), torch.randn(1, 1, 2))
53
self.assertIsInstance(output, ONNXProgram)
54
self.assertIsInstance(output.model_proto, onnx.ModelProto)
56
def test_export_with_options(self):
57
self.assertIsInstance(
61
export_options=ExportOptions(
68
def test_save_to_file_default_serializer(self):
69
with common_utils.TemporaryFileName() as path:
70
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(path)
73
def test_save_to_existing_buffer_default_serializer(self):
75
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
78
def test_save_to_file_using_specified_serializer(self):
79
expected_buffer = "I am not actually ONNX"
81
class CustomSerializer(ONNXProgramSerializer):
83
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
85
destination.write(expected_buffer.encode())
87
with common_utils.TemporaryFileName() as path:
88
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
89
path, serializer=CustomSerializer()
91
with open(path) as fp:
92
self.assertEqual(fp.read(), expected_buffer)
94
def test_save_to_file_using_specified_serializer_without_inheritance(self):
95
expected_buffer = "I am not actually ONNX"
97
# NOTE: Inheritance from `ONNXProgramSerializer` is not required.
98
# Because `ONNXProgramSerializer` is a Protocol class.
99
class CustomSerializer:
101
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
103
destination.write(expected_buffer.encode())
105
with common_utils.TemporaryFileName() as path:
106
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
107
path, serializer=CustomSerializer()
109
with open(path) as fp:
110
self.assertEqual(fp.read(), expected_buffer)
112
def test_save_sarif_log_to_file_with_successful_export(self):
113
with common_utils.TemporaryFileName(suffix=".sarif") as path:
114
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path)
115
self.assertTrue(os.path.exists(path))
117
def test_save_sarif_log_to_file_with_failed_export(self):
118
class ModelWithExportError(torch.nn.Module):
119
def forward(self, x):
120
raise RuntimeError("Export error")
122
with self.assertRaises(RuntimeError):
123
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
125
os.path.exists(_exporter_legacy._DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH)
128
def test_onnx_program_accessible_from_exception_when_export_failed(self):
129
class ModelWithExportError(torch.nn.Module):
130
def forward(self, x):
131
raise RuntimeError("Export error")
133
with self.assertRaises(torch.onnx.OnnxExporterError) as cm:
134
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
135
self.assertIsInstance(cm.exception, torch.onnx.OnnxExporterError)
136
self.assertIsInstance(cm.exception.onnx_program, ONNXProgram)
138
def test_access_onnx_program_model_proto_raises_when_onnx_program_is_emitted_from_failed_export(
141
class ModelWithExportError(torch.nn.Module):
142
def forward(self, x):
143
raise RuntimeError("Export error")
145
with self.assertRaises(torch.onnx.OnnxExporterError) as cm:
146
dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2))
147
onnx_program = cm.exception.onnx_program
148
with self.assertRaises(RuntimeError):
149
onnx_program.model_proto
151
def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_is_true(
154
with self.assertRaises(torch.onnx.OnnxExporterError):
157
torch.randn(1, 1, 2),
158
export_options=ExportOptions(
159
diagnostic_options=torch.onnx.DiagnosticOptions(
160
warnings_as_errors=True
166
if __name__ == "__main__":
167
common_utils.run_tests()