1
# Owner(s): ["module: onnx"]
3
from __future__ import annotations
32
import pytorch_test_common
34
from torch import export as torch_export
35
from torch.onnx import _constants, verification
36
from torch.onnx._internal import _beartype
37
from torch.onnx._internal.fx import diagnostics
38
from torch.testing._internal import common_utils
39
from torch.testing._internal.opinfo import core as opinfo_core
40
from torch.types import Number
42
_NumericType = Union[Number, torch.Tensor, np.ndarray]
43
_ModelType = Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
44
_InputArgsType = Optional[
45
Union[torch.Tensor, int, float, bool, Sequence[Any], Mapping[str, Any]]
47
_OutputsType = Sequence[_NumericType]
49
onnx_model_dir = os.path.join(
50
os.path.dirname(os.path.realpath(__file__)),
61
pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
64
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
67
def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
68
options = verification.VerificationOptions()
70
kwargs["opset_version"] = test_suite.opset_version
71
kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
72
if hasattr(test_suite, "check_shape"):
73
options.check_shape = test_suite.check_shape
74
if hasattr(test_suite, "check_dtype"):
75
options.check_dtype = test_suite.check_dtype
77
names = {f.name for f in dataclasses.fields(options)}
79
for k, v in kwargs.items():
81
setattr(options, k, v)
82
keywords_to_pop.append(k)
83
for k in keywords_to_pop:
86
return verification.verify(*args, options=options, **kwargs)
89
def assert_dynamic_shapes(onnx_program: torch.onnx.ONNXProgram, dynamic_shapes: bool):
90
"""Assert whether the exported model has dynamic shapes or not.
93
onnx_program (torch.onnx.ONNXProgram): The output of torch.onnx.dynamo_export.
94
dynamic_shapes (bool): Whether the exported model has dynamic shapes or not.
95
When True, raises if graph inputs don't have at least one dynamic dimension
96
When False, raises if graph inputs have at least one dynamic dimension.
99
AssertionError: If the exported model has dynamic shapes and dynamic_shapes is False and vice-versa.
102
if dynamic_shapes is None:
105
model_proto = onnx_program.model_proto
106
# Process graph inputs
108
for inp in model_proto.graph.input:
111
for dim in inp.type.tensor_type.shape.dim
112
if dim.dim_value == 0 and dim.dim_param != ""
114
assert dynamic_shapes == (
115
len(dynamic_inputs) > 0
116
), "Dynamic shape check failed for graph inputs"
119
def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
120
"""Combine class name with the parameterized arguments.
122
This function is passed to `parameterized.parameterized_class` as the
123
`class_name_func` argument.
125
suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items())
126
return f"{cls.__name__}_{suffix}"
129
class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
130
opset_version = _constants.ONNX_DEFAULT_OPSET
131
keep_initializers_as_inputs = True # For IR version 3 type export.
138
onnxruntime.set_seed(0)
139
if torch.cuda.is_available():
140
torch.cuda.manual_seed_all(0)
141
os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0"
142
self.is_script_test_enabled = True
144
# The exported ONNX model may have less inputs than the pytorch model because of const folding.
145
# This mostly happens in unit test, where we widely use torch.size or torch.shape.
146
# So the output is only dependent on the input shape, not value.
147
# remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model.
155
do_constant_folding=True,
157
additional_test_inputs=None,
160
fixed_batch_size=False,
161
training=torch.onnx.TrainingMode.EVAL,
162
remained_onnx_input_idx=None,
165
def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True):
166
return run_model_test(
169
input_args=input_args,
170
input_kwargs=input_kwargs,
173
do_constant_folding=do_constant_folding,
174
dynamic_axes=dynamic_axes,
175
additional_test_inputs=additional_test_inputs,
176
input_names=input_names,
177
output_names=output_names,
178
fixed_batch_size=fixed_batch_size,
180
remained_onnx_input_idx=remained_onnx_input_idx,
182
ignore_none=ignore_none,
186
if isinstance(remained_onnx_input_idx, dict):
187
scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"]
188
tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"]
190
scripting_remained_onnx_input_idx = remained_onnx_input_idx
191
tracing_remained_onnx_input_idx = remained_onnx_input_idx
193
is_model_script = isinstance(
194
model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)
197
if self.is_script_test_enabled and self.is_script:
198
script_model = model if is_model_script else torch.jit.script(model)
201
scripting_remained_onnx_input_idx,
205
if not is_model_script and not self.is_script:
206
_run_test(model, tracing_remained_onnx_input_idx)
209
def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
212
input_args: Sequence[_InputArgsType],
214
input_kwargs: Optional[Mapping[str, _InputArgsType]] = None,
215
rtol: Optional[float] = 1e-3,
216
atol: Optional[float] = 1e-7,
217
has_mutation: bool = False,
218
additional_test_inputs: Optional[
221
Tuple[Sequence[_InputArgsType], Mapping[str, _InputArgsType]],
222
Tuple[Sequence[_InputArgsType]],
226
skip_dynamic_shapes_check: bool = False,
228
"""Compare the results of PyTorch model with exported ONNX model
231
model (_ModelType): PyTorch model
232
input_args (Sequence[_InputArgsType]): torch input arguments
233
input_kwargs (Mapping[str, _InputArgsType]): torch input kwargs
234
rtol (float, optional): relative tolerance. Defaults to 1e-3.
235
atol (float, optional): absolute tolerance. Defaults to 1e-7.
236
has_mutation (bool, optional): Whether the model mutates its input or state.
237
`mutation` as `True` incurs extra overhead of cloning the inputs and model.
239
additional_test_inputs: Test the models with another dataset input, which
240
is designed for dynamic axes testing. Defaults to None. It's a list of
241
different input sets in tuples. Inside tuple, the first element is a tuple
242
of args, and the second element is a dict of kwargs. Remember to put comma
243
even if the following element is not provided.
245
additional_test_inputs = [((args1, args2), {"kwargs":1}), ((args1,),), ((), {"kwargs":1})]
246
skip_dynamic_shapes_check: Whether to skip dynamic shape check. Defaults to False.
247
Must be used when tests do not produce dynamic shapes even when dynamic shape feature is enabled.
248
This is needed because Torch Dynamo uses the dynamic_shapes flag as a hint, only.
252
# avoid mutable data structure
253
if input_kwargs is None:
259
!= pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
261
ref_model = _try_clone_model(model)
262
ref_input_args, ref_input_kwargs = _try_clone_inputs(
263
input_args, input_kwargs
267
ref_input_args = input_args
268
ref_input_kwargs = input_kwargs
270
assert isinstance(ref_model, torch.nn.Module) or callable(
272
), "Model must be a torch.nn.Module or callable"
275
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
277
ref_model = torch.export.export(ref_model, args=ref_input_args)
280
): # TODO: Support dynamic shapes for torch.export.ExportedProgram
281
# https://github.com/pytorch/pytorch/issues/113705
283
reason="torch.export.ExportedProgram does not support dynamic shapes"
286
# Feed args and kwargs into exporter.
287
# Note that exporter should flatten kwargs into positional args the exported model;
288
# since ONNX doesn't represent kwargs.
289
export_error: Optional[torch.onnx.OnnxExporterError] = None
291
onnx_program = torch.onnx.dynamo_export(
295
export_options=torch.onnx.ExportOptions(
296
op_level_debug=self.op_level_debug,
297
dynamic_shapes=self.dynamic_shapes,
298
diagnostic_options=torch.onnx.DiagnosticOptions(
299
verbosity_level=logging.DEBUG
303
except torch.onnx.OnnxExporterError as e:
305
onnx_program = e.onnx_program
307
if diagnostics.is_onnx_diagnostics_log_artifact_enabled():
308
onnx_program.save_diagnostics(
309
f"test_report_{self._testMethodName}"
310
f"_op_level_debug_{self.op_level_debug}"
311
f"_dynamic_axes_{self.dynamic_shapes}"
312
f"_model_type_{self.model_type}"
316
if export_error is not None:
319
if not skip_dynamic_shapes_check:
320
assert_dynamic_shapes(onnx_program, self.dynamic_shapes)
322
_compare_pytorch_onnx_with_ort(
329
has_mutation=has_mutation,
331
# This confirms the exported mode accepts different input shapes
332
# when dynamic shape is enabled.
333
if additional_test_inputs and self.dynamic_shapes:
334
for another_input in additional_test_inputs:
335
if len(another_input) > 2:
337
f"test_inputs should only have tuple args and dictionary kwargs. But receives: {len(another_input)}"
339
additional_input_args = another_input[0]
340
additional_input_kwargs = (
342
if len(another_input) == 2 and another_input[1] is not None
345
_compare_pytorch_onnx_with_ort(
348
additional_input_args,
349
additional_input_kwargs,
352
has_mutation=has_mutation,
358
onnx_model: Union[str, torch.onnx.ONNXProgram],
359
pytorch_inputs: Sequence[_InputArgsType],
361
"""Run ORT on the given ONNX model and inputs
363
Used in test_fx_to_onnx_with_onnxruntime.py
366
onnx_model (Union[str, torch.onnx.ONNXProgram]): Converter ONNX model
367
pytorch_inputs (Sequence[_InputArgsType]): The given torch inputs
370
AssertionError: ONNX and PyTorch should have the same input sizes
373
_OutputsType: ONNX model predictions
375
if isinstance(onnx_model, torch.onnx.ONNXProgram):
376
buffer = io.BytesIO()
377
onnx_model.save(buffer)
378
ort_model = buffer.getvalue()
380
ort_model = onnx_model
382
# Suppress floods of warnings from ONNX Runtime
383
session_options = onnxruntime.SessionOptions()
384
session_options.log_severity_level = 3 # Error
385
session = onnxruntime.InferenceSession(
386
ort_model, providers=["CPUExecutionProvider"], sess_options=session_options
388
input_names = [ort_input.name for ort_input in session.get_inputs()]
390
if len(input_names) != len(pytorch_inputs):
391
raise AssertionError(
392
f"Expected {len(input_names)} inputs, got {len(pytorch_inputs)}"
396
k: torch.Tensor.numpy(v, force=True)
397
for k, v in zip(input_names, pytorch_inputs)
399
return session.run(None, ort_input)
403
def _try_clone_model(model: _ModelType) -> _ModelType:
404
"""Used for preserving original model in case forward mutates model states."""
406
return copy.deepcopy(model)
409
"Failed to clone model. Model state might be mutated during verification."
415
def _try_clone_inputs(input_args, input_kwargs):
416
ref_input_args = copy.deepcopy(input_args)
417
ref_input_kwargs = copy.deepcopy(input_kwargs)
418
return ref_input_args, ref_input_kwargs
422
def _compare_pytorch_onnx_with_ort(
423
onnx_program: torch.onnx.ONNXProgram,
425
input_args: Sequence[_InputArgsType],
426
input_kwargs: Mapping[str, _InputArgsType],
427
atol: Optional[float] = None,
428
rtol: Optional[float] = None,
429
has_mutation: bool = False,
432
ref_model = _try_clone_model(model)
433
ref_input_args, ref_input_kwargs = _try_clone_inputs(input_args, input_kwargs)
436
ref_input_args = input_args
437
ref_input_kwargs = input_kwargs
439
# NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
440
# Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
441
# Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
442
# NOTE: `model_with_state_dict=ref_model` is specified to cover runs with FakeTensor support
443
ort_outputs = onnx_program(*input_args, **input_kwargs)
444
ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs)
445
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs)
447
if len(ref_outputs) != len(ort_outputs):
448
raise AssertionError(
449
f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}"
452
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
453
torch.testing.assert_close(
454
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
458
# The min onnx opset version to test for
459
MIN_ONNX_OPSET_VERSION = 9
460
# The max onnx opset version to test for
461
MAX_ONNX_OPSET_VERSION = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
462
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
464
# The min onnx opset version to test for
465
FX_MIN_ONNX_OPSET_VERSION = 18
466
# The max onnx opset version to test for
467
FX_MAX_ONNX_OPSET_VERSION = 18
468
FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1)
470
BOOL_TYPES = (torch.bool,)
488
# torch.float64, ORT doesn't support
492
# torch.complex32, NOTE: torch.complex32 is experimental in torch
494
# torch.complex128, ORT doesn't support
509
@dataclasses.dataclass
511
"""Information about a test case to skip or xfail.
513
Adapted from functorch: functorch/test/common_utils.py
516
op_name: The name of the operator.
517
variant_name: The name of the OpInfo variant.
518
decorator: The decorator to apply to the test case.
519
opsets: The opsets to apply the decorator to.
520
dtypes: The dtypes to apply the decorator to.
521
reason: The reason for skipping.
522
test_behavior: The behavior of the test case. [skip or xfail]
523
matcher: The matcher to apply to the test case.
524
enabled_if: Whether to enable test behavior. Usually used on onnx/ort version control
525
model_type: The type of the torch model. Defaults to None.
531
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
532
dtypes: Optional[Collection[torch.dtype]]
535
matcher: Optional[Callable[[Any], bool]] = None
536
enabled_if: bool = True
537
model_type: Optional[pytorch_test_common.TorchModelType] = None
539
def contains_opset(self, opset: int) -> bool:
540
if self.opsets is None:
543
opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
544
for opset_spec in self.opsets
550
variant_name: str = "",
553
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
554
dtypes: Optional[Collection[torch.dtype]] = None,
555
matcher: Optional[Callable[[Any], bool]] = None,
556
enabled_if: bool = True,
557
model_type: Optional[pytorch_test_common.TorchModelType] = None,
559
"""Expects a OpInfo test to fail.
562
op_name: The name of the operator.
563
variant_name: The name of the variant.
564
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
565
dtypes: The dtypes to expect the failure.
566
reason: The reason for the failure.
567
matcher: A function that matches the test sample input. It is used only when
568
xfail is in the SKIP_XFAIL_SUBTESTS list.
569
enabled_if: Whether to enable xfail. Usually used on onnx/ort version control
570
model_type: The type of the torch model. Defaults to None.
574
variant_name=variant_name,
575
decorator=unittest.expectedFailure,
578
enabled_if=enabled_if,
581
test_behavior="xfail",
582
model_type=model_type,
588
variant_name: str = "",
591
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
592
dtypes: Optional[Collection[torch.dtype]] = None,
593
matcher: Optional[Callable[[Any], Any]] = None,
594
enabled_if: bool = True,
595
model_type: Optional[pytorch_test_common.TorchModelType] = None,
597
"""Skips a test case in OpInfo that we don't care about.
599
Likely because ONNX does not support the use case or it is by design.
602
op_name: The name of the operator.
603
variant_name: The name of the variant.
604
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
605
dtypes: The dtypes to expect the failure.
606
reason: The reason for the failure.
607
matcher: A function that matches the test sample input. It is used only when
608
skip is in the SKIP_XFAIL_SUBTESTS list.
609
enabled_if: Whether to enable skip. Usually used on onnx/ort version control
610
model_type: The type of the torch model. Defaults to None.
614
variant_name=variant_name,
615
decorator=unittest.skip(f"Skip: {reason}"),
620
enabled_if=enabled_if,
621
test_behavior="skip",
622
model_type=model_type,
628
variant_name: str = "",
631
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
632
dtypes: Optional[Collection[torch.dtype]] = None,
633
matcher: Optional[Callable[[Any], Any]] = None,
634
model_type: Optional[pytorch_test_common.TorchModelType] = None,
636
"""Skips a test case in OpInfo that is too slow.
638
It needs further investigation to understand why it is slow.
641
op_name: The name of the operator.
642
variant_name: The name of the variant.
643
opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
644
dtypes: The dtypes to expect the failure.
645
reason: The reason for the failure.
646
matcher: A function that matches the test sample input. It is used only when
647
skip is in the SKIP_XFAIL_SUBTESTS list.
648
model_type: The type of the torch model. Defaults to None.
652
variant_name=variant_name,
653
decorator=common_utils.slowTest,
658
enabled_if=not common_utils.TEST_WITH_SLOW,
659
test_behavior="skip",
660
model_type=model_type,
664
def add_decorate_info(
665
all_opinfos: Sequence[opinfo_core.OpInfo],
666
test_class_name: str,
669
skip_or_xfails: Iterable[DecorateMeta],
671
"""Decorates OpInfo tests with decorators based on the skip_or_xfails list.
674
all_opinfos: All OpInfos.
675
test_class_name: The name of the test class.
676
base_test_name: The name of the test method.
677
opset: The opset to decorate for.
678
skip_or_xfails: DecorateMeta's.
680
ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
681
for decorate_meta in skip_or_xfails:
682
if not decorate_meta.contains_opset(opset):
683
# Skip does not apply to this opset
685
opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
688
), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
689
decorators = list(opinfo.decorators)
690
new_decorator = opinfo_core.DecorateInfo(
691
decorate_meta.decorator,
694
dtypes=decorate_meta.dtypes,
695
active_if=decorate_meta.enabled_if,
697
decorators.append(new_decorator)
698
opinfo.decorators = tuple(decorators)
700
# This decorator doesn't modify fn in any way
707
def opsets_before(opset: int) -> Callable[[int], bool]:
708
"""Returns a comparison function that decides if the given opset is before the specified."""
710
def compare(other_opset: int):
711
return other_opset < opset
716
def opsets_after(opset: int) -> Callable[[int], bool]:
717
"""Returns a comparison function that decides if the given opset is after the specified."""
719
def compare(other_opset: int):
720
return other_opset > opset
725
def reason_onnx_script_does_not_support(
726
operator: str, dtypes: Optional[Sequence[str]] = None
728
"""Formats the reason: ONNX script doesn't support the given dtypes."""
729
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX script"
732
def reason_onnx_runtime_does_not_support(
733
operator: str, dtypes: Optional[Sequence[str]] = None
735
"""Formats the reason: ONNX Runtime doesn't support the given dtypes."""
736
return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
739
def reason_onnx_does_not_support(
740
operator: str, dtypes: Optional[Sequence[str]] = None
742
"""Formats the reason: ONNX doesn't support the given dtypes."""
743
return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
746
def reason_dynamo_does_not_support(
747
operator: str, dtypes: Optional[Sequence[str]] = None
749
"""Formats the reason: Dynamo doesn't support the given dtypes."""
751
f"{operator} on {dtypes or 'certain dtypes'} not supported by the Dynamo Spec"
755
def reason_jit_tracer_error(info: str) -> str:
756
"""Formats the reason: JIT tracer errors."""
757
return f"JIT tracer error on {info}"
760
def reason_flaky() -> str:
761
"""Formats the reason: test is flaky."""
765
@contextlib.contextmanager
766
def normal_xfail_skip_test_behaviors(
767
test_behavior: Optional[str] = None, reason: Optional[str] = None
769
"""This context manager is used to handle the different behaviors of xfail and skip.
772
test_behavior (optional[str]): From DecorateMeta name, can be 'skip', 'xfail', or None.
773
reason (optional[str]): The reason for the failure or skip.
776
e: Any exception raised by the test case if it's not an expected failure.
779
# We need to skip as soon as possible, as SegFault might also be a case.
780
if test_behavior == "skip":
781
pytest.skip(reason=reason)
785
# We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs
786
# to go over all test cases to find the right exception type.
787
except Exception as e: # pylint: disable=broad-exception-caught
788
if test_behavior is None:
790
if test_behavior == "xfail":
791
pytest.xfail(reason=reason)
793
if test_behavior == "xfail":
794
pytest.fail("Test unexpectedly passed")