pytorch

Форк
0
/
test_fx_op_consistency.py 
2009 строк · 68.2 Кб
1
# Owner(s): ["module: onnx"]
2

3
"""Test consistency between the output values of torch.onnx FX exported operators
4
and torch operators given the same inputs.
5

6
Usage:
7

8
    1. Test all operators:
9

10
    pytest test/onnx/test_fx_op_consistency.py
11

12
    2. To run tests on a specific operator (e.g. torch.ceil):
13

14
    pytest test/onnx/test_fx_op_consistency.py -k ceil
15
    pytest test/onnx/test_fx_op_consistency.py -k nn_functional_scaled_dot_product_attention
16

17
    3. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g.
18

19
    CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/test_fx_op_consistency.py -k div_mode_int
20

21
    NOTE: Read more on Running and writing tests:
22
        https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
23

24
Note:
25

26
    1. Please make sure pytest-subtests is installed. Otherwise, the sub-tests will be ignored.
27

28
    2. Install pytest-xdist to run tests in parallel if runng all tests is the goal.
29

30
    3. When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
31
    TESTED_OPS lists. See "Modify this section"
32

33
"""
34

35
from __future__ import annotations
36

37
import copy
38
import itertools
39
import os
40
from typing import Any, Callable, Collection, Mapping, Optional, Tuple, Type, Union
41

42
import error_reproduction
43

44
import onnx_test_common
45

46
import parameterized
47
import pytest
48
import pytorch_test_common
49

50
import torch
51
from onnx_test_common import skip, skip_slow, xfail
52
from torch.onnx._internal.diagnostics import _rules
53
from torch.testing._internal import (
54
    common_device_type,
55
    common_methods_invocations,
56
    common_utils,
57
)
58
from torch.testing._internal.opinfo import core as opinfo_core
59

60

61
# NOTE: For ATen signature modifications that will break ONNX export,
62
# use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip
63
# to make the signal apparent for maintainers.
64
def xfail_torchlib_forward_compatibility(
65
    op_name: str,
66
    variant_name: str = "",
67
    *,
68
    reason: str,
69
    github_issue: str,
70
    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
71
    dtypes: Optional[Collection[torch.dtype]] = None,
72
    matcher: Optional[Callable[[Any], bool]] = None,
73
    enabled_if: bool = True,
74
):
75
    """Prefer using this (xfail) over skip when possible.
76

77
    Only skip when the test is not failing consistently.
78
    """
79
    return xfail(
80
        op_name,
81
        variant_name=variant_name,
82
        reason=f"{reason}. GitHub Issue: {github_issue}",
83
        opsets=opsets,
84
        dtypes=dtypes,
85
        matcher=matcher,
86
        enabled_if=enabled_if,
87
    )
88

89

90
def skip_torchlib_forward_compatibility(
91
    op_name: str,
92
    variant_name: str = "",
93
    *,
94
    reason: str,
95
    github_issue: str,
96
    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
97
    dtypes: Optional[Collection[torch.dtype]] = None,
98
    matcher: Optional[Callable[[Any], Any]] = None,
99
    enabled_if: bool = True,
100
):
101
    """Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible.
102

103
    Only skip when the test is not failing consistently.
104
    """
105
    return skip(
106
        op_name,
107
        variant_name=variant_name,
108
        reason=f"{reason}. GitHub Issue: {github_issue}",
109
        opsets=opsets,
110
        dtypes=dtypes,
111
        matcher=matcher,
112
        enabled_if=enabled_if,
113
    )
114

115

116
# fmt: off
117
# Turn off black formatting to keep the list compact
118

119
# Expected failures for onnx export.
120
# The list should be sorted alphabetically by op name.
121
# Q: When should I use fixme vs vs skip vs xfail?
122
# A: Prefer xfail over skip when possible.
123
#     2a. If a test is now failing because of xpass, because some previous errors
124
#     are now fixed, removed the corresponding xfail.
125
#     2b. If a test is not failing consistently, use skip.
126
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
127
    xfail(
128
        "__getitem__",
129
        reason="io_adaper doesn't support __getitem__ input slice(0, 3, None)",
130
    ),
131
    xfail(
132
        "__radd__",
133
        dtypes=onnx_test_common.BOOL_TYPES,
134
        reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "bool"),
135
    ),
136
    xfail(
137
        "__rmatmul__",
138
        reason="fixme: Assertion error: result mismatch",
139
    ),
140
    xfail(
141
        "__rpow__",
142
        dtypes=onnx_test_common.INT_TYPES,
143
        reason=onnx_test_common.reason_onnx_does_not_support("Pow", "int"),
144
    ),
145
    skip(
146
        "_native_batch_norm_legit",
147
        dtypes=(torch.float16,),
148
        reason="fixme: Assertion error: result mismatch and type error",
149
    ),
150
    xfail(
151
        "_softmax_backward_data",
152
        reason=onnx_test_common.reason_dynamo_does_not_support("assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)")
153
    ),
154
    xfail(
155
        "add", dtypes=onnx_test_common.BOOL_TYPES,
156
        reason=onnx_test_common.reason_onnx_does_not_support("Add")
157
    ),
158
    xfail(
159
        "add",
160
        dtypes=(torch.uint8, torch.int8, torch.int16,),
161
        reason=onnx_test_common.reason_onnx_script_does_not_support(
162
            "Add", "int8, int16, uint8 have type issue."
163
        ),
164
    ),
165
    xfail(
166
        "addbmm",
167
        dtypes=onnx_test_common.COMPLEX_TYPES,
168
        reason=onnx_test_common.reason_dynamo_does_not_support("Addbmm", "complex64")
169
    ),
170
    xfail(
171
        "addmm", dtypes=onnx_test_common.BOOL_TYPES,
172
        reason=onnx_test_common.reason_onnx_does_not_support("Addmm")
173
    ),
174
    xfail(
175
        "addmm",
176
        variant_name="decomposed",
177
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
178
        reason=onnx_test_common.reason_onnx_does_not_support("Addmm")
179
    ),
180
    skip(
181
        "addmm", dtypes=onnx_test_common.COMPLEX_TYPES,
182
        reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)")
183
    ),
184
    skip(
185
        "addmm",
186
        variant_name="decomposed",
187
        dtypes=onnx_test_common.COMPLEX_TYPES,
188
        reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)")
189
    ),
190
    xfail(
191
        "addr",
192
        dtypes=onnx_test_common.BOOL_TYPES,
193
        reason=onnx_test_common.reason_onnx_script_does_not_support(
194
            "Addr", "bool"
195
        ),
196
    ),
197
    xfail(
198
        "addr",
199
        dtypes=onnx_test_common.COMPLEX_TYPES,
200
        reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64")
201
    ),
202
    xfail(
203
        "all",
204
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
205
            "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0")
206
    ),
207
    xfail(
208
        "allclose",
209
        reason=onnx_test_common.reason_dynamo_does_not_support("Allclose")
210
    ),
211
    xfail(
212
        "amax",
213
        dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
214
        reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"),
215
    ),
216
    xfail(
217
        "amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
218
        reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16")
219
    ),
220
    xfail(
221
        "aminmax",
222
        dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
223
        reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"),
224
    ),
225
    xfail(
226
        "any",
227
        reason=onnx_test_common.reason_onnx_does_not_support(
228
            "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0")
229
    ),
230
    xfail(
231
        "arange",
232
        dtypes=(torch.uint8,),
233
        reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"),
234
    ),
235
    xfail(
236
        "arange",
237
        dtypes=(torch.int16, torch.int32),
238
        reason="AssertionError: The values for attribute 'shape' do not match",
239
    ),
240
    xfail(
241
        "argmax",
242
        dtypes=(
243
            torch.int16,
244
            torch.int64,
245
        ),
246
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
247
            "ArgMax", "int16, int64"
248
        ),
249
    ),
250
    xfail(
251
        "argmin",
252
        dtypes=(
253
            torch.uint8,
254
            torch.int8,
255
            torch.int16,
256
            torch.int64,
257
        ),
258
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
259
            "ArgMin", "uint8, int8, int16, int64"
260
        ),
261
    ),
262
    xfail(
263
        "argwhere",
264
        reason="fixme: Assertion error: result mismatch",
265
    ),
266
    skip(
267
        "as_strided",
268
        variant_name="partial_views",
269
        reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults",
270
    ),
271
    xfail(
272
        "atan2",
273
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
274
        reason="fixme: Assertion error: result mismatch",
275
    ),
276
    xfail(
277
        "baddbmm",
278
        dtypes=(
279
            torch.uint8,
280
            torch.int8,
281
            torch.int16,
282
        ),
283
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
284
            "Matmul", "uint8, int8, int16"
285
        ),
286
    ),
287
    xfail(
288
        "baddbmm",
289
        dtypes=onnx_test_common.COMPLEX_TYPES,
290
        reason=onnx_test_common.reason_dynamo_does_not_support("baddbmm", "complex64")
291
    ),
292
    xfail(
293
        "bernoulli",
294
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
295
    ),
296
    xfail(
297
        "bfloat16",
298
        reason="fixme: ORT errors with RuntimeError: No corresponding Numpy type for Tensor Type.",
299
    ),
300
    xfail(
301
        "bincount",
302
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"),
303
    ),
304
    xfail(
305
        "bmm",
306
        dtypes=(
307
            torch.uint8,
308
            torch.int8,
309
            torch.int16,
310
        ),
311
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
312
            "Matmul", "uint8, int8, int16"
313
        ),
314
    ),
315
    xfail(
316
        "broadcast_shapes",
317
        reason=onnx_test_common.reason_dynamo_does_not_support("output is int"),
318
    ),
319
    xfail(
320
        "cauchy",
321
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
322
    ),
323
    skip(
324
        "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
325
        reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int")
326
    ),
327
    xfail(
328
        "chalf",
329
        reason="fixme: ONNX shape type inference error: Invalid tensor data type 0."
330
    ),
331
    xfail(
332
        "chunk", dtypes=onnx_test_common.BOOL_TYPES,
333
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Chunk", "bool")
334
    ),
335
    xfail(
336
        "chunk",
337
        dtypes=(torch.uint8, torch.int8, torch.int16,),
338
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
339
            "Chunk", "uint8, int8, int16"
340
        ),
341
    ),
342
    xfail(
343
        "clamp",
344
        dtypes=(torch.uint8, torch.int8, torch.int16,),
345
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
346
            "Max", "uint8, int8, int16"
347
        ),
348
    ),
349
    xfail(
350
        "clamp_max", dtypes=onnx_test_common.BOOL_TYPES,
351
        reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool")
352
    ),
353
    xfail(
354
        "clamp_max",
355
        dtypes=(torch.uint8, torch.int8, torch.int16,),
356
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
357
            "Max", "uint8, int8, int16"
358
        ),
359
    ),
360
    xfail(
361
        "clamp_min",
362
        dtypes=(torch.uint8, torch.int8, torch.int16,),
363
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
364
            "Max", "uint8, int8, int16"
365
        ),
366
    ),
367
    xfail(
368
        "clamp_min", dtypes=onnx_test_common.BOOL_TYPES,
369
        reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool")
370
    ),
371
    xfail(
372
        "constant_pad_nd",
373
        dtypes=(torch.int16,),
374
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
375
            "Constant_pad_nd", "int16"
376
        ),
377
    ),
378
    xfail(
379
        "constant_pad_nd",
380
        dtypes=onnx_test_common.COMPLEX_TYPES,
381
        reason=onnx_test_common.reason_dynamo_does_not_support(
382
            "Constant_pad_nd", "complex64"
383
        ),
384
    ),
385
    xfail(
386
        "corrcoef",
387
        reason=onnx_test_common.reason_dynamo_does_not_support(
388
            "aten.equal.default"
389
        ),
390
    ),
391
    xfail(
392
        "cov",
393
        reason=onnx_test_common.reason_dynamo_does_not_support(
394
            "aten.equal.default"
395
        ),
396
    ),
397
    xfail(
398
        "cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,),
399
        reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16")
400
    ),
401
    xfail(
402
        "combinations",
403
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"),
404
    ),
405
    xfail(
406
        "cross",
407
        reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"),
408
    ),
409
    xfail(
410
        "dot", dtypes=(torch.uint8, torch.int8, torch.int16,),
411
        reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16")
412
    ),
413
    skip(
414
        "dot",
415
        dtypes=onnx_test_common.COMPLEX_TYPES,
416
        reason=onnx_test_common.reason_dynamo_does_not_support("Dot", "complex64(core dump)"),
417
    ),
418
    xfail(
419
        "empty",
420
        dtypes=onnx_test_common.COMPLEX_TYPES,
421
        reason="fixme: kwargs dtpye=complex64 is not supported in ONNX."
422
    ),
423
    xfail(
424
        "empty_strided",
425
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
426
    ),
427
    xfail(
428
        "eq",
429
        dtypes=(torch.uint8, torch.int8, torch.int16,),
430
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"),
431
    ),
432
    xfail(
433
        "equal",
434
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default")
435
    ),
436
    xfail(
437
        "exponential",
438
        reason=onnx_test_common.reason_dynamo_does_not_support("exponential"),
439
    ),
440
    xfail(
441
        "fft.fft",
442
        reason="fixme: Assertion error: result mismatch",
443
    ),
444
    xfail(
445
        "fft.fft2",
446
        reason="fixme: Assertion error: result mismatch",
447
    ),
448
    xfail(
449
        "fft.fftn",
450
        reason="fixme: Assertion error: result mismatch",
451
    ),
452
    xfail(
453
        "fft.ifft",
454
        reason="fixme: Assertion error: result mismatch",
455
    ),
456
    xfail(
457
        "fft.ifft2",
458
        reason="fixme: Assertion error: result mismatch",
459
    ),
460
    xfail(
461
        "fft.ifftn",
462
        reason="fixme: Assertion error: result mismatch",
463
    ),
464
    xfail(
465
        "fft.irfft",
466
        reason="fixme: Assertion error: result mismatch",
467
    ),
468
    xfail(
469
        "fft.irfft2",
470
        reason="fixme: Assertion error: result mismatch",
471
    ),
472
    xfail(
473
        "fft.irfftn",
474
        reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"),
475
    ),
476
    xfail(
477
        "fft.rfft",
478
        reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"),
479
    ),
480
    xfail(
481
        "fft.rfftn",
482
        reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"),
483
    ),
484
    xfail(
485
        "fft.rfft2",
486
        reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"),
487
    ),
488
    xfail(
489
        "floor",
490
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
491
        reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"),
492
    ),
493
    xfail(
494
        "floor_divide",
495
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
496
        reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"),
497
    ),
498
    xfail(
499
        "full",
500
        dtypes=onnx_test_common.COMPLEX_TYPES,
501
        reason=onnx_test_common.reason_dynamo_does_not_support("full", "complex64")
502
    ),
503
    xfail(
504
        "full_like",
505
        dtypes=onnx_test_common.COMPLEX_TYPES,
506
        reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64")
507
    ),
508
    xfail(
509
        "geometric",
510
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
511
    ),
512
    xfail(
513
        "heaviside",
514
        dtypes=onnx_test_common.BOOL_TYPES,
515
        reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"),
516
    ),
517
    xfail(
518
        "index_fill",
519
        dtypes=onnx_test_common.COMPLEX_TYPES,
520
        reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64")
521
    ),
522
    xfail(
523
        "index_put",
524
        dtypes=onnx_test_common.BOOL_TYPES,
525
        reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"),
526
    ),
527
    xfail(
528
        "index_put",
529
        dtypes=(torch.uint8, torch.int8, torch.int16,),
530
        reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"),
531
    ),
532
    xfail(
533
        "isnan",
534
        dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES,
535
        reason=onnx_test_common.reason_onnx_does_not_support("IsNaN", "int, bool"),
536
    ),
537
    xfail(
538
        "istft",
539
        reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
540
    ),
541
    xfail(
542
        "item",
543
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
544
    ),
545
    xfail(
546
        "lerp",
547
        dtypes=onnx_test_common.COMPLEX_TYPES,
548
        reason=onnx_test_common.reason_dynamo_does_not_support("lerp", "complex64")
549
    ),
550
    xfail(
551
        "linalg.lstsq",
552
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"),
553
    ),
554
    xfail(
555
        "linalg.lstsq",
556
        variant_name="grad_oriented",
557
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"),
558
    ),
559
    xfail(
560
        "linalg.norm",
561
        reason="fixme: Assertion error: result mismatch",
562
    ),
563
    xfail(
564
        "linalg.norm",
565
        variant_name="subgradients_at_zero",
566
        reason="fixme: Assertion error: result mismatch",
567
    ),
568
    xfail(
569
        "linalg.vecdot",
570
        reason="fixme: Assertion error: result shape mismatch",
571
    ),
572
    xfail(
573
        "linspace",
574
        dtypes=(torch.int64, torch.int32,),
575
        reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
576
    ),
577
    xfail(
578
        "linspace",
579
        variant_name="tensor_overload",
580
        dtypes=(torch.int64, torch.int32,),
581
        reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
582
    ),
583
    xfail(
584
        "linspace",
585
        dtypes=onnx_test_common.COMPLEX_TYPES,
586
        reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64")
587
    ),
588
    xfail(
589
        "linspace",
590
        variant_name="tensor_overload",
591
        dtypes=onnx_test_common.COMPLEX_TYPES,
592
        reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64")
593
    ),
594
    xfail(
595
        "log_normal",
596
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
597
    ),
598
    xfail(
599
        "log_softmax",
600
        dtypes=(torch.float16,),
601
        reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
602
    ),
603
    xfail(
604
        "log_softmax",
605
        variant_name="with_dtype",
606
        dtypes=(torch.float16,),
607
        reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
608
    ),
609
    xfail(
610
        "logcumsumexp",
611
        reason=onnx_test_common.reason_onnx_does_not_support(
612
            "Op (ReduceXXX) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0")
613
    ),
614
    xfail(
615
        "logical_and",
616
        dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES,
617
        reason=onnx_test_common.reason_onnx_script_does_not_support("And", "float, int"),
618
    ),
619
    xfail(
620
        "logical_not",
621
        dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES,
622
        reason=onnx_test_common.reason_onnx_script_does_not_support("Not", "float, int"),
623
    ),
624
    xfail(
625
        "logical_or",
626
        dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES,
627
        reason=onnx_test_common.reason_onnx_script_does_not_support("Or", "float, int"),
628
    ),
629
    xfail(
630
        "logical_xor",
631
        dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES,
632
        reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"),
633
    ),
634
    xfail(
635
        "logsumexp",
636
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
637
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceLogSumExp", "bool, int"),
638
    ),
639
    xfail(
640
        "masked.logsumexp",
641
        reason="fixme: https://github.com/onnx/onnx/issues/4986",
642
    ),
643
    xfail(
644
        "masked.amax",
645
        reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
646
    ),
647
    xfail(
648
        "masked.amin",
649
        reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
650
    ),
651
    xfail(
652
        "masked.argmin",
653
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,),
654
        reason="fixme: Assertion error: result mismatch",
655
    ),
656
    xfail(
657
        "masked.argmax",
658
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,),
659
        reason="fixme: Assertion error: result mismatch",
660
    ),
661
    xfail(
662
        "masked_fill",
663
        dtypes=onnx_test_common.BOOL_TYPES,
664
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
665
    ),
666
    xfail(
667
        "masked.sum",
668
        dtypes=onnx_test_common.BOOL_TYPES,
669
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
670
    ),
671
    xfail(
672
        "masked.log_softmax",
673
        dtypes=(torch.float16,),
674
        reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
675
    ),
676
    xfail(
677
        "masked.mean",
678
        dtypes=onnx_test_common.BOOL_TYPES,
679
        reason=onnx_test_common.reason_onnx_does_not_support("ReduceMean", "bool"),
680
    ),
681
    xfail(
682
        "masked.norm",
683
        reason="fixme: Assertion error: result mismatch",
684
    ),
685
    xfail(
686
        "masked.prod",
687
        dtypes=onnx_test_common.BOOL_TYPES,
688
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
689
    ),
690
    xfail(
691
        "masked_select",
692
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked_select.default"),
693
    ),
694
    xfail(
695
        "max",
696
        variant_name="reduction_no_dim",
697
        dtypes=onnx_test_common.BOOL_TYPES,
698
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"),
699
    ),
700
    xfail(
701
        "max",
702
        variant_name="reduction_with_dim",
703
        dtypes=onnx_test_common.BOOL_TYPES,
704
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"),
705
    ),
706
    xfail(
707
        "max",
708
        variant_name="reduction_with_dim",
709
        reason="https://github.com/onnx/onnx/issues/4986",
710
    ),
711
    xfail(
712
        "mean",
713
        reason="(ReduceMean) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
714
    ),
715
    xfail(
716
        "min",
717
        variant_name="reduction_no_dim",
718
        dtypes=onnx_test_common.BOOL_TYPES,
719
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"),
720
    ),
721
    xfail(
722
        "min",
723
        variant_name="reduction_with_dim",
724
        dtypes=onnx_test_common.BOOL_TYPES + (torch.int64,),
725
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"),
726
    ),
727
    skip(
728
        "mm",
729
        dtypes=onnx_test_common.COMPLEX_TYPES,
730
        reason=onnx_test_common.reason_dynamo_does_not_support("MM", "complex64(core dump)"),
731
    ),
732
    xfail(
733
        "multinomial",
734
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
735
    ),
736
    xfail(
737
        "nanquantile",
738
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default")
739
    ),
740
    xfail(
741
        "nansum",
742
        dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES,
743
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("IsNaN", "int, bool"),
744
    ),
745
    xfail(
746
        "narrow",
747
        reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
748
    ),
749
    xfail(
750
        "native_batch_norm",
751
        dtypes=(torch.float16,),
752
        reason="fixme: https://github.com/microsoft/onnxscript/issues/1269",
753
    ),
754
    xfail(
755
        "native_layer_norm",
756
        dtypes=(torch.float16,),
757
        reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438",
758
    ),
759
    xfail(
760
        "new_full",
761
        dtypes=onnx_test_common.COMPLEX_TYPES,
762
        reason=onnx_test_common.reason_dynamo_does_not_support("new_full", "complex64")
763
    ),
764
    xfail(
765
        "nn.functional.adaptive_avg_pool2d",
766
        reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \
767
            maximum recursion depth exceeded while calling a Python object"),
768
    ),
769
    xfail(
770
        "nn.functional.adaptive_avg_pool3d",
771
        reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"),
772
    ),
773
    xfail(
774
        "nn.functional.alpha_dropout",
775
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
776
    ),
777
    xfail(
778
        "nn.functional.avg_pool1d",
779
        dtypes=onnx_test_common.INT_TYPES,
780
        reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
781
    ),
782
    xfail(
783
        "nn.functional.avg_pool2d",
784
        dtypes=onnx_test_common.INT_TYPES,
785
        reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
786
    ),
787
    xfail(
788
        "nn.functional.avg_pool3d",
789
        dtypes=onnx_test_common.INT_TYPES,
790
        reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
791
    ),
792
    xfail(
793
        "nn.functional.batch_norm",
794
        dtypes=(torch.float16,),
795
        reason="fixme: https://github.com/microsoft/onnxscript/issues/1270",
796
    ),
797
    xfail(
798
        "nn.functional.conv_transpose1d",
799
        dtypes=(torch.int64,),
800
        reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"),
801
    ),
802
    xfail(
803
        "nn.functional.conv_transpose2d",
804
        dtypes=(torch.int64,),
805
        reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"),
806
    ),
807
    xfail(
808
        "nn.functional.conv_transpose3d",
809
        dtypes=(torch.int64,),
810
        reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"),
811
    ),
812
    skip(
813
        "nn.functional.conv_transpose1d",
814
        reason="fixme: Assertion error: result mismatch",
815
    ),
816
    skip(
817
        "nn.functional.conv_transpose2d",
818
        reason="fixme: Assertion error: result mismatch",
819
    ),
820
    skip(
821
        "nn.functional.conv_transpose3d",
822
        reason="fixme: Assertion error: result mismatch",
823
    ),
824
    xfail(
825
        "nn.functional.conv1d",
826
        dtypes=(torch.int64,),
827
        reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"),
828
    ),
829
    xfail(
830
        "nn.functional.conv2d",
831
        dtypes=(torch.int64,),
832
        reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"),
833
    ),
834
    xfail(
835
        "nn.functional.conv2d",
836
        dtypes=onnx_test_common.COMPLEX_TYPES,
837
        reason="fixme: Assertion error: result mismatch",
838
    ),
839
    xfail(
840
        "nn.functional.conv3d",
841
        dtypes=(torch.int64,),
842
        reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"),
843
    ),
844
    xfail(
845
        "nn.functional.conv3d",
846
        dtypes=onnx_test_common.COMPLEX_TYPES,
847
        reason="fixme: Assertion error: result mismatch",
848
    ),
849
    xfail(
850
        "nn.functional.ctc_loss",
851
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"),
852
    ),
853
    xfail(
854
        "nn.functional.dropout",
855
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
856
    ),
857
    xfail(
858
        "nn.functional.dropout2d",
859
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
860
    ),
861
    xfail(
862
        "nn.functional.dropout3d",
863
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
864
    ),
865
    xfail(
866
        "nn.functional.feature_alpha_dropout",
867
        variant_name="with_train",
868
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
869
    ),
870
    xfail(
871
        "nn.functional.feature_alpha_dropout",
872
        variant_name="without_train",
873
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
874
    ),
875
    xfail(
876
        "nn.functional.fractional_max_pool2d",
877
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
878
    ),
879
    xfail(
880
        "nn.functional.fractional_max_pool3d",
881
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
882
    ),
883
    xfail(
884
        "nn.functional.gaussian_nll_loss",
885
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.gaussian_nll_loss"),
886
    ),
887
    xfail(
888
        "nn.functional.grid_sample",
889
        reason="fixme: Assertion error: result mismatch",
890
    ),
891
    xfail(
892
        "nn.functional.group_norm",
893
        dtypes=(torch.float16,),
894
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"),
895
    ),
896
    xfail(
897
        "nn.functional.local_response_norm",
898
        dtypes=(torch.int64,),
899
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("avgpool", "int64"),
900
    ),
901
    xfail(
902
        "nn.functional.linear",
903
        dtypes=onnx_test_common.INT_TYPES,
904
        reason=onnx_test_common.reason_onnx_does_not_support("Gemm", "int"),
905
    ),
906
    xfail(
907
        "nn.functional.max_pool2d",
908
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
909
        reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"),
910
    ),
911
    xfail(
912
        "nn.functional.max_pool3d",
913
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
914
        reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"),
915
    ),
916
    xfail(
917
        "nn.functional.multi_head_attention_forward",
918
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
919
    ),
920
    xfail(
921
        "nn.functional.one_hot",
922
        reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
923
    ),
924
    xfail(
925
        "nn.functional.pad",
926
        variant_name="replicate",
927
        reason="fixme: ORT error: padding size",
928
    ),
929
    xfail(
930
        "nn.functional.pad",
931
        variant_name="replicate_negative",
932
        reason="fixme: Assertion error: result mismatch",
933
    ),
934
    xfail(
935
        "nn.functional.pad",
936
        variant_name="reflect",
937
        reason="fixme: Assertion error: result mismatch",
938
    ),
939
    xfail(
940
        "nn.functional.rrelu",
941
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
942
    ),
943
    xfail(
944
        "nn.functional.rrelu",
945
        dtypes=(torch.int64,),
946
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Relu", "int64"),
947
    ),
948
    skip(
949
        "nn.functional.scaled_dot_product_attention",
950
        matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0,
951
        reason="dropout is random so the results do not match",
952
    ),
953
    xfail(
954
        "nn.functional.scaled_dot_product_attention",
955
        dtypes=(torch.float16,),
956
        reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438",
957
    ),
958
    xfail(
959
        "nn.functional.selu",
960
        reason="fixme: nn.functional.selu is not in torch._decomp.decomposition_table",
961
    ),
962
    xfail(
963
        "nn.functional.soft_margin_loss",
964
        dtypes=(torch.float16,),
965
        reason="fixme: Assertion error: result mismatch",
966
    ),
967
    xfail(
968
        "nn.functional.tanhshrink",
969
        dtypes=(torch.float16,),
970
        reason="fixme: Assertion error: result mismatch",
971
    ),
972
    xfail(
973
        "nonzero",
974
        dtypes=(torch.int8, torch.int16),
975
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"),
976
    ),
977
    xfail(
978
        "normal",
979
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
980
    ),
981
    xfail(
982
        "normal",
983
        variant_name="in_place",
984
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
985
    ),
986
    xfail(
987
        "normal",
988
        variant_name="number_mean",
989
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
990
    ),
991
    xfail(
992
        "ones",
993
        dtypes=onnx_test_common.COMPLEX_TYPES,
994
        reason="fixme: kwargs dtpye=complex64 is not supported in ONNX."
995
    ),
996
    xfail(
997
        "pca_lowrank",
998
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
999
    ),
1000
    xfail(
1001
        "quantile",
1002
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default")
1003
    ),
1004
    xfail(
1005
        "rand_like",
1006
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1007
    ),
1008
    xfail(
1009
        "randint",
1010
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1011
    ),
1012
    xfail(
1013
        "randint_like",
1014
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1015
    ),
1016
    xfail(
1017
        "randn",
1018
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1019
    ),
1020
    xfail(
1021
        "randn_like",
1022
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1023
    ),
1024
    xfail(
1025
        "resize_",
1026
        reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_")
1027
    ),
1028
    xfail(
1029
        "resize_as_",
1030
        reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_")
1031
    ),
1032
    xfail(
1033
        "round",
1034
        dtypes=onnx_test_common.INT_TYPES,
1035
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Round", "int"),
1036
    ),
1037
    xfail(
1038
        "rsub",
1039
        dtypes=(torch.uint8, torch.int8, torch.int16),
1040
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
1041
            "Mul", "uint8, int8, int16"
1042
        ),
1043
    ),
1044
    xfail(
1045
        "scatter_add",
1046
        dtypes=(torch.float16,),
1047
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
1048
    ),
1049
    xfail(
1050
        "scatter_reduce",
1051
        variant_name="sum",
1052
        dtypes=(torch.float16,),
1053
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
1054
    ),
1055
    xfail(
1056
        "scatter_reduce",
1057
        variant_name="prod",
1058
        dtypes=(torch.float16,),
1059
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
1060
    ),
1061
    xfail(
1062
        "scatter_reduce",
1063
        variant_name="amin",
1064
        dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
1065
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
1066
    ),
1067
    xfail(
1068
        "scatter_reduce",
1069
        variant_name="amax",
1070
        dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
1071
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
1072
    ),
1073
    xfail(
1074
        "scatter_reduce",
1075
        variant_name="mean",
1076
        reason="ONNX doesn't support reduce='mean' option",
1077
    ),
1078
    xfail(
1079
        "sign",
1080
        dtypes=onnx_test_common.BOOL_TYPES,
1081
        reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"),
1082
    ),
1083
    xfail(
1084
        "signal.windows.kaiser",
1085
        reason=onnx_test_common.reason_dynamo_does_not_support("functionalization"),
1086
    ),
1087
    xfail(
1088
        "softmax",
1089
        dtypes=(torch.float16,),
1090
        reason="ORT error: https://github.com/microsoft/onnxruntime/issues/16438"
1091
    ),
1092
    xfail(
1093
        "sparse.mm",
1094
        variant_name="reduce",
1095
        reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"),
1096
    ),
1097
    xfail(
1098
        "sparse.sampled_addmm",
1099
        reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"),
1100
    ),
1101
    xfail(
1102
        "special.erfcx",
1103
        dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES,
1104
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Erf", "int, bool"),
1105
    ),
1106
    xfail(
1107
        "special.erfcx",
1108
        dtypes=onnx_test_common.FLOAT_TYPES,
1109
        reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"),
1110
    ),
1111
    xfail(
1112
        "special.ndtr",
1113
        dtypes=(torch.float16,),
1114
        reason="fixme: Assertion error: result mismatch",
1115
    ),
1116
    xfail(
1117
        "split",
1118
        dtypes=onnx_test_common.BOOL_TYPES,
1119
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1120
    ),
1121
    xfail(
1122
        "split",
1123
        variant_name="list_args",
1124
        dtypes=onnx_test_common.BOOL_TYPES,
1125
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1126
    ),
1127
    xfail(
1128
        "split_with_sizes",
1129
        dtypes=onnx_test_common.BOOL_TYPES,
1130
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1131
    ),
1132
    xfail(
1133
        "square",
1134
        dtypes=(torch.int8, torch.uint8, torch.int16),
1135
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"),
1136
    ),
1137
    xfail(
1138
        "squeeze",
1139
        reason="fixme: Assertion error: result mismatch",
1140
    ),
1141
    xfail(
1142
        "squeeze",
1143
        variant_name="multiple",
1144
        reason="fixme: https://github.com/microsoft/onnxscript/issues/1264",
1145
    ),
1146
    xfail(
1147
        "svd_lowrank",
1148
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1149
    ),
1150
    xfail(
1151
        "std_mean",
1152
        reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple."
1153
    ),
1154
    xfail(
1155
        "std_mean",
1156
        variant_name="unbiased",
1157
        reason="fixme: NotImplementedError: Type promotion does not support node output of list or tuple."
1158
    ),
1159
    xfail(
1160
        "stft",
1161
        reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"),
1162
    ),
1163
    xfail(
1164
        "sub",
1165
        dtypes=(torch.uint8, torch.int8, torch.int16),
1166
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
1167
            "Mul", "uint8, int8, int16"
1168
        ),
1169
    ),
1170
    xfail(
1171
        "take",
1172
        reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
1173
    ),
1174
    xfail(
1175
        "tensor_split",
1176
        reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"),
1177
    ),
1178
    xfail(
1179
        "to",
1180
        dtypes=(torch.int32, torch.int64, torch.float16, torch.float32, torch.bool, torch.complex64),
1181
        # model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1182
        model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1183
        reason="This op requires torch.dtype as input, which is not supported currently.",
1184
    ),
1185
    xfail(
1186
        "topk",
1187
        dtypes=(torch.int64, torch.int32),
1188
        reason="fixme: Assertion error: result mismatch",
1189
    ),
1190
    xfail(
1191
        "tril",
1192
        dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,),
1193
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"),
1194
    ),
1195
    xfail(
1196
        "triu",
1197
        dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,),
1198
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"),
1199
    ),
1200
    xfail(
1201
        "trunc",
1202
        dtypes=onnx_test_common.INT_TYPES,
1203
        reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"),
1204
    ),
1205
    xfail(
1206
        "unbind",
1207
        dtypes=onnx_test_common.BOOL_TYPES,
1208
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1209
    ),
1210
    xfail(
1211
        "unflatten",
1212
        dtypes=onnx_test_common.BOOL_TYPES,
1213
        reason=onnx_test_common.reason_onnx_does_not_support("Unflatten")
1214
    ),
1215
    xfail(
1216
        "uniform",
1217
        reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"),
1218
    ),
1219
    xfail(
1220
        "unique",
1221
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"),
1222
    ),
1223
    xfail(
1224
        "unique_consecutive",
1225
        reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"),
1226
    ),
1227
    xfail(
1228
        "unravel_index",
1229
        dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
1230
        reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"),
1231
    ),
1232
    xfail(
1233
        "unsafe_split",
1234
        dtypes=onnx_test_common.BOOL_TYPES,
1235
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1236
    ),
1237
    xfail(
1238
        "unsafe_chunk",
1239
        dtypes=onnx_test_common.BOOL_TYPES,
1240
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Split, SplitToSequence", "bool"),
1241
    ),
1242
    xfail(
1243
        "where",
1244
        dtypes=onnx_test_common.BOOL_TYPES,
1245
        reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"),
1246
    ),
1247
    xfail(
1248
        "zeros",
1249
        dtypes=onnx_test_common.COMPLEX_TYPES,
1250
        reason="fixme: kwargs dtpye=complex64 is not supported in ONNX."
1251
    ),
1252
    # SLOW TESTS (All are xfails if we run them)
1253
    # TODO: https://github.com/pytorch/pytorch/issues/117118
1254
    skip_slow(
1255
        "cdist",
1256
        reason="fixme: Test sets are too many.",
1257
    ),
1258
    skip_slow(
1259
        "histogram",
1260
        reason="fixme: Test sets are too many.",
1261
    ),
1262
    skip_slow(
1263
        "histogramdd",
1264
        reason="fixme: Test sets are too many.",
1265
    ),
1266
    skip_slow(
1267
        "linalg.lu_solve",
1268
        reason="fixme: Test sets are too many.",
1269
    ),
1270
    skip_slow(
1271
        "linalg.solve_triangular",
1272
        reason="fixme: Test sets are too many.",
1273
    ),
1274
    skip_slow(
1275
        "linalg.svd",
1276
        reason="fixme: Test sets are too many.",
1277
    ),
1278
    skip_slow(
1279
        "logspace",
1280
        reason="fixme: Test sets are too many.",
1281
    ),
1282
    skip_slow(
1283
        "logspace",
1284
        variant_name="tensor_overload",
1285
        reason="fixme: Test sets are too many.",
1286
    ),
1287
    skip_slow(
1288
        "max_pool2d_with_indices_backward",
1289
        reason="fixme: Test sets are too many.",
1290
    ),
1291
    skip_slow(
1292
        "nn.functional.interpolate",
1293
        variant_name="bicubic",
1294
        reason="fixme: Test sets are too many.",
1295
    ),
1296
    skip_slow(
1297
        "nn.functional.max_unpool1d",
1298
        reason="fixme: Test sets are too many.",
1299
    ),
1300
    skip_slow(
1301
        "nn.functional.max_unpool2d",
1302
        reason="fixme: Test sets are too many.",
1303
    ),
1304
    skip_slow(
1305
        "nn.functional.max_unpool3d",
1306
        reason="fixme: Test sets are too many.",
1307
    ),
1308
    skip_slow(
1309
        "nn.functional.max_pool1d",
1310
        reason="fixme: Test sets are too many.",
1311
    ),
1312
    skip_slow(
1313
        "nn.functional.max_pool2d",
1314
        reason="fixme: Test sets are too many.",
1315
    ),
1316
    skip_slow(
1317
        "nn.functional.max_pool3d",
1318
        reason="fixme: Test sets are too many.",
1319
    ),
1320
    skip_slow(
1321
        "nn.functional.unfold",
1322
        reason="fixme: Test sets are too many.",
1323
    ),
1324
    skip_slow(
1325
        "ormqr",
1326
        reason="fixme: Test sets are too many.",
1327
    ),
1328
    skip_slow(
1329
        "searchsorted",
1330
        reason="fixme: Test sets are too many.",
1331
    ),
1332
    skip_slow(
1333
        "svd",
1334
        reason="fixme: Test sets are too many.",
1335
    ),
1336
)
1337
# fmt: on
1338

1339
SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
1340
    skip(
1341
        "_native_batch_norm_legit",
1342
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1343
        reason="https://github.com/pytorch/pytorch/issues/115106",
1344
    ),
1345
    xfail(
1346
        "addmm",  # xfail can't only use dtypes to catch all cases
1347
        matcher=lambda sample: sample.input.dtype
1348
        in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
1349
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
1350
            "Gemm", "uint8, int8, int16, int32, int64"
1351
        ),
1352
    ),
1353
    xfail(
1354
        "addmm",
1355
        matcher=lambda sample: sample.args[0].numel() == 0,
1356
        reason="ONNX Runtime does not support empty tensors multiplication",
1357
    ),
1358
    xfail(
1359
        "addmm",
1360
        variant_name="decomposed",
1361
        matcher=lambda sample: sample.args[0].numel() == 0,
1362
        reason="ONNX Runtime does not support empty tensors multiplication",
1363
    ),
1364
    xfail(
1365
        "amax",
1366
        matcher=lambda sample: len(sample.input.shape) == 0
1367
        and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()),
1368
        reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
1369
    ),
1370
    xfail(
1371
        "amin",
1372
        matcher=lambda sample: len(sample.input.shape) == 0
1373
        and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()),
1374
        reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
1375
    ),
1376
    xfail(
1377
        "aminmax",
1378
        matcher=lambda sample: len(sample.input.shape) == 0
1379
        and sample.kwargs.get("dim") is not None,
1380
        reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
1381
    ),
1382
    skip(
1383
        "cat",
1384
        matcher=lambda sample: sample.input[0].equal(torch.tensor([])),
1385
        reason="core dump - cat does not support zero-dim tensors yet",
1386
    ),
1387
    xfail(
1388
        "index_add",
1389
        matcher=lambda sample: len(sample.input.shape) < 2,
1390
        reason="fixme: https://github.com/microsoft/onnxscript/issues/1212",
1391
    ),
1392
    xfail(
1393
        "index_add",
1394
        matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1,
1395
        reason="fixme: aten::index_put indices contains None when dim is -1",
1396
    ),
1397
    xfail(
1398
        "index_copy",
1399
        matcher=lambda sample: len(sample.input.shape) < 2,
1400
        reason="fixme: https://github.com/microsoft/onnxscript/issues/1212",
1401
    ),
1402
    xfail(
1403
        "index_copy",
1404
        matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1,
1405
        reason="fixme: aten::index_put indices contains None when dim is -1",
1406
    ),
1407
    xfail(
1408
        "index_put",
1409
        matcher=lambda sample: (sample.args[0][0].dtype == torch.bool)
1410
        and (sample.kwargs.get("accumulate") is False),
1411
        reason=onnx_test_common.reason_dynamo_does_not_support(
1412
            "https://github.com/pytorch/pytorch/issues/101150"
1413
        ),
1414
    ),
1415
    skip(
1416
        "linalg.multi_dot",
1417
        matcher=lambda sample: sum([torch.numel(input) for input in sample.input]) == 0,
1418
        reason="fixme: Undefined",
1419
    ),
1420
    skip(
1421
        "log_softmax",
1422
        matcher=lambda sample: len(sample.input.shape) == 0,
1423
        reason="fixme: LogSoftMax does not support empty tensor as input",
1424
    ),
1425
    skip(
1426
        "log_softmax",
1427
        variant_name="with_dtype",
1428
        matcher=lambda sample: len(sample.input.shape) == 0,
1429
        reason="fixme: LogSoftMax does not support empty tensor as input",
1430
    ),
1431
    xfail(
1432
        "logsumexp",
1433
        matcher=lambda sample: isinstance(sample.input, torch.Tensor)
1434
        and len(sample.input.shape) == 0,
1435
        reason="fixme: IsScalar",
1436
    ),
1437
    skip(
1438
        "masked.log_softmax",
1439
        matcher=lambda sample: len(sample.input.shape) == 0,
1440
        reason="fixme: LogSoftMax does not support empty tensor as input",
1441
    ),
1442
    skip(
1443
        "matmul",
1444
        matcher=lambda sample: torch.numel(sample.input) == 0,
1445
        reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
1446
    ),
1447
    xfail(
1448
        "min",
1449
        variant_name="reduction_with_dim",
1450
        matcher=lambda sample: len(sample.input.shape) == 0,
1451
        reason="fixme: https://github.com/onnx/onnx/issues/4986",
1452
    ),
1453
    skip(
1454
        "mm",
1455
        matcher=lambda sample: torch.numel(sample.input) == 0,
1456
        reason="values of matmul of [m, 0] and [0, n] matrices are undefined",
1457
    ),
1458
    xfail(
1459
        "native_batch_norm",
1460
        matcher=lambda sample: sample.args[-3] is True
1461
        and any(arg is not None for arg in sample.args[2:4]),
1462
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1463
        reason="https://github.com/pytorch/pytorch/issues/115106",
1464
    ),
1465
    xfail(
1466
        "nn.functional.avg_pool1d",
1467
        matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True)
1468
        and (
1469
            sample.kwargs.get("count_include_pad") is True
1470
            or sample.input.shape[2]
1471
            % (
1472
                sample.args[0][0]
1473
                if isinstance(sample.args[0], tuple)
1474
                else sample.args[0]
1475
            )
1476
            != 0
1477
        ),
1478
        reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
1479
    ),
1480
    xfail(
1481
        "nn.functional.avg_pool2d",
1482
        matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
1483
        or (sample.kwargs.get("divisor_override") is not None),
1484
        reason="ONNX doesn't support divisor_override argument",
1485
    ),
1486
    xfail(
1487
        "nn.functional.avg_pool3d",
1488
        matcher=lambda sample: sample.kwargs.get("ceil_mode") is True,
1489
        reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
1490
    ),
1491
    xfail(
1492
        "nn.functional.avg_pool3d",
1493
        matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
1494
        or (sample.kwargs.get("divisor_override") is not None),
1495
        reason="ONNX doesn't support divisor_override argument",
1496
    ),
1497
    xfail(
1498
        "nn.functional.batch_norm",
1499
        matcher=lambda sample: sample.kwargs.get("training") is True
1500
        and any(arg is not None for arg in sample.args[2:4]),
1501
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1502
        reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106",
1503
    ),
1504
    xfail(
1505
        "nn.functional.conv2d",
1506
        matcher=lambda sample: sample.kwargs.get("padding") == "valid",
1507
        reason="fixme: https://github.com/pytorch/pytorch/issues/117054",
1508
    ),
1509
    xfail(
1510
        "nn.functional.conv3d",
1511
        matcher=lambda sample: sample.kwargs.get("padding") == "valid",
1512
        reason="fixme: https://github.com/pytorch/pytorch/issues/117054",
1513
    ),
1514
    skip(
1515
        "nn.functional.cross_entropy",
1516
        matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int),
1517
        reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type",
1518
    ),
1519
    xfail(
1520
        "nn.functional.embedding",
1521
        matcher=lambda sample: sample.kwargs.get("max_norm") is not None,
1522
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1523
        reason="https://github.com/pytorch/pytorch/issues/115106",
1524
    ),
1525
    skip_torchlib_forward_compatibility(
1526
        "nn.functional.embedding_bag",
1527
        matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True,
1528
        reason=onnx_test_common.reason_onnx_script_does_not_support(
1529
            "'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. "
1530
            "'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided"
1531
        ),
1532
        github_issue="https://github.com/microsoft/onnxscript/issues/1056",
1533
    ),
1534
    xfail(
1535
        "nn.functional.group_norm",
1536
        matcher=lambda sample: torch.numel(sample.input) == 0,
1537
        reason=onnx_test_common.reason_onnx_runtime_does_not_support(
1538
            "Reshape", "empty tensor"
1539
        ),
1540
    ),
1541
    xfail(
1542
        "nn.functional.instance_norm",
1543
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1544
        matcher=lambda sample: sample.kwargs.get("running_mean") is not None
1545
        or sample.input.dtype in (torch.float16,),
1546
        reason="fixme: KeyError: 'self___kwargs__running_mean'",
1547
    ),
1548
    xfail(
1549
        "nn.functional.max_pool3d",
1550
        matcher=lambda sample: sample.kwargs.get("ceil_mode") is True
1551
        and sample.kwargs.get("padding") == 1,
1552
        reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed",
1553
    ),
1554
    xfail(
1555
        "nonzero",
1556
        matcher=lambda sample: len(sample.input.shape) == 0
1557
        and sample.kwargs.get("as_tuple", False) is False,
1558
        reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).",
1559
        model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1560
    ),
1561
    xfail(
1562
        "nonzero",
1563
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1564
        reason=onnx_test_common.reason_onnx_script_does_not_support(
1565
            "aten::_assert_async.msg",
1566
            "https://github.com/pytorch/pytorch/issues/112443",
1567
        ),
1568
    ),
1569
    xfail(
1570
        "scatter_add",
1571
        matcher=lambda sample: len(sample.input.shape) == 0,
1572
        reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch",
1573
    ),
1574
    skip(
1575
        "scatter_reduce",
1576
        variant_name="amax",
1577
        # ONNX has not include_self parameter and default is include_self=True mode
1578
        matcher=lambda sample: sample.kwargs.get("include_self") is False,
1579
        reason="ONNX does't support include_self=False option",
1580
    ),
1581
    skip(
1582
        "scatter_reduce",
1583
        variant_name="amin",
1584
        # ONNX has not include_self parameter and default is include_self=True mode
1585
        matcher=lambda sample: sample.kwargs.get("include_self") is False,
1586
        reason="ONNX does't support include_self=False option",
1587
    ),
1588
    skip(
1589
        "scatter_reduce",
1590
        variant_name="prod",
1591
        # ONNX has not include_self parameter and default is include_self=True mode
1592
        matcher=lambda sample: sample.kwargs.get("include_self") is False,
1593
        reason="ONNX does't support include_self=False option",
1594
    ),
1595
    skip(
1596
        "scatter_reduce",
1597
        variant_name="sum",
1598
        # ONNX has not include_self parameter and default is include_self=True mode
1599
        matcher=lambda sample: sample.kwargs.get("include_self") is False,
1600
        reason="ONNX does't support include_self=False option",
1601
    ),
1602
    skip(
1603
        "softmax",
1604
        matcher=lambda sample: len(sample.input.shape) == 0,
1605
        reason="fixme: LogSoftMax does not support empty tensor as input",
1606
    ),
1607
    xfail(
1608
        "t",
1609
        matcher=lambda sample: isinstance(sample.input, torch.Tensor)
1610
        and len(sample.input.shape) < 2,
1611
        reason="fixme: IsScalar",
1612
    ),
1613
    xfail(
1614
        "unflatten",
1615
        reason="Logic not implemented for size 0 inputs in op.Reshape",
1616
        matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
1617
    ),
1618
    skip(
1619
        "signal.windows.hamming",
1620
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1621
        reason="does not match node name",
1622
    ),
1623
    skip(
1624
        "signal.windows.general_hamming",
1625
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1626
        reason="does not match node name",
1627
    ),
1628
    skip(
1629
        "signal.windows.blackman",
1630
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1631
        reason="does not match node name",
1632
    ),
1633
    skip(
1634
        "signal.windows.general_cosine",
1635
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1636
        reason="does not match node name",
1637
    ),
1638
    skip(
1639
        "signal.windows.hann",
1640
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1641
        reason="does not match node name",
1642
    ),
1643
    skip(
1644
        "signal.windows.nuttall",
1645
        model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1646
        reason="does not match node name",
1647
    ),
1648
)
1649

1650
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
1651
OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
1652
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
1653

1654

1655
class SingleOpModel(torch.nn.Module):
1656
    """Test model to wrap around a single op for export."""
1657

1658
    def __init__(self, op, kwargs):
1659
        super().__init__()
1660
        self.operator = op
1661
        self.kwargs = kwargs
1662

1663
    def forward(self, *args):
1664
        return self.operator(*args, **self.kwargs)
1665

1666

1667
def _should_skip_xfail_test_sample(
1668
    op_name: str,
1669
    variant_test_name: str,
1670
    sample,
1671
    model_type: pytorch_test_common.TorchModelType,
1672
) -> Tuple[Optional[str], Optional[str]]:
1673
    """Check if the test sample should be skipped or xfailed.
1674

1675
    If the xfail/skip decorator meta is matched with its op_name and model_type,
1676
    return the test_behavior and reason. Otherwise, return None, None. Note that
1677
    if the matcher is None, the test is decorator_meta is meant to skip/xfail all model types.
1678

1679
    Args:
1680
        op_name: The name of the op.
1681
        sample: The test sample.
1682
        model_type: The model type of the test.
1683

1684
    Returns:
1685
        A tuple of (test_behavior, reason). test_behavior is either "skip" or "xfail".
1686
        reason is the reason for the test_behavior.
1687
    """
1688

1689
    if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
1690
        return None, None
1691
    for decorator_meta in SKIP_XFAIL_SUBTESTS:
1692
        # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
1693
        # NOTE: If model_type is None, the test is decorator_meta is meant to skip/xfail all model types.
1694
        if (
1695
            decorator_meta.op_name == op_name
1696
            and decorator_meta.variant_name == variant_test_name
1697
        ) and (
1698
            model_type == decorator_meta.model_type or decorator_meta.model_type is None
1699
        ):
1700
            if decorator_meta.matcher is None and decorator_meta.model_type is None:
1701
                raise TypeError(
1702
                    "Either Matcher or model_type must be defined in sub xfail and skip."
1703
                )
1704
            if decorator_meta.matcher is not None and decorator_meta.matcher(sample):
1705
                return decorator_meta.test_behavior, decorator_meta.reason
1706
            elif decorator_meta.matcher is None:
1707
                # xfail/skip the whole test of the model type without matcher
1708
                return decorator_meta.test_behavior, decorator_meta.reason
1709
    return None, None
1710

1711

1712
def _compare_onnx_and_torch_exported_program(
1713
    torch_exported_program,
1714
    onnx_exported_program,
1715
    input_args,
1716
    input_kwargs=None,
1717
    test_name=None,
1718
    sample_num=None,
1719
    sample_kwargs=None,
1720
    rtol=1e-03,
1721
    atol=1e-07,
1722
    only_check_shape=False,
1723
):
1724
    # avoid mutable default argument
1725
    if input_kwargs is None:
1726
        input_kwargs = {}
1727

1728
    # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
1729
    # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
1730
    # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
1731
    onnx_outputs = onnx_exported_program(*input_args, **input_kwargs)
1732
    torch_outputs = torch_exported_program(*input_args, **input_kwargs)
1733
    torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx(
1734
        torch_outputs
1735
    )
1736
    if len(torch_outputs_onnx_format) != len(onnx_outputs):
1737
        raise AssertionError(
1738
            f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}"
1739
        )
1740

1741
    for j, (torch_output, onnx_output) in enumerate(
1742
        zip(torch_outputs_onnx_format, onnx_outputs)
1743
    ):
1744
        if only_check_shape:
1745
            assert torch_output.shape == onnx_output.shape
1746
        else:
1747
            try:
1748
                torch.testing.assert_close(
1749
                    torch.tensor(onnx_output),
1750
                    torch_output,
1751
                    rtol=rtol,
1752
                    atol=atol,
1753
                    equal_nan=True,
1754
                )
1755
            except AssertionError as e:
1756
                if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
1757
                    error_reproduction.create_mismatch_report(
1758
                        test_name,
1759
                        sample_num,
1760
                        onnx_exported_program.model_proto,
1761
                        input_args,
1762
                        sample_kwargs,
1763
                        torch.tensor(onnx_output),
1764
                        torch_output,
1765
                        e,
1766
                    )
1767
                if len(torch_outputs_onnx_format) > 1:
1768
                    raise AssertionError(f"Output {j} mismatch") from e
1769
                raise
1770

1771

1772
def _run_test_output_match(
1773
    test_suite: onnx_test_common._TestONNXRuntime,
1774
    device: str,
1775
    dtype: torch.dtype,
1776
    op: opinfo_core.OpInfo,
1777
):
1778
    # device is provided by instantiate_device_type_tests, but we only want to run in cpu.
1779
    assert device == "cpu"
1780
    samples = op.sample_inputs(
1781
        device,
1782
        dtype,
1783
        requires_grad=False,
1784
    )
1785
    for i, cpu_sample in enumerate(samples):
1786
        inputs = (cpu_sample.input, *cpu_sample.args)
1787
        # Provide the repr to subtest because tensors are not serializable in parallel test runs
1788

1789
        with test_suite.subTest(
1790
            opset=test_suite.opset_version,
1791
            sample_num=i,
1792
            inputs=repr(inputs),
1793
            kwargs=repr(cpu_sample.kwargs),
1794
        ):
1795
            test_behavior, reason = _should_skip_xfail_test_sample(
1796
                op.name, op.variant_test_name, cpu_sample, test_suite.model_type
1797
            )
1798
            with onnx_test_common.normal_xfail_skip_test_behaviors(
1799
                test_behavior, reason
1800
            ):
1801
                model = SingleOpModel(op.op, cpu_sample.kwargs)
1802
                model.eval()
1803

1804
                if (
1805
                    dtype == torch.float32
1806
                    and op.name in test_suite.fp32_low_precision_dict
1807
                ):
1808
                    rtol = test_suite.fp32_low_precision_dict[op.name][0]
1809
                    atol = test_suite.fp32_low_precision_dict[op.name][1]
1810
                elif dtype == torch.float32:
1811
                    # Relax atol and rtol for float32 based on empirical results
1812
                    rtol = 1e-5
1813
                    atol = 2e-5
1814
                elif (
1815
                    dtype == torch.float16
1816
                    and (op.name, op.variant_test_name)
1817
                    in test_suite.fp16_low_precision_variant_dict
1818
                ):
1819
                    rtol = test_suite.fp16_low_precision_variant_dict[
1820
                        (op.name, op.variant_test_name)
1821
                    ][0]
1822
                    atol = test_suite.fp16_low_precision_variant_dict[
1823
                        (op.name, op.variant_test_name)
1824
                    ][1]
1825
                elif (
1826
                    dtype == torch.float16
1827
                    and op.name in test_suite.fp16_low_precision_dict
1828
                ):
1829
                    rtol = test_suite.fp16_low_precision_dict[op.name][0]
1830
                    atol = test_suite.fp16_low_precision_dict[op.name][1]
1831
                else:
1832
                    rtol = None
1833
                    atol = None
1834

1835
                if (
1836
                    test_suite.model_type
1837
                    == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
1838
                ):
1839
                    try:
1840
                        model = torch.export.export(model, inputs)
1841
                    except AssertionError as e:
1842
                        # NOTE: avoid fake_mode detection bug in torch.export.export
1843
                        pytest.xfail(
1844
                            onnx_test_common.reason_dynamo_does_not_support(str(e))
1845
                        )
1846

1847
                try:
1848
                    onnx_program = torch.onnx.dynamo_export(
1849
                        model,
1850
                        *inputs,
1851
                    )
1852
                except torch.onnx.OnnxExporterError as e:
1853
                    # NOTE: If the model has unsupported nodes, we will skip the test
1854
                    # with non-strict xfail. Otherwise, we will raise the error.
1855
                    if hasattr(
1856
                        e.__cause__, "diagnostic"
1857
                    ) and e.__cause__.diagnostic.rule in (
1858
                        _rules._POERules.no_symbolic_function_for_call_function,
1859
                        _rules._POERules.unsupported_fx_node_analysis,
1860
                    ):
1861
                        pytest.xfail(
1862
                            onnx_test_common.reason_onnx_script_does_not_support(str(e))
1863
                        )
1864
                    else:
1865
                        raise e
1866
                _compare_onnx_and_torch_exported_program(
1867
                    model,
1868
                    onnx_program,
1869
                    inputs,
1870
                    test_name=test_suite.id(),
1871
                    sample_num=i,
1872
                    sample_kwargs=cpu_sample.kwargs,
1873
                    rtol=rtol,
1874
                    atol=atol,
1875
                    only_check_shape=(op.name in test_suite.only_shape_check_list),
1876
                )
1877

1878

1879
def _parameterized_class_attrs_and_values():
1880
    input_values = []
1881
    input_values.extend(
1882
        itertools.product(
1883
            (opset for opset in onnx_test_common.FX_TESTED_OPSETS),
1884
            (
1885
                pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
1886
                pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
1887
            ),
1888
        )
1889
    )
1890
    return {
1891
        "attrs": ["opset_version", "model_type"],
1892
        "input_values": input_values,
1893
    }
1894

1895

1896
def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
1897
    """Combine class name with the parameterized arguments.
1898

1899
    This function is passed to `parameterized.parameterized_class` as the
1900
    `class_name_func` argument.
1901
    """
1902
    suffixes = []
1903
    for k, v in input_dicts.items():
1904
        suffixes.append(f"{k}_{v}")
1905
    return f"{cls.__name__}_{'_'.join(suffixes)}"
1906

1907

1908
@parameterized.parameterized_class(
1909
    **_parameterized_class_attrs_and_values(),
1910
    class_name_func=_parameterize_class_name,
1911
)
1912
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
1913
    """Test output consistency between exported ONNX models and PyTorch eager mode.
1914

1915
    This is a parameterized test suite.
1916
    """
1917

1918
    opset_version = -1
1919
    op_level_debug: bool = False
1920
    dynamic_shapes: bool = False
1921
    model_type: pytorch_test_common.TorchModelType = (
1922
        pytorch_test_common.TorchModelType.TORCH_NN_MODULE
1923
    )
1924

1925
    # NOTE: Follow torchlib settings in ops_test_data.py
1926
    only_shape_check_list = [
1927
        "empty",
1928
        "empty_like",
1929
        "empty_strided",
1930
        "new_empty",
1931
        "new_empty_strided",
1932
    ]
1933

1934
    fp32_low_precision_dict = {
1935
        "native_layer_norm": [2e-4, 7e-4],
1936
    }
1937

1938
    fp16_low_precision_dict = {
1939
        "addbmm": [2e-1, 2e-2],
1940
        "addcdiv": [3e-2, 1e-3],
1941
        "addcmul": [3e-2, 1e-3],
1942
        "addmv": [5e-2, 3e-2],
1943
        "addr": [3e-3, 4e-3],
1944
        "baddbmm": [3e-2, 1e-3],
1945
        "cumulative_trapezoid": [3e-2, 1e-3],
1946
        "diff": [1e-2, 5e-2],
1947
        "gradient": [3e-3, 4e-3],
1948
        "linalg.multi_dot": [3e-2, 1e-3],
1949
        "linalg.vecdot": [1e-2, 2e-2],
1950
        "linspace": [2e-2, 2e-3],
1951
        "masked.std": [2e-2, 2e-3],
1952
        "masked.var": [2e-2, 2e-2],
1953
        "matmul": [2e-2, 6e-2],
1954
        "nn.functional.batch_norm": [3e-2, 1e-3],
1955
        "nn.functional.binary_cross_entropy": [3e-2, 1e-3],
1956
        "nn.functional.binary_cross_entropy_with_logits": [3e-2, 1e-3],
1957
        "nn.functional.cosine_similarity": [3e-2, 1e-3],
1958
        "nn.functional.cosine_embedding_loss": [1e-2, 1e-3],
1959
        "nn.functional.hardsigmoid": [1e-3, 5e-3],
1960
        "nn.functional.hardswish": [1e-3, 5e-3],
1961
        "nn.functional.hinge_embedding_loss": [4e-1, 3e-3],
1962
        "nn.functional.instance_norm": [1e-2, 1e-3],
1963
        "nn.functional.interpolate": [1e-2, 1e-3],
1964
        "nn.functional.kl_div": [2e-3, 2e-4],
1965
        "nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3],
1966
        "nn.functional.local_response_norm": [1e-2, 5e-3],
1967
        "nn.functional.poisson_nll_loss": [3e-2, 1e-3],
1968
        "native_batch_norm": [3e-2, 1e-3],
1969
        "dot": [3e-2, 1e-3],
1970
        "logit": [3e-2, 1e-3],
1971
        "rsub": [3e-2, 1e-3],
1972
        "sinc": [2e-1, 6e-4],
1973
        "sub": [3e-2, 1e-3],
1974
        "trapezoid": [1e-3, 7e-3],
1975
        "trapz": [1e-3, 7e-3],
1976
    }
1977

1978
    fp16_low_precision_variant_dict = {
1979
        ("nn.functional.interpolate", "trilinear"): [3e-2, 3e-3],
1980
        ("nn.functional.interpolate", "linear"): [3e-2, 3e-3],
1981
    }
1982

1983
    @common_device_type.ops(
1984
        [op for op in OPS_DB if op.name in ALL_OPS_IN_DB],
1985
        allowed_dtypes=onnx_test_common.TESTED_DTYPES,
1986
    )
1987
    def test_output_match(self, device: str, dtype: torch.dtype, op):
1988
        """Test the ONNX exporter."""
1989
        _run_test_output_match(self, device, dtype, op)
1990

1991

1992
for opset in onnx_test_common.FX_TESTED_OPSETS:
1993
    for model_type in pytorch_test_common.TorchModelType:
1994
        # The name needs to match the parameterized_class name.
1995
        test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}"
1996
        onnx_test_common.add_decorate_info(
1997
            OPS_DB,
1998
            test_class_name,
1999
            "test_output_match",
2000
            opset=opset,
2001
            skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
2002
        )
2003

2004
        common_device_type.instantiate_device_type_tests(
2005
            globals()[test_class_name], globals(), only_for="cpu"
2006
        )
2007

2008
if __name__ == "__main__":
2009
    common_utils.run_tests()
2010

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.