1
# Owner(s): ["module: onnx"]
3
"""Test consistency between the output values of torch.onnx FX exported operators
4
and torch operators given the same inputs.
10
pytest test/onnx/test_fx_op_consistency.py
12
2. To run tests on a specific operator (e.g. torch.ceil):
14
pytest test/onnx/test_fx_op_consistency.py -k ceil
15
pytest test/onnx/test_fx_op_consistency.py -k nn_functional_scaled_dot_product_attention
17
3. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g.
19
CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/test_fx_op_consistency.py -k div_mode_int
21
NOTE: Read more on Running and writing tests:
22
https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
26
1. Please make sure pytest-subtests is installed. Otherwise, the sub-tests will be ignored.
28
2. Install pytest-xdist to run tests in parallel if runng all tests is the goal.
30
3. When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
31
TESTED_OPS lists. See "Modify this section"
35
from __future__ import annotations
40
from typing import Any, Callable, Collection, Mapping, Optional, Tuple, Type, Union
42
import error_reproduction
44
import onnx_test_common
48
import pytorch_test_common
51
from onnx_test_common import skip, skip_slow, xfail
52
from torch.onnx._internal.diagnostics import _rules
53
from torch.testing._internal import (
55
common_methods_invocations,
58
from torch.testing._internal.opinfo import core as opinfo_core
61
# NOTE: For ATen signature modifications that will break ONNX export,
62
# use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip
63
# to make the signal apparent for maintainers.
64
def xfail_torchlib_forward_compatibility(
66
variant_name: str = "",
70
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
71
dtypes: Optional[Collection[torch.dtype]] = None,
72
matcher: Optional[Callable[[Any], bool]] = None,
73
enabled_if: bool = True,
75
"""Prefer using this (xfail) over skip when possible.
77
Only skip when the test is not failing consistently.
81
variant_name=variant_name,
82
reason=f"{reason}. GitHub Issue: {github_issue}",
86
enabled_if=enabled_if,
90
def skip_torchlib_forward_compatibility(
92
variant_name: str = "",
96
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
97
dtypes: Optional[Collection[torch.dtype]] = None,
98
matcher: Optional[Callable[[Any], Any]] = None,
99
enabled_if: bool = True,
101
"""Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible.
103
Only skip when the test is not failing consistently.
107
variant_name=variant_name,
108
reason=f"{reason}. GitHub Issue: {github_issue}",
112
enabled_if=enabled_if,
117
# Turn off black formatting to keep the list compact
119
# Expected failures for onnx export.
120
# The list should be sorted alphabetically by op name.
121
# Q: When should I use fixme vs vs skip vs xfail?
122
# A: Prefer xfail over skip when possible.
123
# 2a. If a test is now failing because of xpass, because some previous errors
124
# are now fixed, removed the corresponding xfail.
125
# 2b. If a test is not failing consistently, use skip.
126
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
129
reason="io_adaper doesn't support __getitem__ input slice(0, 3, None)",
133
dtypes=onnx_test_common.BOOL_TYPES,
134
reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "bool"),
138
reason="fixme: Assertion error: result mismatch",
142
dtypes=onnx_test_common.INT_TYPES,
143
reason=onnx_test_common.reason_onnx_does_not_support("Pow", "int"),
146
"_native_batch_norm_legit",
147
dtypes=(torch.float16,),
148
reason="fixme: Assertion error: result mismatch and type error",
151
"_softmax_backward_data",
152
reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)")
155
"add", dtypes=onnx_test_common.BOOL_TYPES,
156
reason=onnx_test_common.reason_onnx_does_not_support("Add")
160
dtypes=(torch.uint8, torch.int8, torch.int16,),
161
reason=onnx_test_common.reason_onnx_script_does_not_support(
162
"Add", "int8, int16, uint8 have type issue."
167
dtypes=onnx_test_common.COMPLEX_TYPES,
168
reason=onnx_test_common.reason_dynamo_does_not_support("Addbmm", "complex64")
171
"addmm", dtypes=onnx_test_common.BOOL_TYPES,
172
reason=onnx_test_common.reason_onnx_does_not_support("Addmm")
176
variant_name="decomposed",
177
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
178
reason=onnx_test_common.reason_onnx_does_not_support("Addmm")
181
"addmm", dtypes=onnx_test_common.COMPLEX_TYPES,
182
reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)")
186
variant_name="decomposed",
187
dtypes=onnx_test_common.COMPLEX_TYPES,
188
reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)")
192
dtypes=onnx_test_common.BOOL_TYPES,
193
reason=onnx_test_common.reason_onnx_script_does_not_support(
199
dtypes=onnx_test_common.COMPLEX_TYPES,
200
reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64")
204
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
205
"Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0")
209
reason=onnx_test_common.reason_dynamo_does_not_support("Allclose")
213
dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
214
reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"),
217
"amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
218
reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16")
222
dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
223
reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"),
227
reason=onnx_test_common.reason_onnx_does_not_support(
228
"Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0")
232
dtypes=(torch.uint8,),
233
reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"),
237
dtypes=(torch.int16, torch.int32),
238
reason="AssertionError: The values for attribute 'shape' do not match",
246
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
247
"ArgMax", "int16, int64"
258
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
259
"ArgMin", "uint8, int8, int16, int64"
264
reason="fixme: Assertion error: result mismatch",
268
variant_name="partial_views",
269
reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults",
273
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
274
reason="fixme: Assertion error: result mismatch",
283
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
284
"Matmul", "uint8, int8, int16"
289
dtypes=onnx_test_common.COMPLEX_TYPES,
290
reason=onnx_test_common.reason_dynamo_does_not_support("baddbmm", "complex64")
294
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
298
reason="fixme: ORT errors with RuntimeError: No corresponding Numpy type for Tensor Type.",
302
reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"),
311
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
312
"Matmul", "uint8, int8, int16"
317
reason=onnx_test_common.reason_dynamo_does_not_support("output is int"),
321
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
324
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
325
reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int")
329
reason="fixme: ONNX shape type inference error: Invalid tensor data type 0."
332
"chunk", dtypes=onnx_test_common.BOOL_TYPES,
333
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Chunk", "bool")
337
dtypes=(torch.uint8, torch.int8, torch.int16,),
338
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
339
"Chunk", "uint8, int8, int16"
344
dtypes=(torch.uint8, torch.int8, torch.int16,),
345
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
346
"Max", "uint8, int8, int16"
350
"clamp_max", dtypes=onnx_test_common.BOOL_TYPES,
351
reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool")
355
dtypes=(torch.uint8, torch.int8, torch.int16,),
356
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
357
"Max", "uint8, int8, int16"
362
dtypes=(torch.uint8, torch.int8, torch.int16,),
363
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
364
"Max", "uint8, int8, int16"
368
"clamp_min", dtypes=onnx_test_common.BOOL_TYPES,
369
reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool")
373
dtypes=(torch.int16,),
374
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
375
"Constant_pad_nd", "int16"
380
dtypes=onnx_test_common.COMPLEX_TYPES,
381
reason=onnx_test_common.reason_dynamo_does_not_support(
382
"Constant_pad_nd", "complex64"
387
reason=onnx_test_common.reason_dynamo_does_not_support(
393
reason=onnx_test_common.reason_dynamo_does_not_support(
398
"cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,),
399
reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16")
403
reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"),
407
reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"),
410
"dot", dtypes=(torch.uint8, torch.int8, torch.int16,),
411
reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16")
415
dtypes=onnx_test_common.COMPLEX_TYPES,
416
reason=onnx_test_common.reason_dynamo_does_not_support("Dot", "complex64(core dump)"),
420
dtypes=onnx_test_common.COMPLEX_TYPES,
421
reason="fixme: kwargs dtpye=complex64 is not supported in ONNX."
425
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
429
dtypes=(torch.uint8, torch.int8, torch.int16,),
430
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"),
434
reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default")
438
reason=onnx_test_common.reason_dynamo_does_not_support("exponential"),
442
reason="fixme: Assertion error: result mismatch",
446
reason="fixme: Assertion error: result mismatch",
450
reason="fixme: Assertion error: result mismatch",
454
reason="fixme: Assertion error: result mismatch",
458
reason="fixme: Assertion error: result mismatch",
462
reason="fixme: Assertion error: result mismatch",
466
reason="fixme: Assertion error: result mismatch",
470
reason="fixme: Assertion error: result mismatch",
474
reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"),
478
reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"),
482
reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"),
486
reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"),
490
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
491
reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"),
495
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
496
reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"),
500
dtypes=onnx_test_common.COMPLEX_TYPES,
501
reason=onnx_test_common.reason_dynamo_does_not_support("full", "complex64")
505
dtypes=onnx_test_common.COMPLEX_TYPES,
506
reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64")
510
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
514
dtypes=onnx_test_common.BOOL_TYPES,
515
reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"),
519
dtypes=onnx_test_common.COMPLEX_TYPES,
520
reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64")
524
dtypes=onnx_test_common.BOOL_TYPES,
525
reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"),
529
dtypes=(torch.uint8, torch.int8, torch.int16,),
530
reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"),
534
dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES,
535
reason=onnx_test_common.reason_onnx_does_not_support("IsNaN", "int, bool"),
539
reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
543
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
547
dtypes=onnx_test_common.COMPLEX_TYPES,
548
reason=onnx_test_common.reason_dynamo_does_not_support("lerp", "complex64")
552
reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"),
556
variant_name="grad_oriented",
557
reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"),
561
reason="fixme: Assertion error: result mismatch",
565
variant_name="subgradients_at_zero",
566
reason="fixme: Assertion error: result mismatch",
570
reason="fixme: Assertion error: result shape mismatch",
574
dtypes=(torch.int64, torch.int32,),
575
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
579
variant_name="tensor_overload",
580
dtypes=(torch.int64, torch.int32,),
581
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
585
dtypes=onnx_test_common.COMPLEX_TYPES,
586
reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64")
590
variant_name="tensor_overload",
591
dtypes=onnx_test_common.COMPLEX_TYPES,
592
reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64")
596
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
600
dtypes=(torch.float16,),
601
reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
605
variant_name="with_dtype",
606
dtypes=(torch.float16,),
607
reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
611
reason=onnx_test_common.reason_onnx_does_not_support(
612
"Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0")
616
dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES,
617
reason=onnx_test_common.reason_onnx_script_does_not_support("And", "float, int"),
621
dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES,
622
reason=onnx_test_common.reason_onnx_script_does_not_support("Not", "float, int"),
626
dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES,
627
reason=onnx_test_common.reason_onnx_script_does_not_support("Or", "float, int"),
631
dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES,
632
reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"),
636
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
637
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceLogSumExp", "bool, int"),
641
reason="fixme: https://github.com/onnx/onnx/issues/4986",
645
reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
649
reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
653
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,),
654
reason="fixme: Assertion error: result mismatch",
658
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,),
659
reason="fixme: Assertion error: result mismatch",
663
dtypes=onnx_test_common.BOOL_TYPES,
664
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
668
dtypes=onnx_test_common.BOOL_TYPES,
669
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
672
"masked.log_softmax",
673
dtypes=(torch.float16,),
674
reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
678
dtypes=onnx_test_common.BOOL_TYPES,
679
reason=onnx_test_common.reason_onnx_does_not_support("ReduceMean", "bool"),
683
reason="fixme: Assertion error: result mismatch",
687
dtypes=onnx_test_common.BOOL_TYPES,
688
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
692
reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked_select.default"),
696
variant_name="reduction_no_dim",
697
dtypes=onnx_test_common.BOOL_TYPES,
698
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"),
702
variant_name="reduction_with_dim",
703
dtypes=onnx_test_common.BOOL_TYPES,
704
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"),
708
variant_name="reduction_with_dim",
709
reason="https://github.com/onnx/onnx/issues/4986",
713
reason="(ReduceMean) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
717
variant_name="reduction_no_dim",
718
dtypes=onnx_test_common.BOOL_TYPES,
719
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"),
723
variant_name="reduction_with_dim",
724
dtypes=onnx_test_common.BOOL_TYPES + (torch.int64,),
725
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"),
729
dtypes=onnx_test_common.COMPLEX_TYPES,
730
reason=onnx_test_common.reason_dynamo_does_not_support("MM", "complex64(core dump)"),
734
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
738
reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default")
742
dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES,
743
reason=onnx_test_common.reason_onnx_runtime_does_not_support("IsNaN", "int, bool"),
747
reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
751
dtypes=(torch.float16,),
752
reason="fixme: https://github.com/microsoft/onnxscript/issues/1269",
756
dtypes=(torch.float16,),
757
reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
761
dtypes=onnx_test_common.COMPLEX_TYPES,
762
reason=onnx_test_common.reason_dynamo_does_not_support("new_full", "complex64")
765
"nn.functional.adaptive_avg_pool2d",
766
reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \
767
maximum recursion depth exceeded while calling a Python object"),
770
"nn.functional.adaptive_avg_pool3d",
771
reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"),
774
"nn.functional.alpha_dropout",
775
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
778
"nn.functional.avg_pool1d",
779
dtypes=onnx_test_common.INT_TYPES,
780
reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
783
"nn.functional.avg_pool2d",
784
dtypes=onnx_test_common.INT_TYPES,
785
reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
788
"nn.functional.avg_pool3d",
789
dtypes=onnx_test_common.INT_TYPES,
790
reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
793
"nn.functional.batch_norm",
794
dtypes=(torch.float16,),
795
reason="fixme: https://github.com/microsoft/onnxscript/issues/1270",
798
"nn.functional.conv_transpose1d",
799
dtypes=(torch.int64,),
800
reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"),
803
"nn.functional.conv_transpose2d",
804
dtypes=(torch.int64,),
805
reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"),
808
"nn.functional.conv_transpose3d",
809
dtypes=(torch.int64,),
810
reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"),
813
"nn.functional.conv_transpose1d",
814
reason="fixme: Assertion error: result mismatch",
817
"nn.functional.conv_transpose2d",
818
reason="fixme: Assertion error: result mismatch",
821
"nn.functional.conv_transpose3d",
822
reason="fixme: Assertion error: result mismatch",
825
"nn.functional.conv1d",
826
dtypes=(torch.int64,),
827
reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"),
830
"nn.functional.conv2d",
831
dtypes=(torch.int64,),
832
reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"),
835
"nn.functional.conv2d",
836
dtypes=onnx_test_common.COMPLEX_TYPES,
837
reason="fixme: Assertion error: result mismatch",
840
"nn.functional.conv3d",
841
dtypes=(torch.int64,),
842
reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"),
845
"nn.functional.conv3d",
846
dtypes=onnx_test_common.COMPLEX_TYPES,
847
reason="fixme: Assertion error: result mismatch",
850
"nn.functional.ctc_loss",
851
reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"),
854
"nn.functional.dropout",
855
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
858
"nn.functional.dropout2d",
859
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
862
"nn.functional.dropout3d",
863
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
866
"nn.functional.feature_alpha_dropout",
867
variant_name="with_train",
868
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
871
"nn.functional.feature_alpha_dropout",
872
variant_name="without_train",
873
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
876
"nn.functional.fractional_max_pool2d",
877
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
880
"nn.functional.fractional_max_pool3d",
881
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
884
"nn.functional.gaussian_nll_loss",
885
reason=onnx_test_common.reason_dynamo_does_not_support("aten.gaussian_nll_loss"),
888
"nn.functional.grid_sample",
889
reason="fixme: Assertion error: result mismatch",
892
"nn.functional.group_norm",
893
dtypes=(torch.float16,),
894
reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"),
897
"nn.functional.local_response_norm",
898
dtypes=(torch.int64,),
899
reason=onnx_test_common.reason_onnx_runtime_does_not_support("avgpool", "int64"),
902
"nn.functional.linear",
903
dtypes=onnx_test_common.INT_TYPES,
904
reason=onnx_test_common.reason_onnx_does_not_support("Gemm", "int"),
907
"nn.functional.max_pool2d",
908
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
909
reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"),
912
"nn.functional.max_pool3d",
913
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
914
reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"),
917
"nn.functional.multi_head_attention_forward",
918
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
921
"nn.functional.one_hot",
922
reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
926
variant_name="replicate",
927
reason="fixme: ORT error: padding size",
931
variant_name="replicate_negative",
932
reason="fixme: Assertion error: result mismatch",
936
variant_name="reflect",
937
reason="fixme: Assertion error: result mismatch",
940
"nn.functional.rrelu",
941
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
944
"nn.functional.rrelu",
945
dtypes=(torch.int64,),
946
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Relu", "int64"),
949
"nn.functional.scaled_dot_product_attention",
950
matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
951
reason="dropout is random so the results do not match",
954
"nn.functional.scaled_dot_product_attention",
955
dtypes=(torch.float16,),
956
reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
959
"nn.functional.selu",
960
reason="fixme: nn.functional.selu is not in torch._decomp.decomposition_table",
963
"nn.functional.soft_margin_loss",
964
dtypes=(torch.float16,),
965
reason="fixme: Assertion error: result mismatch",
968
"nn.functional.tanhshrink",
969
dtypes=(torch.float16,),
970
reason="fixme: Assertion error: result mismatch",
974
dtypes=(torch.int8, torch.int16),
975
reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"),
979
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
983
variant_name="in_place",
984
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
988
variant_name="number_mean",
989
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
993
dtypes=onnx_test_common.COMPLEX_TYPES,
994
reason="fixme: kwargs dtpye=complex64 is not supported in ONNX."
998
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1002
reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default")
1006
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1010
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1014
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1018
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1022
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1026
reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_")
1030
reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_")
1034
dtypes=onnx_test_common.INT_TYPES,
1035
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Round", "int"),
1039
dtypes=(torch.uint8, torch.int8, torch.int16),
1040
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
1041
"Mul", "uint8, int8, int16"
1046
dtypes=(torch.float16,),
1047
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
1052
dtypes=(torch.float16,),
1053
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
1057
variant_name="prod",
1058
dtypes=(torch.float16,),
1059
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
1063
variant_name="amin",
1064
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
1065
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
1069
variant_name="amax",
1070
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
1071
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
1075
variant_name="mean",
1076
reason="ONNX doesn't support reduce='mean' option",
1080
dtypes=onnx_test_common.BOOL_TYPES,
1081
reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"),
1084
"signal.windows.kaiser",
1085
reason=onnx_test_common.reason_dynamo_does_not_support("functionalization"),
1089
dtypes=(torch.float16,),
1090
reason="ORT error: https://github.com/microsoft/onnxruntime/issues/16438"
1094
variant_name="reduce",
1095
reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"),
1098
"sparse.sampled_addmm",
1099
reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"),
1103
dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES,
1104
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Erf", "int, bool"),
1108
dtypes=onnx_test_common.FLOAT_TYPES,
1109
reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"),
1113
dtypes=(torch.float16,),
1114
reason="fixme: Assertion error: result mismatch",
1118
dtypes=onnx_test_common.BOOL_TYPES,
1119
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1123
variant_name="list_args",
1124
dtypes=onnx_test_common.BOOL_TYPES,
1125
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1129
dtypes=onnx_test_common.BOOL_TYPES,
1130
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1134
dtypes=(torch.int8, torch.uint8, torch.int16),
1135
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"),
1139
reason="fixme: Assertion error: result mismatch",
1143
variant_name="multiple",
1144
reason="fixme: https://github.com/microsoft/onnxscript/issues/1264",
1148
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1152
reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple."
1156
variant_name="unbiased",
1157
reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple."
1161
reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"),
1165
dtypes=(torch.uint8, torch.int8, torch.int16),
1166
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
1167
"Mul", "uint8, int8, int16"
1172
reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
1176
reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
1180
dtypes=(torch.int32, torch.int64, torch.float16, torch.float32, torch.bool, torch.complex64),
1181
# model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1182
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1183
reason="This op requires torch.dtype as input, which is not supported currently.",
1187
dtypes=(torch.int64, torch.int32),
1188
reason="fixme: Assertion error: result mismatch",
1192
dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,),
1193
reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"),
1197
dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,),
1198
reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"),
1202
dtypes=onnx_test_common.INT_TYPES,
1203
reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"),
1207
dtypes=onnx_test_common.BOOL_TYPES,
1208
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1212
dtypes=onnx_test_common.BOOL_TYPES,
1213
reason=onnx_test_common.reason_onnx_does_not_support("Unflatten")
1217
reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1221
reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"),
1224
"unique_consecutive",
1225
reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"),
1229
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
1230
reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"),
1234
dtypes=onnx_test_common.BOOL_TYPES,
1235
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1239
dtypes=onnx_test_common.BOOL_TYPES,
1240
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1244
dtypes=onnx_test_common.BOOL_TYPES,
1245
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
1249
dtypes=onnx_test_common.COMPLEX_TYPES,
1250
reason="fixme: kwargs dtpye=complex64 is not supported in ONNX."
1252
# SLOW TESTS (All are xfails if we run them)
1253
# TODO: https://github.com/pytorch/pytorch/issues/117118
1256
reason="fixme: Test sets are too many.",
1260
reason="fixme: Test sets are too many.",
1264
reason="fixme: Test sets are too many.",
1268
reason="fixme: Test sets are too many.",
1271
"linalg.solve_triangular",
1272
reason="fixme: Test sets are too many.",
1276
reason="fixme: Test sets are too many.",
1280
reason="fixme: Test sets are too many.",
1284
variant_name="tensor_overload",
1285
reason="fixme: Test sets are too many.",
1288
"max_pool2d_with_indices_backward",
1289
reason="fixme: Test sets are too many.",
1292
"nn.functional.interpolate",
1293
variant_name="bicubic",
1294
reason="fixme: Test sets are too many.",
1297
"nn.functional.max_unpool1d",
1298
reason="fixme: Test sets are too many.",
1301
"nn.functional.max_unpool2d",
1302
reason="fixme: Test sets are too many.",
1305
"nn.functional.max_unpool3d",
1306
reason="fixme: Test sets are too many.",
1309
"nn.functional.max_pool1d",
1310
reason="fixme: Test sets are too many.",
1313
"nn.functional.max_pool2d",
1314
reason="fixme: Test sets are too many.",
1317
"nn.functional.max_pool3d",
1318
reason="fixme: Test sets are too many.",
1321
"nn.functional.unfold",
1322
reason="fixme: Test sets are too many.",
1326
reason="fixme: Test sets are too many.",
1330
reason="fixme: Test sets are too many.",
1334
reason="fixme: Test sets are too many.",
1339
SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
1341
"_native_batch_norm_legit",
1342
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1343
reason="https://github.com/pytorch/pytorch/issues/115106",
1346
"addmm", # xfail can't only use dtypes to catch all cases
1347
matcher=lambda sample: sample.input.dtype
1348
in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
1349
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
1350
"Gemm", "uint8, int8, int16, int32, int64"
1355
matcher=lambda sample: sample.args[0].numel() == 0,
1356
reason="ONNX Runtime does not support empty tensors multiplication",
1360
variant_name="decomposed",
1361
matcher=lambda sample: sample.args[0].numel() == 0,
1362
reason="ONNX Runtime does not support empty tensors multiplication",
1366
matcher=lambda sample: len(sample.input.shape) == 0
1367
and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()),
1368
reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
1372
matcher=lambda sample: len(sample.input.shape) == 0
1373
and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()),
1374
reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
1378
matcher=lambda sample: len(sample.input.shape) == 0
1379
and sample.kwargs.get("dim") is not None,
1380
reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
1384
matcher=lambda sample: sample.input[0].equal(torch.tensor([])),
1385
reason="core dump - cat does not support zero-dim tensors yet",
1389
matcher=lambda sample: len(sample.input.shape) < 2,
1390
reason="fixme: https://github.com/microsoft/onnxscript/issues/1212",
1394
matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1,
1395
reason="fixme: aten::index_put indices contains None when dim is -1",
1399
matcher=lambda sample: len(sample.input.shape) < 2,
1400
reason="fixme: https://github.com/microsoft/onnxscript/issues/1212",
1404
matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1,
1405
reason="fixme: aten::index_put indices contains None when dim is -1",
1409
matcher=lambda sample: (sample.args[0][0].dtype == torch.bool)
1410
and (sample.kwargs.get("accumulate") is False),
1411
reason=onnx_test_common.reason_dynamo_does_not_support(
1412
"https://github.com/pytorch/pytorch/issues/101150"
1417
matcher=lambda sample: sum([torch.numel(input) for input in sample.input]) == 0,
1418
reason="fixme: Undefined",
1422
matcher=lambda sample: len(sample.input.shape) == 0,
1423
reason="fixme: LogSoftMax does not support empty tensor as input",
1427
variant_name="with_dtype",
1428
matcher=lambda sample: len(sample.input.shape) == 0,
1429
reason="fixme: LogSoftMax does not support empty tensor as input",
1433
matcher=lambda sample: isinstance(sample.input, torch.Tensor)
1434
and len(sample.input.shape) == 0,
1435
reason="fixme: IsScalar",
1438
"masked.log_softmax",
1439
matcher=lambda sample: len(sample.input.shape) == 0,
1440
reason="fixme: LogSoftMax does not support empty tensor as input",
1444
matcher=lambda sample: torch.numel(sample.input) == 0,
1445
reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
1449
variant_name="reduction_with_dim",
1450
matcher=lambda sample: len(sample.input.shape) == 0,
1451
reason="fixme: https://github.com/onnx/onnx/issues/4986",
1455
matcher=lambda sample: torch.numel(sample.input) == 0,
1456
reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
1459
"native_batch_norm",
1460
matcher=lambda sample: sample.args[-3] is True
1461
and any(arg is not None for arg in sample.args[2:4]),
1462
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1463
reason="https://github.com/pytorch/pytorch/issues/115106",
1466
"nn.functional.avg_pool1d",
1467
matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True)
1469
sample.kwargs.get("count_include_pad") is True
1470
or sample.input.shape[2]
1473
if isinstance(sample.args[0], tuple)
1478
reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
1481
"nn.functional.avg_pool2d",
1482
matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
1483
or (sample.kwargs.get("divisor_override") is not None),
1484
reason="ONNX doesn't support divisor_override argument",
1487
"nn.functional.avg_pool3d",
1488
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True,
1489
reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
1492
"nn.functional.avg_pool3d",
1493
matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
1494
or (sample.kwargs.get("divisor_override") is not None),
1495
reason="ONNX doesn't support divisor_override argument",
1498
"nn.functional.batch_norm",
1499
matcher=lambda sample: sample.kwargs.get("training") is True
1500
and any(arg is not None for arg in sample.args[2:4]),
1501
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1502
reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106",
1505
"nn.functional.conv2d",
1506
matcher=lambda sample: sample.kwargs.get("padding") == "valid",
1507
reason="fixme: https://github.com/pytorch/pytorch/issues/117054",
1510
"nn.functional.conv3d",
1511
matcher=lambda sample: sample.kwargs.get("padding") == "valid",
1512
reason="fixme: https://github.com/pytorch/pytorch/issues/117054",
1515
"nn.functional.cross_entropy",
1516
matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int),
1517
reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type",
1520
"nn.functional.embedding",
1521
matcher=lambda sample: sample.kwargs.get("max_norm") is not None,
1522
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1523
reason="https://github.com/pytorch/pytorch/issues/115106",
1525
skip_torchlib_forward_compatibility(
1526
"nn.functional.embedding_bag",
1527
matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True,
1528
reason=onnx_test_common.reason_onnx_script_does_not_support(
1529
"'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. "
1530
"'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided"
1532
github_issue="https://github.com/microsoft/onnxscript/issues/1056",
1535
"nn.functional.group_norm",
1536
matcher=lambda sample: torch.numel(sample.input) == 0,
1537
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
1538
"Reshape", "empty tensor"
1542
"nn.functional.instance_norm",
1543
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1544
matcher=lambda sample: sample.kwargs.get("running_mean") is not None
1545
or sample.input.dtype in (torch.float16,),
1546
reason="fixme: KeyError: 'self___kwargs__running_mean'",
1549
"nn.functional.max_pool3d",
1550
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True
1551
and sample.kwargs.get("padding") == 1,
1552
reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed",
1556
matcher=lambda sample: len(sample.input.shape) == 0
1557
and sample.kwargs.get("as_tuple", False) is False,
1558
reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).",
1559
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1563
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1564
reason=onnx_test_common.reason_onnx_script_does_not_support(
1565
"aten::_assert_async.msg",
1566
"https://github.com/pytorch/pytorch/issues/112443",
1571
matcher=lambda sample: len(sample.input.shape) == 0,
1572
reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch",
1576
variant_name="amax",
1577
# ONNX has not include_self parameter and default is include_self=True mode
1578
matcher=lambda sample: sample.kwargs.get("include_self") is False,
1579
reason="ONNX does't support include_self=False option",
1583
variant_name="amin",
1584
# ONNX has not include_self parameter and default is include_self=True mode
1585
matcher=lambda sample: sample.kwargs.get("include_self") is False,
1586
reason="ONNX does't support include_self=False option",
1590
variant_name="prod",
1591
# ONNX has not include_self parameter and default is include_self=True mode
1592
matcher=lambda sample: sample.kwargs.get("include_self") is False,
1593
reason="ONNX does't support include_self=False option",
1598
# ONNX has not include_self parameter and default is include_self=True mode
1599
matcher=lambda sample: sample.kwargs.get("include_self") is False,
1600
reason="ONNX does't support include_self=False option",
1604
matcher=lambda sample: len(sample.input.shape) == 0,
1605
reason="fixme: LogSoftMax does not support empty tensor as input",
1609
matcher=lambda sample: isinstance(sample.input, torch.Tensor)
1610
and len(sample.input.shape) < 2,
1611
reason="fixme: IsScalar",
1615
reason="Logic not implemented for size 0 inputs in op.Reshape",
1616
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
1619
"signal.windows.hamming",
1620
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1621
reason="does not match node name",
1624
"signal.windows.general_hamming",
1625
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1626
reason="does not match node name",
1629
"signal.windows.blackman",
1630
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1631
reason="does not match node name",
1634
"signal.windows.general_cosine",
1635
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1636
reason="does not match node name",
1639
"signal.windows.hann",
1640
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1641
reason="does not match node name",
1644
"signal.windows.nuttall",
1645
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1646
reason="does not match node name",
1650
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
1651
OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
1652
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
1655
class SingleOpModel(torch.nn.Module):
1656
"""Test model to wrap around a single op for export."""
1658
def __init__(self, op, kwargs):
1661
self.kwargs = kwargs
1663
def forward(self, *args):
1664
return self.operator(*args, **self.kwargs)
1667
def _should_skip_xfail_test_sample(
1669
variant_test_name: str,
1671
model_type: pytorch_test_common.TorchModelType,
1672
) -> Tuple[Optional[str], Optional[str]]:
1673
"""Check if the test sample should be skipped or xfailed.
1675
If the xfail/skip decorator meta is matched with its op_name and model_type,
1676
return the test_behavior and reason. Otherwise, return None, None. Note that
1677
if the matcher is None, the test is decorator_meta is meant to skip/xfail all model types.
1680
op_name: The name of the op.
1681
sample: The test sample.
1682
model_type: The model type of the test.
1685
A tuple of (test_behavior, reason). test_behavior is either "skip" or "xfail".
1686
reason is the reason for the test_behavior.
1689
if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
1691
for decorator_meta in SKIP_XFAIL_SUBTESTS:
1692
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
1693
# NOTE: If model_type is None, the test is decorator_meta is meant to skip/xfail all model types.
1695
decorator_meta.op_name == op_name
1696
and decorator_meta.variant_name == variant_test_name
1698
model_type == decorator_meta.model_type or decorator_meta.model_type is None
1700
if decorator_meta.matcher is None and decorator_meta.model_type is None:
1702
"Either Matcher or model_type must be defined in sub xfail and skip."
1704
if decorator_meta.matcher is not None and decorator_meta.matcher(sample):
1705
return decorator_meta.test_behavior, decorator_meta.reason
1706
elif decorator_meta.matcher is None:
1707
# xfail/skip the whole test of the model type without matcher
1708
return decorator_meta.test_behavior, decorator_meta.reason
1712
def _compare_onnx_and_torch_exported_program(
1713
torch_exported_program,
1714
onnx_exported_program,
1722
only_check_shape=False,
1724
# avoid mutable default argument
1725
if input_kwargs is None:
1728
# NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
1729
# Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
1730
# Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
1731
onnx_outputs = onnx_exported_program(*input_args, **input_kwargs)
1732
torch_outputs = torch_exported_program(*input_args, **input_kwargs)
1733
torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx(
1736
if len(torch_outputs_onnx_format) != len(onnx_outputs):
1737
raise AssertionError(
1738
f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}"
1741
for j, (torch_output, onnx_output) in enumerate(
1742
zip(torch_outputs_onnx_format, onnx_outputs)
1744
if only_check_shape:
1745
assert torch_output.shape == onnx_output.shape
1748
torch.testing.assert_close(
1749
torch.tensor(onnx_output),
1755
except AssertionError as e:
1756
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
1757
error_reproduction.create_mismatch_report(
1760
onnx_exported_program.model_proto,
1763
torch.tensor(onnx_output),
1767
if len(torch_outputs_onnx_format) > 1:
1768
raise AssertionError(f"Output {j} mismatch") from e
1772
def _run_test_output_match(
1773
test_suite: onnx_test_common._TestONNXRuntime,
1776
op: opinfo_core.OpInfo,
1778
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
1779
assert device == "cpu"
1780
samples = op.sample_inputs(
1783
requires_grad=False,
1785
for i, cpu_sample in enumerate(samples):
1786
inputs = (cpu_sample.input, *cpu_sample.args)
1787
# Provide the repr to subtest because tensors are not serializable in parallel test runs
1789
with test_suite.subTest(
1790
opset=test_suite.opset_version,
1792
inputs=repr(inputs),
1793
kwargs=repr(cpu_sample.kwargs),
1795
test_behavior, reason = _should_skip_xfail_test_sample(
1796
op.name, op.variant_test_name, cpu_sample, test_suite.model_type
1798
with onnx_test_common.normal_xfail_skip_test_behaviors(
1799
test_behavior, reason
1801
model = SingleOpModel(op.op, cpu_sample.kwargs)
1805
dtype == torch.float32
1806
and op.name in test_suite.fp32_low_precision_dict
1808
rtol = test_suite.fp32_low_precision_dict[op.name][0]
1809
atol = test_suite.fp32_low_precision_dict[op.name][1]
1810
elif dtype == torch.float32:
1811
# Relax atol and rtol for float32 based on empirical results
1815
dtype == torch.float16
1816
and (op.name, op.variant_test_name)
1817
in test_suite.fp16_low_precision_variant_dict
1819
rtol = test_suite.fp16_low_precision_variant_dict[
1820
(op.name, op.variant_test_name)
1822
atol = test_suite.fp16_low_precision_variant_dict[
1823
(op.name, op.variant_test_name)
1826
dtype == torch.float16
1827
and op.name in test_suite.fp16_low_precision_dict
1829
rtol = test_suite.fp16_low_precision_dict[op.name][0]
1830
atol = test_suite.fp16_low_precision_dict[op.name][1]
1836
test_suite.model_type
1837
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
1840
model = torch.export.export(model, inputs)
1841
except AssertionError as e:
1842
# NOTE: avoid fake_mode detection bug in torch.export.export
1844
onnx_test_common.reason_dynamo_does_not_support(str(e))
1848
onnx_program = torch.onnx.dynamo_export(
1852
except torch.onnx.OnnxExporterError as e:
1853
# NOTE: If the model has unsupported nodes, we will skip the test
1854
# with non-strict xfail. Otherwise, we will raise the error.
1856
e.__cause__, "diagnostic"
1857
) and e.__cause__.diagnostic.rule in (
1858
_rules._POERules.no_symbolic_function_for_call_function,
1859
_rules._POERules.unsupported_fx_node_analysis,
1862
onnx_test_common.reason_onnx_script_does_not_support(str(e))
1866
_compare_onnx_and_torch_exported_program(
1870
test_name=test_suite.id(),
1872
sample_kwargs=cpu_sample.kwargs,
1875
only_check_shape=(op.name in test_suite.only_shape_check_list),
1879
def _parameterized_class_attrs_and_values():
1881
input_values.extend(
1883
(opset for opset in onnx_test_common.FX_TESTED_OPSETS),
1885
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1886
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1891
"attrs": ["opset_version", "model_type"],
1892
"input_values": input_values,
1896
def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
1897
"""Combine class name with the parameterized arguments.
1899
This function is passed to `parameterized.parameterized_class` as the
1900
`class_name_func` argument.
1903
for k, v in input_dicts.items():
1904
suffixes.append(f"{k}_{v}")
1905
return f"{cls.__name__}_{'_'.join(suffixes)}"
1908
@parameterized.parameterized_class(
1909
**_parameterized_class_attrs_and_values(),
1910
class_name_func=_parameterize_class_name,
1912
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
1913
"""Test output consistency between exported ONNX models and PyTorch eager mode.
1915
This is a parameterized test suite.
1919
op_level_debug: bool = False
1920
dynamic_shapes: bool = False
1921
model_type: pytorch_test_common.TorchModelType = (
1922
pytorch_test_common.TorchModelType.TORCH_NN_MODULE
1925
# NOTE: Follow torchlib settings in ops_test_data.py
1926
only_shape_check_list = [
1931
"new_empty_strided",
1934
fp32_low_precision_dict = {
1935
"native_layer_norm": [2e-4, 7e-4],
1938
fp16_low_precision_dict = {
1939
"addbmm": [2e-1, 2e-2],
1940
"addcdiv": [3e-2, 1e-3],
1941
"addcmul": [3e-2, 1e-3],
1942
"addmv": [5e-2, 3e-2],
1943
"addr": [3e-3, 4e-3],
1944
"baddbmm": [3e-2, 1e-3],
1945
"cumulative_trapezoid": [3e-2, 1e-3],
1946
"diff": [1e-2, 5e-2],
1947
"gradient": [3e-3, 4e-3],
1948
"linalg.multi_dot": [3e-2, 1e-3],
1949
"linalg.vecdot": [1e-2, 2e-2],
1950
"linspace": [2e-2, 2e-3],
1951
"masked.std": [2e-2, 2e-3],
1952
"masked.var": [2e-2, 2e-2],
1953
"matmul": [2e-2, 6e-2],
1954
"nn.functional.batch_norm": [3e-2, 1e-3],
1955
"nn.functional.binary_cross_entropy": [3e-2, 1e-3],
1956
"nn.functional.binary_cross_entropy_with_logits": [3e-2, 1e-3],
1957
"nn.functional.cosine_similarity": [3e-2, 1e-3],
1958
"nn.functional.cosine_embedding_loss": [1e-2, 1e-3],
1959
"nn.functional.hardsigmoid": [1e-3, 5e-3],
1960
"nn.functional.hardswish": [1e-3, 5e-3],
1961
"nn.functional.hinge_embedding_loss": [4e-1, 3e-3],
1962
"nn.functional.instance_norm": [1e-2, 1e-3],
1963
"nn.functional.interpolate": [1e-2, 1e-3],
1964
"nn.functional.kl_div": [2e-3, 2e-4],
1965
"nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3],
1966
"nn.functional.local_response_norm": [1e-2, 5e-3],
1967
"nn.functional.poisson_nll_loss": [3e-2, 1e-3],
1968
"native_batch_norm": [3e-2, 1e-3],
1969
"dot": [3e-2, 1e-3],
1970
"logit": [3e-2, 1e-3],
1971
"rsub": [3e-2, 1e-3],
1972
"sinc": [2e-1, 6e-4],
1973
"sub": [3e-2, 1e-3],
1974
"trapezoid": [1e-3, 7e-3],
1975
"trapz": [1e-3, 7e-3],
1978
fp16_low_precision_variant_dict = {
1979
("nn.functional.interpolate", "trilinear"): [3e-2, 3e-3],
1980
("nn.functional.interpolate", "linear"): [3e-2, 3e-3],
1983
@common_device_type.ops(
1984
[op for op in OPS_DB if op.name in ALL_OPS_IN_DB],
1985
allowed_dtypes=onnx_test_common.TESTED_DTYPES,
1987
def test_output_match(self, device: str, dtype: torch.dtype, op):
1988
"""Test the ONNX exporter."""
1989
_run_test_output_match(self, device, dtype, op)
1992
for opset in onnx_test_common.FX_TESTED_OPSETS:
1993
for model_type in pytorch_test_common.TorchModelType:
1994
# The name needs to match the parameterized_class name.
1995
test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}"
1996
onnx_test_common.add_decorate_info(
1999
"test_output_match",
2001
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
2004
common_device_type.instantiate_device_type_tests(
2005
globals()[test_class_name], globals(), only_for="cpu"
2008
if __name__ == "__main__":
2009
common_utils.run_tests()