pytorch

Форк
0
/
pytorch_test_common.py 
420 строк · 12.1 Кб
1
# Owner(s): ["module: onnx"]
2
from __future__ import annotations
3

4
import functools
5
import os
6
import random
7
import sys
8
import unittest
9
from enum import auto, Enum
10
from typing import Optional
11

12
import numpy as np
13
import packaging.version
14
import pytest
15

16
import torch
17
from torch.autograd import function
18
from torch.onnx._internal import diagnostics
19
from torch.testing._internal import common_utils
20

21

22
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
23
sys.path.insert(-1, pytorch_test_dir)
24

25
torch.set_default_dtype(torch.float)
26

27
BATCH_SIZE = 2
28

29
RNN_BATCH_SIZE = 7
30
RNN_SEQUENCE_LENGTH = 11
31
RNN_INPUT_SIZE = 5
32
RNN_HIDDEN_SIZE = 3
33

34

35
class TorchModelType(Enum):
36
    TORCH_NN_MODULE = auto()
37
    TORCH_EXPORT_EXPORTEDPROGRAM = auto()
38

39

40
def _skipper(condition, reason):
41
    def decorator(f):
42
        @functools.wraps(f)
43
        def wrapper(*args, **kwargs):
44
            if condition():
45
                raise unittest.SkipTest(reason)
46
            return f(*args, **kwargs)
47

48
        return wrapper
49

50
    return decorator
51

52

53
skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(), "CUDA is not available")
54

55
skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"), "Skip In Travis")
56

57
skipIfNoBFloat16Cuda = _skipper(
58
    lambda: not torch.cuda.is_bf16_supported(), "BFloat16 CUDA is not available"
59
)
60

61
skipIfQuantizationBackendQNNPack = _skipper(
62
    lambda: torch.backends.quantized.engine == "qnnpack",
63
    "Not compatible with QNNPack quantization backend",
64
)
65

66

67
# skips tests for all versions below min_opset_version.
68
# add this wrapper to prevent running the test for opset_versions
69
# smaller than `min_opset_version`.
70
def skipIfUnsupportedMinOpsetVersion(min_opset_version):
71
    def skip_dec(func):
72
        @functools.wraps(func)
73
        def wrapper(self, *args, **kwargs):
74
            if self.opset_version < min_opset_version:
75
                raise unittest.SkipTest(
76
                    f"Unsupported opset_version: {self.opset_version} < {min_opset_version}"
77
                )
78
            return func(self, *args, **kwargs)
79

80
        return wrapper
81

82
    return skip_dec
83

84

85
# skips tests for all versions above max_opset_version.
86
# add this wrapper to prevent running the test for opset_versions
87
# higher than `max_opset_version`.
88
def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
89
    def skip_dec(func):
90
        @functools.wraps(func)
91
        def wrapper(self, *args, **kwargs):
92
            if self.opset_version > max_opset_version:
93
                raise unittest.SkipTest(
94
                    f"Unsupported opset_version: {self.opset_version} > {max_opset_version}"
95
                )
96
            return func(self, *args, **kwargs)
97

98
        return wrapper
99

100
    return skip_dec
101

102

103
# skips tests for all opset versions.
104
def skipForAllOpsetVersions():
105
    def skip_dec(func):
106
        @functools.wraps(func)
107
        def wrapper(self, *args, **kwargs):
108
            if self.opset_version:
109
                raise unittest.SkipTest(
110
                    "Skip verify test for unsupported opset_version"
111
                )
112
            return func(self, *args, **kwargs)
113

114
        return wrapper
115

116
    return skip_dec
117

118

119
def skipTraceTest(skip_before_opset_version: Optional[int] = None, reason: str = ""):
120
    """Skip tracing test for opset version less than skip_before_opset_version.
121

122
    Args:
123
        skip_before_opset_version: The opset version before which to skip tracing test.
124
            If None, tracing test is always skipped.
125
        reason: The reason for skipping tracing test.
126

127
    Returns:
128
        A decorator for skipping tracing test.
129
    """
130

131
    def skip_dec(func):
132
        @functools.wraps(func)
133
        def wrapper(self, *args, **kwargs):
134
            if skip_before_opset_version is not None:
135
                self.skip_this_opset = self.opset_version < skip_before_opset_version
136
            else:
137
                self.skip_this_opset = True
138
            if self.skip_this_opset and not self.is_script:
139
                raise unittest.SkipTest(f"Skip verify test for torch trace. {reason}")
140
            return func(self, *args, **kwargs)
141

142
        return wrapper
143

144
    return skip_dec
145

146

147
def skipScriptTest(skip_before_opset_version: Optional[int] = None, reason: str = ""):
148
    """Skip scripting test for opset version less than skip_before_opset_version.
149

150
    Args:
151
        skip_before_opset_version: The opset version before which to skip scripting test.
152
            If None, scripting test is always skipped.
153
        reason: The reason for skipping scripting test.
154

155
    Returns:
156
        A decorator for skipping scripting test.
157
    """
158

159
    def skip_dec(func):
160
        @functools.wraps(func)
161
        def wrapper(self, *args, **kwargs):
162
            if skip_before_opset_version is not None:
163
                self.skip_this_opset = self.opset_version < skip_before_opset_version
164
            else:
165
                self.skip_this_opset = True
166
            if self.skip_this_opset and self.is_script:
167
                raise unittest.SkipTest(f"Skip verify test for TorchScript. {reason}")
168
            return func(self, *args, **kwargs)
169

170
        return wrapper
171

172
    return skip_dec
173

174

175
# NOTE: This decorator is currently unused, but we may want to use it in the future when
176
# we have more tests that are not supported in released ORT.
177
def skip_min_ort_version(reason: str, version: str, dynamic_only: bool = False):
178
    def skip_dec(func):
179
        @functools.wraps(func)
180
        def wrapper(self, *args, **kwargs):
181
            if (
182
                packaging.version.parse(self.ort_version).release
183
                < packaging.version.parse(version).release
184
            ):
185
                if dynamic_only and not self.dynamic_shapes:
186
                    return func(self, *args, **kwargs)
187

188
                raise unittest.SkipTest(
189
                    f"ONNX Runtime version: {version} is older than required version {version}. "
190
                    f"Reason: {reason}."
191
                )
192
            return func(self, *args, **kwargs)
193

194
        return wrapper
195

196
    return skip_dec
197

198

199
def xfail_dynamic_fx_test(
200
    error_message: str,
201
    model_type: Optional[TorchModelType] = None,
202
    reason: Optional[str] = None,
203
):
204
    """Xfail dynamic exporting test.
205

206
    Args:
207
        reason: The reason for xfailing dynamic exporting test.
208
        model_type (TorchModelType): The model type to xfail dynamic exporting test for.
209
            When None, model type is not used to xfail dynamic tests.
210

211
    Returns:
212
        A decorator for xfailing dynamic exporting test.
213
    """
214

215
    def skip_dec(func):
216
        @functools.wraps(func)
217
        def wrapper(self, *args, **kwargs):
218
            if self.dynamic_shapes and (
219
                not model_type or self.model_type == model_type
220
            ):
221
                return xfail(error_message, reason)(func)(self, *args, **kwargs)
222
            return func(self, *args, **kwargs)
223

224
        return wrapper
225

226
    return skip_dec
227

228

229
def skip_dynamic_fx_test(reason: str, model_type: TorchModelType = None):
230
    """Skip dynamic exporting test.
231

232
    Args:
233
        reason: The reason for skipping dynamic exporting test.
234
        model_type (TorchModelType): The model type to skip dynamic exporting test for.
235
            When None, model type is not used to skip dynamic tests.
236

237
    Returns:
238
        A decorator for skipping dynamic exporting test.
239
    """
240

241
    def skip_dec(func):
242
        @functools.wraps(func)
243
        def wrapper(self, *args, **kwargs):
244
            if self.dynamic_shapes and (
245
                not model_type or self.model_type == model_type
246
            ):
247
                raise unittest.SkipTest(
248
                    f"Skip verify dynamic shapes test for FX. {reason}"
249
                )
250
            return func(self, *args, **kwargs)
251

252
        return wrapper
253

254
    return skip_dec
255

256

257
def skip_in_ci(reason: str):
258
    """Skip test in CI.
259

260
    Args:
261
        reason: The reason for skipping test in CI.
262

263
    Returns:
264
        A decorator for skipping test in CI.
265
    """
266

267
    def skip_dec(func):
268
        @functools.wraps(func)
269
        def wrapper(self, *args, **kwargs):
270
            if os.getenv("CI"):
271
                raise unittest.SkipTest(f"Skip test in CI. {reason}")
272
            return func(self, *args, **kwargs)
273

274
        return wrapper
275

276
    return skip_dec
277

278

279
def xfail(error_message: str, reason: Optional[str] = None):
280
    """Expect failure.
281

282
    Args:
283
        reason: The reason for expected failure.
284

285
    Returns:
286
        A decorator for expecting test failure.
287
    """
288

289
    def wrapper(func):
290
        @functools.wraps(func)
291
        def inner(self, *args, **kwargs):
292
            try:
293
                func(self, *args, **kwargs)
294
            except Exception as e:
295
                if isinstance(e, torch.onnx.OnnxExporterError):
296
                    # diagnostic message is in the cause of the exception
297
                    assert (
298
                        error_message in str(e.__cause__)
299
                    ), f"Expected error message: {error_message} NOT in {str(e.__cause__)}"
300
                else:
301
                    assert error_message in str(
302
                        e
303
                    ), f"Expected error message: {error_message} NOT in {str(e)}"
304
                pytest.xfail(reason if reason else f"Expected failure: {error_message}")
305
            else:
306
                pytest.fail("Unexpected success!")
307

308
        return inner
309

310
    return wrapper
311

312

313
# skips tests for opset_versions listed in unsupported_opset_versions.
314
# if the PyTorch test cannot be run for a specific version, add this wrapper
315
# (for example, an op was modified but the change is not supported in PyTorch)
316
def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
317
    def skip_dec(func):
318
        @functools.wraps(func)
319
        def wrapper(self, *args, **kwargs):
320
            if self.opset_version in unsupported_opset_versions:
321
                raise unittest.SkipTest(
322
                    "Skip verify test for unsupported opset_version"
323
                )
324
            return func(self, *args, **kwargs)
325

326
        return wrapper
327

328
    return skip_dec
329

330

331
def skipShapeChecking(func):
332
    @functools.wraps(func)
333
    def wrapper(self, *args, **kwargs):
334
        self.check_shape = False
335
        return func(self, *args, **kwargs)
336

337
    return wrapper
338

339

340
def skipDtypeChecking(func):
341
    @functools.wraps(func)
342
    def wrapper(self, *args, **kwargs):
343
        self.check_dtype = False
344
        return func(self, *args, **kwargs)
345

346
    return wrapper
347

348

349
def xfail_if_model_type_is_exportedprogram(
350
    error_message: str, reason: Optional[str] = None
351
):
352
    """xfail test with models using ExportedProgram as input.
353

354
    Args:
355
        error_message: The error message to raise when the test is xfailed.
356
        reason: The reason for xfail the ONNX export test.
357

358
    Returns:
359
        A decorator for xfail tests.
360
    """
361

362
    def xfail_dec(func):
363
        @functools.wraps(func)
364
        def wrapper(self, *args, **kwargs):
365
            if self.model_type == TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
366
                return xfail(error_message, reason)(func)(self, *args, **kwargs)
367
            return func(self, *args, **kwargs)
368

369
        return wrapper
370

371
    return xfail_dec
372

373

374
def xfail_if_model_type_is_not_exportedprogram(
375
    error_message: str, reason: Optional[str] = None
376
):
377
    """xfail test without models using ExportedProgram as input.
378

379
    Args:
380
        reason: The reason for xfail the ONNX export test.
381

382
    Returns:
383
        A decorator for xfail tests.
384
    """
385

386
    def xfail_dec(func):
387
        @functools.wraps(func)
388
        def wrapper(self, *args, **kwargs):
389
            if self.model_type != TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
390
                return xfail(error_message, reason)(func)(self, *args, **kwargs)
391
            return func(self, *args, **kwargs)
392

393
        return wrapper
394

395
    return xfail_dec
396

397

398
def flatten(x):
399
    return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))
400

401

402
def set_rng_seed(seed):
403
    torch.manual_seed(seed)
404
    random.seed(seed)
405
    np.random.seed(seed)
406

407

408
class ExportTestCase(common_utils.TestCase):
409
    """Test case for ONNX export.
410

411
    Any test case that tests functionalities under torch.onnx should inherit from this class.
412
    """
413

414
    def setUp(self):
415
        super().setUp()
416
        # TODO(#88264): Flaky test failures after changing seed.
417
        set_rng_seed(0)
418
        if torch.cuda.is_available():
419
            torch.cuda.manual_seed_all(0)
420
        diagnostics.engine.clear()
421

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.