pytorch

Форк
0
/
onnx_test_common.py 
794 строки · 26.7 Кб
1
# Owner(s): ["module: onnx"]
2

3
from __future__ import annotations
4

5
import contextlib
6

7
import copy
8
import dataclasses
9
import io
10
import logging
11
import os
12
import unittest
13
import warnings
14
from typing import (
15
    Any,
16
    Callable,
17
    Collection,
18
    Iterable,
19
    List,
20
    Mapping,
21
    Optional,
22
    Sequence,
23
    Tuple,
24
    Type,
25
    Union,
26
)
27

28
import numpy as np
29

30
import onnxruntime
31
import pytest
32
import pytorch_test_common
33
import torch
34
from torch import export as torch_export
35
from torch.onnx import _constants, verification
36
from torch.onnx._internal import _beartype
37
from torch.onnx._internal.fx import diagnostics
38
from torch.testing._internal import common_utils
39
from torch.testing._internal.opinfo import core as opinfo_core
40
from torch.types import Number
41

42
_NumericType = Union[Number, torch.Tensor, np.ndarray]
43
_ModelType = Union[torch.nn.Module, Callable, torch_export.ExportedProgram]
44
_InputArgsType = Optional[
45
    Union[torch.Tensor, int, float, bool, Sequence[Any], Mapping[str, Any]]
46
]
47
_OutputsType = Sequence[_NumericType]
48

49
onnx_model_dir = os.path.join(
50
    os.path.dirname(os.path.realpath(__file__)),
51
    os.pardir,
52
    "repos",
53
    "onnx",
54
    "onnx",
55
    "backend",
56
    "test",
57
    "data",
58
)
59

60

61
pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
62

63

64
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
65

66

67
def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
68
    options = verification.VerificationOptions()
69

70
    kwargs["opset_version"] = test_suite.opset_version
71
    kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
72
    if hasattr(test_suite, "check_shape"):
73
        options.check_shape = test_suite.check_shape
74
    if hasattr(test_suite, "check_dtype"):
75
        options.check_dtype = test_suite.check_dtype
76

77
    names = {f.name for f in dataclasses.fields(options)}
78
    keywords_to_pop = []
79
    for k, v in kwargs.items():
80
        if k in names:
81
            setattr(options, k, v)
82
            keywords_to_pop.append(k)
83
    for k in keywords_to_pop:
84
        kwargs.pop(k)
85

86
    return verification.verify(*args, options=options, **kwargs)
87

88

89
def assert_dynamic_shapes(onnx_program: torch.onnx.ONNXProgram, dynamic_shapes: bool):
90
    """Assert whether the exported model has dynamic shapes or not.
91

92
    Args:
93
        onnx_program (torch.onnx.ONNXProgram): The output of torch.onnx.dynamo_export.
94
        dynamic_shapes (bool): Whether the exported model has dynamic shapes or not.
95
            When True, raises if graph inputs don't have at least one dynamic dimension
96
            When False, raises if graph inputs have at least one dynamic dimension.
97

98
    Raises:
99
        AssertionError: If the exported model has dynamic shapes and dynamic_shapes is False and vice-versa.
100
    """
101

102
    if dynamic_shapes is None:
103
        return
104

105
    model_proto = onnx_program.model_proto
106
    # Process graph inputs
107
    dynamic_inputs = []
108
    for inp in model_proto.graph.input:
109
        dynamic_inputs += [
110
            dim
111
            for dim in inp.type.tensor_type.shape.dim
112
            if dim.dim_value == 0 and dim.dim_param != ""
113
        ]
114
    assert dynamic_shapes == (
115
        len(dynamic_inputs) > 0
116
    ), "Dynamic shape check failed for graph inputs"
117

118

119
def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
120
    """Combine class name with the parameterized arguments.
121

122
    This function is passed to `parameterized.parameterized_class` as the
123
    `class_name_func` argument.
124
    """
125
    suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items())
126
    return f"{cls.__name__}_{suffix}"
127

128

129
class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
130
    opset_version = _constants.ONNX_DEFAULT_OPSET
131
    keep_initializers_as_inputs = True  # For IR version 3 type export.
132
    is_script = False
133
    check_shape = True
134
    check_dtype = True
135

136
    def setUp(self):
137
        super().setUp()
138
        onnxruntime.set_seed(0)
139
        if torch.cuda.is_available():
140
            torch.cuda.manual_seed_all(0)
141
        os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0"
142
        self.is_script_test_enabled = True
143

144
    # The exported ONNX model may have less inputs than the pytorch model because of const folding.
145
    # This mostly happens in unit test, where we widely use torch.size or torch.shape.
146
    # So the output is only dependent on the input shape, not value.
147
    # remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model.
148
    def run_test(
149
        self,
150
        model,
151
        input_args,
152
        input_kwargs=None,
153
        rtol=1e-3,
154
        atol=1e-7,
155
        do_constant_folding=True,
156
        dynamic_axes=None,
157
        additional_test_inputs=None,
158
        input_names=None,
159
        output_names=None,
160
        fixed_batch_size=False,
161
        training=torch.onnx.TrainingMode.EVAL,
162
        remained_onnx_input_idx=None,
163
        verbose=False,
164
    ):
165
        def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True):
166
            return run_model_test(
167
                self,
168
                m,
169
                input_args=input_args,
170
                input_kwargs=input_kwargs,
171
                rtol=rtol,
172
                atol=atol,
173
                do_constant_folding=do_constant_folding,
174
                dynamic_axes=dynamic_axes,
175
                additional_test_inputs=additional_test_inputs,
176
                input_names=input_names,
177
                output_names=output_names,
178
                fixed_batch_size=fixed_batch_size,
179
                training=training,
180
                remained_onnx_input_idx=remained_onnx_input_idx,
181
                flatten=flatten,
182
                ignore_none=ignore_none,
183
                verbose=verbose,
184
            )
185

186
        if isinstance(remained_onnx_input_idx, dict):
187
            scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"]
188
            tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"]
189
        else:
190
            scripting_remained_onnx_input_idx = remained_onnx_input_idx
191
            tracing_remained_onnx_input_idx = remained_onnx_input_idx
192

193
        is_model_script = isinstance(
194
            model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)
195
        )
196

197
        if self.is_script_test_enabled and self.is_script:
198
            script_model = model if is_model_script else torch.jit.script(model)
199
            _run_test(
200
                script_model,
201
                scripting_remained_onnx_input_idx,
202
                flatten=False,
203
                ignore_none=False,
204
            )
205
        if not is_model_script and not self.is_script:
206
            _run_test(model, tracing_remained_onnx_input_idx)
207

208
    @_beartype.beartype
209
    def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
210
        self,
211
        model: _ModelType,
212
        input_args: Sequence[_InputArgsType],
213
        *,
214
        input_kwargs: Optional[Mapping[str, _InputArgsType]] = None,
215
        rtol: Optional[float] = 1e-3,
216
        atol: Optional[float] = 1e-7,
217
        has_mutation: bool = False,
218
        additional_test_inputs: Optional[
219
            List[
220
                Union[
221
                    Tuple[Sequence[_InputArgsType], Mapping[str, _InputArgsType]],
222
                    Tuple[Sequence[_InputArgsType]],
223
                ]
224
            ]
225
        ] = None,
226
        skip_dynamic_shapes_check: bool = False,
227
    ):
228
        """Compare the results of PyTorch model with exported ONNX model
229

230
        Args:
231
            model (_ModelType): PyTorch model
232
            input_args (Sequence[_InputArgsType]): torch input arguments
233
            input_kwargs (Mapping[str, _InputArgsType]): torch input kwargs
234
            rtol (float, optional): relative tolerance. Defaults to 1e-3.
235
            atol (float, optional): absolute tolerance. Defaults to 1e-7.
236
            has_mutation (bool, optional): Whether the model mutates its input or state.
237
                `mutation` as `True` incurs extra overhead of cloning the inputs and model.
238
                Defaults to False.
239
            additional_test_inputs: Test the models with another dataset input, which
240
                is designed for dynamic axes testing. Defaults to None. It's a list of
241
                different input sets in tuples. Inside tuple, the first element is a tuple
242
                of args, and the second element is a dict of kwargs. Remember to put comma
243
                even if the following element is not provided.
244
                For example,
245
                additional_test_inputs = [((args1, args2), {"kwargs":1}), ((args1,),), ((), {"kwargs":1})]
246
            skip_dynamic_shapes_check: Whether to skip dynamic shape check. Defaults to False.
247
                Must be used when tests do not produce dynamic shapes even when dynamic shape feature is enabled.
248
                This is needed because Torch Dynamo uses the dynamic_shapes flag as a hint, only.
249

250
        """
251

252
        # avoid mutable data structure
253
        if input_kwargs is None:
254
            input_kwargs = {}
255

256
        if (
257
            has_mutation
258
            and self.model_type
259
            != pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
260
        ):
261
            ref_model = _try_clone_model(model)
262
            ref_input_args, ref_input_kwargs = _try_clone_inputs(
263
                input_args, input_kwargs
264
            )
265
        else:
266
            ref_model = model
267
            ref_input_args = input_args
268
            ref_input_kwargs = input_kwargs
269

270
        assert isinstance(ref_model, torch.nn.Module) or callable(
271
            ref_model
272
        ), "Model must be a torch.nn.Module or callable"
273
        if (
274
            self.model_type
275
            == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
276
        ):
277
            ref_model = torch.export.export(ref_model, args=ref_input_args)
278
            if (
279
                self.dynamic_shapes
280
            ):  # TODO: Support dynamic shapes for torch.export.ExportedProgram
281
                #       https://github.com/pytorch/pytorch/issues/113705
282
                pytest.xfail(
283
                    reason="torch.export.ExportedProgram does not support dynamic shapes"
284
                )
285

286
        # Feed args and kwargs into exporter.
287
        # Note that exporter should flatten kwargs into positional args the exported model;
288
        # since ONNX doesn't represent kwargs.
289
        export_error: Optional[torch.onnx.OnnxExporterError] = None
290
        try:
291
            onnx_program = torch.onnx.dynamo_export(
292
                ref_model,
293
                *ref_input_args,
294
                **ref_input_kwargs,
295
                export_options=torch.onnx.ExportOptions(
296
                    op_level_debug=self.op_level_debug,
297
                    dynamic_shapes=self.dynamic_shapes,
298
                    diagnostic_options=torch.onnx.DiagnosticOptions(
299
                        verbosity_level=logging.DEBUG
300
                    ),
301
                ),
302
            )
303
        except torch.onnx.OnnxExporterError as e:
304
            export_error = e
305
            onnx_program = e.onnx_program
306

307
        if diagnostics.is_onnx_diagnostics_log_artifact_enabled():
308
            onnx_program.save_diagnostics(
309
                f"test_report_{self._testMethodName}"
310
                f"_op_level_debug_{self.op_level_debug}"
311
                f"_dynamic_axes_{self.dynamic_shapes}"
312
                f"_model_type_{self.model_type}"
313
                ".sarif"
314
            )
315

316
        if export_error is not None:
317
            raise export_error
318

319
        if not skip_dynamic_shapes_check:
320
            assert_dynamic_shapes(onnx_program, self.dynamic_shapes)
321

322
        _compare_pytorch_onnx_with_ort(
323
            onnx_program,
324
            ref_model,
325
            input_args,
326
            input_kwargs,
327
            atol,
328
            rtol,
329
            has_mutation=has_mutation,
330
        )
331
        # This confirms the exported mode accepts different input shapes
332
        # when dynamic shape is enabled.
333
        if additional_test_inputs and self.dynamic_shapes:
334
            for another_input in additional_test_inputs:
335
                if len(another_input) > 2:
336
                    raise ValueError(
337
                        f"test_inputs should only have tuple args and dictionary kwargs. But receives: {len(another_input)}"
338
                    )
339
                additional_input_args = another_input[0]
340
                additional_input_kwargs = (
341
                    another_input[1]
342
                    if len(another_input) == 2 and another_input[1] is not None
343
                    else {}
344
                )
345
                _compare_pytorch_onnx_with_ort(
346
                    onnx_program,
347
                    ref_model,
348
                    additional_input_args,
349
                    additional_input_kwargs,
350
                    atol,
351
                    rtol,
352
                    has_mutation=has_mutation,
353
                )
354

355

356
@_beartype.beartype
357
def run_ort(
358
    onnx_model: Union[str, torch.onnx.ONNXProgram],
359
    pytorch_inputs: Sequence[_InputArgsType],
360
) -> _OutputsType:
361
    """Run ORT on the given ONNX model and inputs
362

363
    Used in test_fx_to_onnx_with_onnxruntime.py
364

365
    Args:
366
        onnx_model (Union[str, torch.onnx.ONNXProgram]): Converter ONNX model
367
        pytorch_inputs (Sequence[_InputArgsType]): The given torch inputs
368

369
    Raises:
370
        AssertionError: ONNX and PyTorch should have the same input sizes
371

372
    Returns:
373
        _OutputsType: ONNX model predictions
374
    """
375
    if isinstance(onnx_model, torch.onnx.ONNXProgram):
376
        buffer = io.BytesIO()
377
        onnx_model.save(buffer)
378
        ort_model = buffer.getvalue()
379
    else:
380
        ort_model = onnx_model
381

382
    # Suppress floods of warnings from ONNX Runtime
383
    session_options = onnxruntime.SessionOptions()
384
    session_options.log_severity_level = 3  # Error
385
    session = onnxruntime.InferenceSession(
386
        ort_model, providers=["CPUExecutionProvider"], sess_options=session_options
387
    )
388
    input_names = [ort_input.name for ort_input in session.get_inputs()]
389

390
    if len(input_names) != len(pytorch_inputs):
391
        raise AssertionError(
392
            f"Expected {len(input_names)} inputs, got {len(pytorch_inputs)}"
393
        )
394

395
    ort_input = {
396
        k: torch.Tensor.numpy(v, force=True)
397
        for k, v in zip(input_names, pytorch_inputs)
398
    }
399
    return session.run(None, ort_input)
400

401

402
@_beartype.beartype
403
def _try_clone_model(model: _ModelType) -> _ModelType:
404
    """Used for preserving original model in case forward mutates model states."""
405
    try:
406
        return copy.deepcopy(model)
407
    except Exception:
408
        warnings.warn(
409
            "Failed to clone model. Model state might be mutated during verification."
410
        )
411
        return model
412

413

414
@_beartype.beartype
415
def _try_clone_inputs(input_args, input_kwargs):
416
    ref_input_args = copy.deepcopy(input_args)
417
    ref_input_kwargs = copy.deepcopy(input_kwargs)
418
    return ref_input_args, ref_input_kwargs
419

420

421
@_beartype.beartype
422
def _compare_pytorch_onnx_with_ort(
423
    onnx_program: torch.onnx.ONNXProgram,
424
    model: _ModelType,
425
    input_args: Sequence[_InputArgsType],
426
    input_kwargs: Mapping[str, _InputArgsType],
427
    atol: Optional[float] = None,
428
    rtol: Optional[float] = None,
429
    has_mutation: bool = False,
430
):
431
    if has_mutation:
432
        ref_model = _try_clone_model(model)
433
        ref_input_args, ref_input_kwargs = _try_clone_inputs(input_args, input_kwargs)
434
    else:
435
        ref_model = model
436
        ref_input_args = input_args
437
        ref_input_kwargs = input_kwargs
438

439
    # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
440
    # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
441
    # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
442
    # NOTE: `model_with_state_dict=ref_model` is specified to cover runs with FakeTensor support
443
    ort_outputs = onnx_program(*input_args, **input_kwargs)
444
    ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs)
445
    ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs)
446

447
    if len(ref_outputs) != len(ort_outputs):
448
        raise AssertionError(
449
            f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}"
450
        )
451

452
    for ref_output, ort_output in zip(ref_outputs, ort_outputs):
453
        torch.testing.assert_close(
454
            ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
455
        )
456

457

458
# The min onnx opset version to test for
459
MIN_ONNX_OPSET_VERSION = 9
460
# The max onnx opset version to test for
461
MAX_ONNX_OPSET_VERSION = _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET
462
TESTED_OPSETS = range(MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION + 1)
463

464
# The min onnx opset version to test for
465
FX_MIN_ONNX_OPSET_VERSION = 18
466
# The max onnx opset version to test for
467
FX_MAX_ONNX_OPSET_VERSION = 18
468
FX_TESTED_OPSETS = range(FX_MIN_ONNX_OPSET_VERSION, FX_MAX_ONNX_OPSET_VERSION + 1)
469

470
BOOL_TYPES = (torch.bool,)
471

472
INT_TYPES = (
473
    # torch.int8,
474
    # torch.int16,
475
    torch.int32,
476
    torch.int64,
477
    # torch.uint8,
478
)
479

480
QINT_TYPES = (
481
    torch.qint8,
482
    torch.quint8,
483
)
484

485
FLOAT_TYPES = (
486
    torch.float16,
487
    torch.float32,
488
    # torch.float64,  ORT doesn't support
489
)
490

491
COMPLEX_TYPES = (
492
    # torch.complex32,  NOTE: torch.complex32 is experimental in torch
493
    torch.complex64,
494
    # torch.complex128,  ORT doesn't support
495
)
496

497
TESTED_DTYPES = (
498
    # Boolean
499
    torch.bool,
500
    # Integers
501
    *INT_TYPES,
502
    # Floating types
503
    *FLOAT_TYPES,
504
    # Complex types
505
    *COMPLEX_TYPES,
506
)
507

508

509
@dataclasses.dataclass
510
class DecorateMeta:
511
    """Information about a test case to skip or xfail.
512

513
    Adapted from functorch: functorch/test/common_utils.py
514

515
    Attributes:
516
        op_name: The name of the operator.
517
        variant_name: The name of the OpInfo variant.
518
        decorator: The decorator to apply to the test case.
519
        opsets: The opsets to apply the decorator to.
520
        dtypes: The dtypes to apply the decorator to.
521
        reason: The reason for skipping.
522
        test_behavior: The behavior of the test case. [skip or xfail]
523
        matcher: The matcher to apply to the test case.
524
        enabled_if: Whether to enable test behavior. Usually used on onnx/ort version control
525
        model_type: The type of the torch model. Defaults to None.
526
    """
527

528
    op_name: str
529
    variant_name: str
530
    decorator: Callable
531
    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]]
532
    dtypes: Optional[Collection[torch.dtype]]
533
    reason: str
534
    test_behavior: str
535
    matcher: Optional[Callable[[Any], bool]] = None
536
    enabled_if: bool = True
537
    model_type: Optional[pytorch_test_common.TorchModelType] = None
538

539
    def contains_opset(self, opset: int) -> bool:
540
        if self.opsets is None:
541
            return True
542
        return any(
543
            opset == opset_spec if isinstance(opset_spec, int) else opset_spec(opset)
544
            for opset_spec in self.opsets
545
        )
546

547

548
def xfail(
549
    op_name: str,
550
    variant_name: str = "",
551
    *,
552
    reason: str,
553
    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
554
    dtypes: Optional[Collection[torch.dtype]] = None,
555
    matcher: Optional[Callable[[Any], bool]] = None,
556
    enabled_if: bool = True,
557
    model_type: Optional[pytorch_test_common.TorchModelType] = None,
558
):
559
    """Expects a OpInfo test to fail.
560

561
    Args:
562
        op_name: The name of the operator.
563
        variant_name: The name of the variant.
564
        opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
565
        dtypes: The dtypes to expect the failure.
566
        reason: The reason for the failure.
567
        matcher: A function that matches the test sample input. It is used only when
568
            xfail is in the SKIP_XFAIL_SUBTESTS list.
569
        enabled_if: Whether to enable xfail. Usually used on onnx/ort version control
570
        model_type: The type of the torch model. Defaults to None.
571
    """
572
    return DecorateMeta(
573
        op_name=op_name,
574
        variant_name=variant_name,
575
        decorator=unittest.expectedFailure,
576
        opsets=opsets,
577
        dtypes=dtypes,
578
        enabled_if=enabled_if,
579
        matcher=matcher,
580
        reason=reason,
581
        test_behavior="xfail",
582
        model_type=model_type,
583
    )
584

585

586
def skip(
587
    op_name: str,
588
    variant_name: str = "",
589
    *,
590
    reason: str,
591
    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
592
    dtypes: Optional[Collection[torch.dtype]] = None,
593
    matcher: Optional[Callable[[Any], Any]] = None,
594
    enabled_if: bool = True,
595
    model_type: Optional[pytorch_test_common.TorchModelType] = None,
596
):
597
    """Skips a test case in OpInfo that we don't care about.
598

599
    Likely because ONNX does not support the use case or it is by design.
600

601
    Args:
602
        op_name: The name of the operator.
603
        variant_name: The name of the variant.
604
        opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
605
        dtypes: The dtypes to expect the failure.
606
        reason: The reason for the failure.
607
        matcher: A function that matches the test sample input. It is used only when
608
            skip is in the SKIP_XFAIL_SUBTESTS list.
609
        enabled_if: Whether to enable skip. Usually used on onnx/ort version control
610
        model_type: The type of the torch model. Defaults to None.
611
    """
612
    return DecorateMeta(
613
        op_name=op_name,
614
        variant_name=variant_name,
615
        decorator=unittest.skip(f"Skip: {reason}"),
616
        opsets=opsets,
617
        dtypes=dtypes,
618
        reason=reason,
619
        matcher=matcher,
620
        enabled_if=enabled_if,
621
        test_behavior="skip",
622
        model_type=model_type,
623
    )
624

625

626
def skip_slow(
627
    op_name: str,
628
    variant_name: str = "",
629
    *,
630
    reason: str,
631
    opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
632
    dtypes: Optional[Collection[torch.dtype]] = None,
633
    matcher: Optional[Callable[[Any], Any]] = None,
634
    model_type: Optional[pytorch_test_common.TorchModelType] = None,
635
):
636
    """Skips a test case in OpInfo that is too slow.
637

638
    It needs further investigation to understand why it is slow.
639

640
    Args:
641
        op_name: The name of the operator.
642
        variant_name: The name of the variant.
643
        opsets: The opsets to expect the failure. e.g. [9, 10] or [opsets_before(11)]
644
        dtypes: The dtypes to expect the failure.
645
        reason: The reason for the failure.
646
        matcher: A function that matches the test sample input. It is used only when
647
            skip is in the SKIP_XFAIL_SUBTESTS list.
648
        model_type: The type of the torch model. Defaults to None.
649
    """
650
    return DecorateMeta(
651
        op_name=op_name,
652
        variant_name=variant_name,
653
        decorator=common_utils.slowTest,
654
        opsets=opsets,
655
        dtypes=dtypes,
656
        reason=reason,
657
        matcher=matcher,
658
        enabled_if=not common_utils.TEST_WITH_SLOW,
659
        test_behavior="skip",
660
        model_type=model_type,
661
    )
662

663

664
def add_decorate_info(
665
    all_opinfos: Sequence[opinfo_core.OpInfo],
666
    test_class_name: str,
667
    base_test_name: str,
668
    opset: int,
669
    skip_or_xfails: Iterable[DecorateMeta],
670
):
671
    """Decorates OpInfo tests with decorators based on the skip_or_xfails list.
672

673
    Args:
674
        all_opinfos: All OpInfos.
675
        test_class_name: The name of the test class.
676
        base_test_name: The name of the test method.
677
        opset: The opset to decorate for.
678
        skip_or_xfails: DecorateMeta's.
679
    """
680
    ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos}
681
    for decorate_meta in skip_or_xfails:
682
        if not decorate_meta.contains_opset(opset):
683
            # Skip does not apply to this opset
684
            continue
685
        opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name))
686
        assert (
687
            opinfo is not None
688
        ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?"
689
        decorators = list(opinfo.decorators)
690
        new_decorator = opinfo_core.DecorateInfo(
691
            decorate_meta.decorator,
692
            test_class_name,
693
            base_test_name,
694
            dtypes=decorate_meta.dtypes,
695
            active_if=decorate_meta.enabled_if,
696
        )
697
        decorators.append(new_decorator)
698
        opinfo.decorators = tuple(decorators)
699

700
    # This decorator doesn't modify fn in any way
701
    def wrapped(fn):
702
        return fn
703

704
    return wrapped
705

706

707
def opsets_before(opset: int) -> Callable[[int], bool]:
708
    """Returns a comparison function that decides if the given opset is before the specified."""
709

710
    def compare(other_opset: int):
711
        return other_opset < opset
712

713
    return compare
714

715

716
def opsets_after(opset: int) -> Callable[[int], bool]:
717
    """Returns a comparison function that decides if the given opset is after the specified."""
718

719
    def compare(other_opset: int):
720
        return other_opset > opset
721

722
    return compare
723

724

725
def reason_onnx_script_does_not_support(
726
    operator: str, dtypes: Optional[Sequence[str]] = None
727
) -> str:
728
    """Formats the reason: ONNX script doesn't support the given dtypes."""
729
    return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX script"
730

731

732
def reason_onnx_runtime_does_not_support(
733
    operator: str, dtypes: Optional[Sequence[str]] = None
734
) -> str:
735
    """Formats the reason: ONNX Runtime doesn't support the given dtypes."""
736
    return f"{operator} on {dtypes or 'dtypes'} not supported by ONNX Runtime"
737

738

739
def reason_onnx_does_not_support(
740
    operator: str, dtypes: Optional[Sequence[str]] = None
741
) -> str:
742
    """Formats the reason: ONNX doesn't support the given dtypes."""
743
    return f"{operator} on {dtypes or 'certain dtypes'} not supported by the ONNX Spec"
744

745

746
def reason_dynamo_does_not_support(
747
    operator: str, dtypes: Optional[Sequence[str]] = None
748
) -> str:
749
    """Formats the reason: Dynamo doesn't support the given dtypes."""
750
    return (
751
        f"{operator} on {dtypes or 'certain dtypes'} not supported by the Dynamo Spec"
752
    )
753

754

755
def reason_jit_tracer_error(info: str) -> str:
756
    """Formats the reason: JIT tracer errors."""
757
    return f"JIT tracer error on {info}"
758

759

760
def reason_flaky() -> str:
761
    """Formats the reason: test is flaky."""
762
    return "flaky test"
763

764

765
@contextlib.contextmanager
766
def normal_xfail_skip_test_behaviors(
767
    test_behavior: Optional[str] = None, reason: Optional[str] = None
768
):
769
    """This context manager is used to handle the different behaviors of xfail and skip.
770

771
    Args:
772
        test_behavior (optional[str]): From DecorateMeta name, can be 'skip', 'xfail', or None.
773
        reason (optional[str]): The reason for the failure or skip.
774

775
    Raises:
776
        e: Any exception raised by the test case if it's not an expected failure.
777
    """
778

779
    # We need to skip as soon as possible, as SegFault might also be a case.
780
    if test_behavior == "skip":
781
        pytest.skip(reason=reason)
782

783
    try:
784
        yield
785
    # We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs
786
    # to go over all test cases to find the right exception type.
787
    except Exception as e:  # pylint: disable=broad-exception-caught
788
        if test_behavior is None:
789
            raise e
790
        if test_behavior == "xfail":
791
            pytest.xfail(reason=reason)
792
    else:
793
        if test_behavior == "xfail":
794
            pytest.fail("Test unexpectedly passed")
795

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.