pytorch

Форк
0
/
test_op_consistency.py 
342 строки · 13.1 Кб
1
# Owner(s): ["module: onnx"]
2

3
"""Test consistency between the output values of torch.onnx exported operators
4
and torch operators given the same inputs.
5

6
Usage:
7

8
    pytest test/onnx/test_op_consistency.py
9

10
    To run tests on a specific operator (e.g. torch.ceil):
11

12
    pytest test/onnx/test_op_consistency.py -k ceil
13
    pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention
14

15
    Read more on Running and writing tests:
16
        https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
17

18
Note:
19

20
    When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
21
    TESTED_OPS lists. See "Modify this section"
22

23
"""
24

25
from __future__ import annotations
26

27
import copy
28
from typing import Optional, Tuple
29

30
import onnx_test_common
31
import parameterized
32

33
# For readability, these two are allowed to be imported as function
34
from onnx_test_common import skip, xfail
35

36
import torch
37
from torch.testing._internal import (
38
    common_device_type,
39
    common_methods_invocations,
40
    common_utils,
41
)
42

43

44
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
45

46
# Modify this section ##########################################################
47
# NOTE: Modify this section as more ops are supported. The list should be sorted
48
# alphabetically.
49
#
50
# For example, to add a test for torch.ceil:
51
# 1.  Add "ceil" to TESTED_OPS then run pytest.
52
# 2.  If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.
53

54
# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
55
# Ops to be tested for numerical consistency between onnx and pytorch
56
# TODO: https://github.com/pytorch/pytorch/issues/102211
57
TESTED_OPS: frozenset[str] = frozenset(
58
    [
59
        "atan",
60
        "atan2",
61
        # "atleast_1d",  # How to support list input?
62
        # "atleast_2d",
63
        # "atleast_3d",
64
        "broadcast_to",
65
        "ceil",
66
        "expand",
67
        "flatten",
68
        "hstack",
69
        "logical_not",
70
        # "logit",
71
        "nn.functional.scaled_dot_product_attention",
72
        "repeat",
73
        "round",
74
        # "scatter_add",
75
        # "scatter_reduce",
76
        "sqrt",
77
        "stft",
78
        "t",
79
        "tile",
80
        "unflatten",
81
        "vstack",
82
    ]
83
)
84

85
# fmt: off
86
# Turn off black formatting to keep the list compact
87

88
# Expected failures for onnx export.
89
# The list should be sorted alphabetically by op name.
90
# Q: When should I use fixme vs vs skip vs xfail?
91
# A: Prefer xfail over skip when possible.
92
#     2a. If a test is now failing because of xpass, because some previous errors
93
#     are now fixed, removed the corresponding xfail.
94
#     2b. If a test is not failing consistently, use skip.
95
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
96
    skip(
97
        "atan", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
98
        reason=onnx_test_common.reason_onnx_does_not_support("Atan")
99
    ),
100
    xfail("atan", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])),
101
    skip(
102
        "atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
103
        reason=onnx_test_common.reason_onnx_does_not_support("Atan")
104
    ),
105
    xfail(
106
        "atan2", dtypes=[torch.float64],
107
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])
108
    ),
109
    xfail(
110
        "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
111
        reason=onnx_test_common.reason_onnx_does_not_support("Ceil")
112
    ),
113
    skip("hstack", opsets=[onnx_test_common.opsets_before(11)],
114
         reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")),
115
    xfail(
116
        "logit",
117
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
118
        reason=onnx_test_common.reason_onnx_does_not_support("Log", "bool, int"),
119
    ),
120
    skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."),
121
    skip("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
122
    xfail("round", opsets=[onnx_test_common.opsets_before(11)],
123
          reason=onnx_test_common.reason_onnx_does_not_support("Round")),
124
    xfail("round", variant_name="decimals_0", opsets=[onnx_test_common.opsets_before(11)],
125
          reason=onnx_test_common.reason_onnx_does_not_support("Round")),
126
    xfail("round", variant_name="decimals_3", opsets=[onnx_test_common.opsets_before(11)],
127
          reason=onnx_test_common.reason_onnx_does_not_support("Round")),
128
    xfail("round", variant_name="decimals_neg_3", opsets=[onnx_test_common.opsets_before(11)],
129
          reason=onnx_test_common.reason_onnx_does_not_support("Round")),
130
    skip("scatter_reduce", variant_name="amin", opsets=[onnx_test_common.opsets_before(16)],
131
         reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
132
    skip("scatter_reduce", variant_name="amax", opsets=[onnx_test_common.opsets_before(16)],
133
         reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
134
    skip("scatter_reduce", variant_name="prod", opsets=[onnx_test_common.opsets_before(16)],
135
         reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
136
    xfail("scatter_reduce", variant_name="mean",
137
          reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction=mean")),
138
    skip("scatter_reduce", variant_name="sum", opsets=[onnx_test_common.opsets_before(16)],
139
         reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")),
140
    xfail(
141
        "scatter_reduce",
142
        variant_name="sum",
143
        dtypes=(torch.float16,),
144
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
145
    ),
146
    xfail(
147
        "scatter_reduce",
148
        variant_name="prod",
149
        dtypes=(torch.float16,),
150
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
151
    ),
152
    xfail(
153
        "scatter_reduce",
154
        variant_name="amin",
155
        dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
156
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
157
    ),
158
    xfail(
159
        "scatter_reduce",
160
        variant_name="amax",
161
        dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
162
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
163
    ),
164
    xfail(
165
        "scatter_reduce",
166
        variant_name="mean",
167
        reason="ONNX doesn't support reduce='mean' option",
168
    ),
169
    skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")),
170
    skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")),
171
    xfail("stft",
172
          reason=onnx_test_common.reason_onnx_runtime_does_not_support("STFT", "Regression on ORT=1.15 4 percent difference")),
173
    skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")),
174
    xfail("unflatten", opsets=[onnx_test_common.opsets_before(13)], reason="Helper function is needed to support legacy ops."),
175
    skip("vstack", opsets=[onnx_test_common.opsets_before(11)],
176
         reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")),
177
)
178
# fmt: on
179

180
SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
181
    skip(
182
        "nn.functional.scaled_dot_product_attention",
183
        matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
184
        reason="dropout is random so the results do not match",
185
    ),
186
    skip(
187
        "repeat",
188
        reason="Empty repeats value leads to an invalid graph",
189
        matcher=lambda sample: not sample.args[0],
190
    ),
191
    skip(
192
        "scatter_reduce",
193
        # ONNX has not include_self parameter and default is include_self=True mode
194
        matcher=lambda sample: sample.kwargs.get("include_self") is False,
195
        reason="ONNX does't support include_self=False option",
196
    ),
197
    skip(
198
        "stft",
199
        reason="ONNX STFT does not support complex results",
200
        matcher=lambda sample: sample.kwargs.get("return_complex") is True,
201
    ),
202
    skip(
203
        "tile",
204
        matcher=lambda sample: any(dim == 0 for dim in sample.input.shape)
205
        or not sample.input.shape,
206
        reason="Logic not implemented for size 0 inputs in op.Reshape",
207
    ),
208
    skip(
209
        "unflatten",
210
        reason="Logic not implemented for size 0 inputs in op.Reshape",
211
        matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
212
    ),
213
)
214

215

216
# END OF SECTION TO MODIFY #####################################################
217

218
OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
219
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
220
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
221
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
222

223

224
class SingleOpModel(torch.nn.Module):
225
    """Test model to wrap around a single op for export."""
226

227
    def __init__(self, op, kwargs):
228
        super().__init__()
229
        self.operator = op
230
        self.kwargs = kwargs
231

232
    def forward(self, *args):
233
        return self.operator(*args, **self.kwargs)
234

235

236
def _should_skip_xfail_test_sample(
237
    op_name: str, sample
238
) -> Tuple[Optional[str], Optional[str]]:
239
    """Returns a reason if a test sample should be skipped."""
240
    if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
241
        return None, None
242
    for decorator_meta in SKIP_XFAIL_SUBTESTS:
243
        # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
244
        if decorator_meta.op_name == op_name:
245
            assert decorator_meta.matcher is not None, "Matcher must be defined"
246
            if decorator_meta.matcher(sample):
247
                return decorator_meta.test_behavior, decorator_meta.reason
248
    return None, None
249

250

251
def _get_test_class_name(cls, num, params_dict) -> str:
252
    del cls  # unused
253
    del num  # unused
254
    return params_dict["name"]
255

256

257
@parameterized.parameterized_class(
258
    [
259
        {
260
            "name": f"TestOnnxModelOutputConsistency_opset{opset}",
261
            "opset_version": opset,
262
        }
263
        for opset in onnx_test_common.TESTED_OPSETS
264
    ],
265
    class_name_func=_get_test_class_name,
266
)
267
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
268
    """Test output consistency between exported ONNX models and PyTorch eager mode.
269

270
    This is a parameterized test suite.
271
    """
272

273
    opset_version = -1
274

275
    @common_device_type.ops(
276
        [op for op in OPS_DB if op.name in TESTED_OPS],
277
        allowed_dtypes=onnx_test_common.INT_TYPES
278
        + onnx_test_common.FLOAT_TYPES
279
        + onnx_test_common.BOOL_TYPES,
280
    )
281
    def test_output_match(self, device: str, dtype: torch.dtype, op):
282
        """Test the ONNX exporter."""
283
        # device is provided by instantiate_device_type_tests, but we only want to run in cpu.
284
        assert device == "cpu"
285

286
        samples = op.sample_inputs(
287
            device,
288
            dtype,
289
            requires_grad=False,
290
        )
291

292
        for i, cpu_sample in enumerate(samples):
293
            inputs = (cpu_sample.input, *cpu_sample.args)
294
            # Provide the repr to subtest because tensors are not serializable in parallel test runs
295
            with self.subTest(
296
                opset=self.opset_version,
297
                sample_num=i,
298
                inputs=repr(inputs),
299
                kwargs=repr(cpu_sample.kwargs),
300
            ):
301
                test_behavior, reason = _should_skip_xfail_test_sample(
302
                    op.name, cpu_sample
303
                )
304
                with onnx_test_common.normal_xfail_skip_test_behaviors(
305
                    test_behavior, reason
306
                ):
307
                    model = SingleOpModel(op, cpu_sample.kwargs)
308
                    model.eval()
309

310
                    if dtype == torch.float32:
311
                        # Relax atol and rtol for float32 based on empirical results
312
                        # The current most relaxed values are for aten::stft
313
                        rtol = 1e-5
314
                        atol = 2e-5
315
                    elif dtype == torch.float64:
316
                        # The current most relaxed values are for aten::stft
317
                        rtol = 1e-5
318
                        atol = 2e-5
319
                    else:
320
                        rtol = None
321
                        atol = None
322
                    # Run the test
323
                    self.run_test(model, inputs, rtol=rtol, atol=atol)
324

325

326
for opset in onnx_test_common.TESTED_OPSETS:
327
    # The name needs to match the parameterized_class name.
328
    test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
329
    onnx_test_common.add_decorate_info(
330
        OPS_DB,
331
        test_class_name,
332
        "test_output_match",
333
        opset=opset,
334
        skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
335
    )
336
    common_device_type.instantiate_device_type_tests(
337
        globals()[test_class_name], globals(), only_for="cpu"
338
    )
339

340

341
if __name__ == "__main__":
342
    common_utils.run_tests()
343

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.