11
import pytorch_test_common
14
from packaging import version
15
from torch.onnx import _constants, _experimental, verification
16
from torch.testing._internal import common_utils
19
class TestVerification(pytorch_test_common.ExportTestCase):
20
def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
21
class UnexportableModel(torch.nn.Module):
22
def forward(self, x, y):
28
((torch.randn(2, 3), torch.randn(2, 3)), {}),
29
((torch.randn(2, 3), torch.randn(2, 3)), {}),
32
results = verification.check_export_model_diff(
33
UnexportableModel(), test_input_groups
38
r"First diverging operator:(.|\n)*"
39
r"prim::Constant(.|\n)*"
40
r"Former source location:(.|\n)*"
41
r"Latter source location:",
44
def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
47
class UnexportableModel(torch.nn.Module):
48
def forward(self, x, y):
49
for i in range(x.size(0)):
54
((torch.randn(2, 3), torch.randn(2, 3)), {}),
55
((torch.randn(4, 3), torch.randn(2, 3)), {}),
58
export_options = _experimental.ExportOptions(
59
input_names=["x", "y"], dynamic_axes={"x": [0]}
61
results = verification.check_export_model_diff(
62
UnexportableModel(), test_input_groups, export_options
67
r"First diverging operator:(.|\n)*"
68
r"prim::Constant(.|\n)*"
69
r"Latter source location:(.|\n)*",
72
def test_check_export_model_diff_returns_empty_when_correct_export(self):
73
class SupportedModel(torch.nn.Module):
74
def forward(self, x, y):
78
((torch.randn(2, 3), torch.randn(2, 3)), {}),
79
((torch.randn(2, 3), torch.randn(2, 3)), {}),
82
results = verification.check_export_model_diff(
83
SupportedModel(), test_input_groups
85
self.assertEqual(results, "")
87
def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
90
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
91
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
92
options = verification.VerificationOptions(
98
acceptable_error_percentage=0.3,
100
verification._compare_onnx_pytorch_outputs(
106
def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
109
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
110
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
111
options = verification.VerificationOptions(
117
acceptable_error_percentage=None,
119
with self.assertRaises(AssertionError):
120
verification._compare_onnx_pytorch_outputs(
127
@common_utils.instantiate_parametrized_tests
128
class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
134
def incorrect_add_symbolic_function(g, self, other, alpha):
137
self.opset_version = _constants.ONNX_DEFAULT_OPSET
138
torch.onnx.register_custom_op_symbolic(
140
incorrect_add_symbolic_function,
141
opset_version=self.opset_version,
146
torch.onnx.unregister_custom_op_symbolic(
147
"aten::add", opset_version=self.opset_version
150
@common_utils.parametrize(
153
common_utils.subtest(
154
verification.OnnxBackend.REFERENCE,
157
version.Version(onnx.__version__) < version.Version("1.13"),
158
reason="Reference Python runtime was introduced in 'onnx' 1.13.",
162
verification.OnnxBackend.ONNX_RUNTIME_CPU,
165
def test_verify_found_mismatch_when_export_is_wrong(
166
self, onnx_backend: verification.OnnxBackend
168
class Model(torch.nn.Module):
169
def forward(self, x):
172
with self.assertRaisesRegex(AssertionError, ".*Tensor-likes are not close!.*"):
175
(torch.randn(2, 3),),
176
opset_version=self.opset_version,
177
options=verification.VerificationOptions(backend=onnx_backend),
181
@parameterized.parameterized_class(
185
{"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
187
class_name_func=lambda cls, idx, input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
189
class TestFindMismatch(pytorch_test_common.ExportTestCase):
190
onnx_backend: verification.OnnxBackend
192
graph_info: verification.GraphInfo
196
self.opset_version = _constants.ONNX_DEFAULT_OPSET
198
def incorrect_relu_symbolic_function(g, self):
199
return g.op("Add", self, g.op("Constant", value_t=torch.tensor(1.0)))
201
torch.onnx.register_custom_op_symbolic(
203
incorrect_relu_symbolic_function,
204
opset_version=self.opset_version,
207
class Model(torch.nn.Module):
210
self.layers = torch.nn.Sequential(
211
torch.nn.Linear(3, 4),
213
torch.nn.Linear(4, 5),
215
torch.nn.Linear(5, 6),
218
def forward(self, x):
219
return self.layers(x)
221
self.graph_info = verification.find_mismatch(
223
(torch.randn(2, 3),),
224
opset_version=self.opset_version,
225
options=verification.VerificationOptions(backend=self.onnx_backend),
230
torch.onnx.unregister_custom_op_symbolic(
231
"aten::relu", opset_version=self.opset_version
233
delattr(self, "opset_version")
234
delattr(self, "graph_info")
236
def test_pretty_print_tree_visualizes_mismatch(self):
238
with contextlib.redirect_stdout(f):
239
self.graph_info.pretty_print_tree()
240
self.assertExpected(f.getvalue())
242
def test_preserve_mismatch_source_location(self):
243
mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
245
self.assertTrue(len(mismatch_leaves) > 0)
247
for leaf_info in mismatch_leaves:
249
with contextlib.redirect_stdout(f):
250
leaf_info.pretty_print_mismatch(graph=True)
253
r"(.|\n)*" r"aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*",
256
def test_find_all_mismatch_operators(self):
257
mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
259
self.assertEqual(len(mismatch_leaves), 2)
261
for leaf_info in mismatch_leaves:
262
self.assertEqual(leaf_info.essential_node_count(), 1)
263
self.assertEqual(leaf_info.essential_node_kinds(), {"aten::relu"})
265
def test_find_mismatch_prints_correct_info_when_no_mismatch(self):
268
class Model(torch.nn.Module):
269
def forward(self, x):
273
with contextlib.redirect_stdout(f):
274
verification.find_mismatch(
276
(torch.randn(2, 3),),
277
opset_version=self.opset_version,
278
options=verification.VerificationOptions(backend=self.onnx_backend),
280
self.assertExpected(f.getvalue())
282
def test_export_repro_for_mismatch(self):
283
mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
284
self.assertTrue(len(mismatch_leaves) > 0)
285
leaf_info = mismatch_leaves[0]
286
with tempfile.TemporaryDirectory() as temp_dir:
287
repro_dir = leaf_info.export_repro(temp_dir)
289
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
290
options = verification.VerificationOptions(backend=self.onnx_backend)
291
verification.OnnxTestCaseRepro(repro_dir).validate(options)
294
if __name__ == "__main__":
295
common_utils.run_tests()