12
import pytorch_test_common
13
from packaging import version
16
from torch.onnx import _constants, _experimental, verification
17
from torch.testing._internal import common_utils
20
class TestVerification(pytorch_test_common.ExportTestCase):
21
def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
22
class UnexportableModel(torch.nn.Module):
23
def forward(self, x, y):
29
((torch.randn(2, 3), torch.randn(2, 3)), {}),
30
((torch.randn(2, 3), torch.randn(2, 3)), {}),
33
results = verification.check_export_model_diff(
34
UnexportableModel(), test_input_groups
39
r"First diverging operator:(.|\n)*"
40
r"prim::Constant(.|\n)*"
41
r"Former source location:(.|\n)*"
42
r"Latter source location:",
45
def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
48
class UnexportableModel(torch.nn.Module):
49
def forward(self, x, y):
50
for i in range(x.size(0)):
55
((torch.randn(2, 3), torch.randn(2, 3)), {}),
56
((torch.randn(4, 3), torch.randn(2, 3)), {}),
59
export_options = _experimental.ExportOptions(
60
input_names=["x", "y"], dynamic_axes={"x": [0]}
62
results = verification.check_export_model_diff(
63
UnexportableModel(), test_input_groups, export_options
68
r"First diverging operator:(.|\n)*"
69
r"prim::Constant(.|\n)*"
70
r"Latter source location:(.|\n)*",
73
def test_check_export_model_diff_returns_empty_when_correct_export(self):
74
class SupportedModel(torch.nn.Module):
75
def forward(self, x, y):
79
((torch.randn(2, 3), torch.randn(2, 3)), {}),
80
((torch.randn(2, 3), torch.randn(2, 3)), {}),
83
results = verification.check_export_model_diff(
84
SupportedModel(), test_input_groups
86
self.assertEqual(results, "")
88
def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
91
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
92
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
93
options = verification.VerificationOptions(
99
acceptable_error_percentage=0.3,
101
verification._compare_onnx_pytorch_outputs(
107
def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
110
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
111
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
112
options = verification.VerificationOptions(
118
acceptable_error_percentage=None,
120
with self.assertRaises(AssertionError):
121
verification._compare_onnx_pytorch_outputs(
128
@common_utils.instantiate_parametrized_tests
129
class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase):
135
def incorrect_add_symbolic_function(g, self, other, alpha):
138
self.opset_version = _constants.ONNX_DEFAULT_OPSET
139
torch.onnx.register_custom_op_symbolic(
141
incorrect_add_symbolic_function,
142
opset_version=self.opset_version,
147
torch.onnx.unregister_custom_op_symbolic(
148
"aten::add", opset_version=self.opset_version
151
@common_utils.parametrize(
154
common_utils.subtest(
155
verification.OnnxBackend.REFERENCE,
158
version.Version(onnx.__version__) < version.Version("1.13"),
159
reason="Reference Python runtime was introduced in 'onnx' 1.13.",
163
verification.OnnxBackend.ONNX_RUNTIME_CPU,
166
def test_verify_found_mismatch_when_export_is_wrong(
167
self, onnx_backend: verification.OnnxBackend
169
class Model(torch.nn.Module):
170
def forward(self, x):
173
with self.assertRaisesRegex(AssertionError, ".*Tensor-likes are not close!.*"):
176
(torch.randn(2, 3),),
177
opset_version=self.opset_version,
178
options=verification.VerificationOptions(backend=onnx_backend),
182
@parameterized.parameterized_class(
186
{"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU},
188
class_name_func=lambda cls,
190
input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}",
192
class TestFindMismatch(pytorch_test_common.ExportTestCase):
193
onnx_backend: verification.OnnxBackend
195
graph_info: verification.GraphInfo
199
self.opset_version = _constants.ONNX_DEFAULT_OPSET
201
def incorrect_relu_symbolic_function(g, self):
202
return g.op("Add", self, g.op("Constant", value_t=torch.tensor(1.0)))
204
torch.onnx.register_custom_op_symbolic(
206
incorrect_relu_symbolic_function,
207
opset_version=self.opset_version,
210
class Model(torch.nn.Module):
211
def __init__(self) -> None:
213
self.layers = torch.nn.Sequential(
214
torch.nn.Linear(3, 4),
216
torch.nn.Linear(4, 5),
218
torch.nn.Linear(5, 6),
221
def forward(self, x):
222
return self.layers(x)
224
self.graph_info = verification.find_mismatch(
226
(torch.randn(2, 3),),
227
opset_version=self.opset_version,
228
options=verification.VerificationOptions(backend=self.onnx_backend),
233
torch.onnx.unregister_custom_op_symbolic(
234
"aten::relu", opset_version=self.opset_version
236
delattr(self, "opset_version")
237
delattr(self, "graph_info")
239
def test_pretty_print_tree_visualizes_mismatch(self):
241
with contextlib.redirect_stdout(f):
242
self.graph_info.pretty_print_tree()
243
self.assertExpected(f.getvalue())
245
def test_preserve_mismatch_source_location(self):
246
mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
248
self.assertTrue(len(mismatch_leaves) > 0)
250
for leaf_info in mismatch_leaves:
252
with contextlib.redirect_stdout(f):
253
leaf_info.pretty_print_mismatch(graph=True)
256
r"(.|\n)*" r"aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*",
259
def test_find_all_mismatch_operators(self):
260
mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
262
self.assertEqual(len(mismatch_leaves), 2)
264
for leaf_info in mismatch_leaves:
265
self.assertEqual(leaf_info.essential_node_count(), 1)
266
self.assertEqual(leaf_info.essential_node_kinds(), {"aten::relu"})
268
def test_find_mismatch_prints_correct_info_when_no_mismatch(self):
271
class Model(torch.nn.Module):
272
def forward(self, x):
276
with contextlib.redirect_stdout(f):
277
verification.find_mismatch(
279
(torch.randn(2, 3),),
280
opset_version=self.opset_version,
281
options=verification.VerificationOptions(backend=self.onnx_backend),
283
self.assertExpected(f.getvalue())
285
def test_export_repro_for_mismatch(self):
286
mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info()
287
self.assertTrue(len(mismatch_leaves) > 0)
288
leaf_info = mismatch_leaves[0]
289
with tempfile.TemporaryDirectory() as temp_dir:
290
repro_dir = leaf_info.export_repro(temp_dir)
292
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
293
options = verification.VerificationOptions(backend=self.onnx_backend)
294
verification.OnnxTestCaseRepro(repro_dir).validate(options)
297
if __name__ == "__main__":
298
common_utils.run_tests()