1
# Owner(s): ["module: onnx"]
3
"""Test consistency between the output values of torch.onnx exported operators
4
and torch operators given the same inputs.
8
pytest test/onnx/test_op_consistency.py
10
To run tests on a specific operator (e.g. torch.ceil):
12
pytest test/onnx/test_op_consistency.py -k ceil
13
pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention
15
Read more on Running and writing tests:
16
https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
20
When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
21
TESTED_OPS lists. See "Modify this section"
25
from __future__ import annotations
28
from typing import Optional, Tuple
30
import onnx_test_common
33
# For readability, these two are allowed to be imported as function
34
from onnx_test_common import skip, xfail
37
from torch.testing._internal import (
39
common_methods_invocations,
44
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
46
# Modify this section ##########################################################
47
# NOTE: Modify this section as more ops are supported. The list should be sorted
50
# For example, to add a test for torch.ceil:
51
# 1. Add "ceil" to TESTED_OPS then run pytest.
52
# 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.
54
# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
55
# Ops to be tested for numerical consistency between onnx and pytorch
56
# TODO: https://github.com/pytorch/pytorch/issues/102211
57
TESTED_OPS: frozenset[str] = frozenset(
61
# "atleast_1d", # How to support list input?
71
"nn.functional.scaled_dot_product_attention",
86
# Turn off black formatting to keep the list compact
88
# Expected failures for onnx export.
89
# The list should be sorted alphabetically by op name.
90
# Q: When should I use fixme vs vs skip vs xfail?
91
# A: Prefer xfail over skip when possible.
92
# 2a. If a test is now failing because of xpass, because some previous errors
93
# are now fixed, removed the corresponding xfail.
94
# 2b. If a test is not failing consistently, use skip.
95
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
97
"atan", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
98
reason=onnx_test_common.reason_onnx_does_not_support("Atan")
100
xfail("atan", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
102
"atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
103
reason=onnx_test_common.reason_onnx_does_not_support("Atan")
106
"atan2", dtypes=[torch.float64],
107
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])
110
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
111
reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
113
skip("hstack", opsets=[onnx_test_common.opsets_before(11)],
114
reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")),
117
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
118
reason=onnx_test_common.reason_onnx_does_not_support("Log", "bool, int"),
120
skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."),
121
skip("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
122
xfail("round", opsets=[onnx_test_common.opsets_before(11)],
123
reason=onnx_test_common.reason_onnx_does_not_support("Round")),
124
xfail("round", variant_name="decimals_0", opsets=[onnx_test_common.opsets_before(11)],
125
reason=onnx_test_common.reason_onnx_does_not_support("Round")),
126
xfail("round", variant_name="decimals_3", opsets=[onnx_test_common.opsets_before(11)],
127
reason=onnx_test_common.reason_onnx_does_not_support("Round")),
128
xfail("round", variant_name="decimals_neg_3", opsets=[onnx_test_common.opsets_before(11)],
129
reason=onnx_test_common.reason_onnx_does_not_support("Round")),
130
skip("scatter_reduce", variant_name="amin", opsets=[onnx_test_common.opsets_before(16)],
131
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
132
skip("scatter_reduce", variant_name="amax", opsets=[onnx_test_common.opsets_before(16)],
133
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
134
skip("scatter_reduce", variant_name="prod", opsets=[onnx_test_common.opsets_before(16)],
135
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
136
xfail("scatter_reduce", variant_name="mean",
137
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction=mean")),
138
skip("scatter_reduce", variant_name="sum", opsets=[onnx_test_common.opsets_before(16)],
139
reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
143
dtypes=(torch.float16,),
144
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
149
dtypes=(torch.float16,),
150
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
155
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
156
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
161
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
162
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
167
reason="ONNX doesn't support reduce='mean' option",
169
skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")),
170
skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")),
172
reason=onnx_test_common.reason_onnx_runtime_does_not_support("STFT", "Regression on ORT=1.15 4 percent difference")),
173
skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")),
174
xfail("unflatten", opsets=[onnx_test_common.opsets_before(13)], reason="Helper function is needed to support legacy ops."),
175
skip("vstack", opsets=[onnx_test_common.opsets_before(11)],
176
reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")),
180
SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
182
"nn.functional.scaled_dot_product_attention",
183
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
184
reason="dropout is random so the results do not match",
188
reason="Empty repeats value leads to an invalid graph",
189
matcher=lambda sample: not sample.args[0],
193
# ONNX has not include_self parameter and default is include_self=True mode
194
matcher=lambda sample: sample.kwargs.get("include_self") is False,
195
reason="ONNX does't support include_self=False option",
199
reason="ONNX STFT does not support complex results",
200
matcher=lambda sample: sample.kwargs.get("return_complex") is True,
204
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape)
205
or not sample.input.shape,
206
reason="Logic not implemented for size 0 inputs in op.Reshape",
210
reason="Logic not implemented for size 0 inputs in op.Reshape",
211
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
216
# END OF SECTION TO MODIFY #####################################################
218
OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
219
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
220
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
221
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
224
class SingleOpModel(torch.nn.Module):
225
"""Test model to wrap around a single op for export."""
227
def __init__(self, op, kwargs):
232
def forward(self, *args):
233
return self.operator(*args, **self.kwargs)
236
def _should_skip_xfail_test_sample(
238
) -> Tuple[Optional[str], Optional[str]]:
239
"""Returns a reason if a test sample should be skipped."""
240
if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
242
for decorator_meta in SKIP_XFAIL_SUBTESTS:
243
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
244
if decorator_meta.op_name == op_name:
245
assert decorator_meta.matcher is not None, "Matcher must be defined"
246
if decorator_meta.matcher(sample):
247
return decorator_meta.test_behavior, decorator_meta.reason
251
def _get_test_class_name(cls, num, params_dict) -> str:
254
return params_dict["name"]
257
@parameterized.parameterized_class(
260
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
261
"opset_version": opset,
263
for opset in onnx_test_common.TESTED_OPSETS
265
class_name_func=_get_test_class_name,
267
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
268
"""Test output consistency between exported ONNX models and PyTorch eager mode.
270
This is a parameterized test suite.
275
@common_device_type.ops(
276
[op for op in OPS_DB if op.name in TESTED_OPS],
277
allowed_dtypes=onnx_test_common.INT_TYPES
278
+ onnx_test_common.FLOAT_TYPES
279
+ onnx_test_common.BOOL_TYPES,
281
def test_output_match(self, device: str, dtype: torch.dtype, op):
282
"""Test the ONNX exporter."""
283
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
284
assert device == "cpu"
286
samples = op.sample_inputs(
292
for i, cpu_sample in enumerate(samples):
293
inputs = (cpu_sample.input, *cpu_sample.args)
294
# Provide the repr to subtest because tensors are not serializable in parallel test runs
296
opset=self.opset_version,
299
kwargs=repr(cpu_sample.kwargs),
301
test_behavior, reason = _should_skip_xfail_test_sample(
304
with onnx_test_common.normal_xfail_skip_test_behaviors(
305
test_behavior, reason
307
model = SingleOpModel(op, cpu_sample.kwargs)
310
if dtype == torch.float32:
311
# Relax atol and rtol for float32 based on empirical results
312
# The current most relaxed values are for aten::stft
315
elif dtype == torch.float64:
316
# The current most relaxed values are for aten::stft
323
self.run_test(model, inputs, rtol=rtol, atol=atol)
326
for opset in onnx_test_common.TESTED_OPSETS:
327
# The name needs to match the parameterized_class name.
328
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
329
onnx_test_common.add_decorate_info(
334
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
336
common_device_type.instantiate_device_type_tests(
337
globals()[test_class_name], globals(), only_for="cpu"
341
if __name__ == "__main__":
342
common_utils.run_tests()