2
from __future__ import annotations
9
from enum import auto, Enum
10
from typing import Optional
13
import packaging.version
17
from torch.autograd import function
18
from torch.onnx._internal import diagnostics
19
from torch.testing._internal import common_utils
22
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
23
sys.path.insert(-1, pytorch_test_dir)
25
torch.set_default_dtype(torch.float)
30
RNN_SEQUENCE_LENGTH = 11
35
class TorchModelType(Enum):
36
TORCH_NN_MODULE = auto()
37
TORCH_EXPORT_EXPORTEDPROGRAM = auto()
40
def _skipper(condition, reason):
43
def wrapper(*args, **kwargs):
45
raise unittest.SkipTest(reason)
46
return f(*args, **kwargs)
53
skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(), "CUDA is not available")
55
skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"), "Skip In Travis")
57
skipIfNoBFloat16Cuda = _skipper(
58
lambda: not torch.cuda.is_bf16_supported(), "BFloat16 CUDA is not available"
61
skipIfQuantizationBackendQNNPack = _skipper(
62
lambda: torch.backends.quantized.engine == "qnnpack",
63
"Not compatible with QNNPack quantization backend",
70
def skipIfUnsupportedMinOpsetVersion(min_opset_version):
72
@functools.wraps(func)
73
def wrapper(self, *args, **kwargs):
74
if self.opset_version < min_opset_version:
75
raise unittest.SkipTest(
76
f"Unsupported opset_version: {self.opset_version} < {min_opset_version}"
78
return func(self, *args, **kwargs)
88
def skipIfUnsupportedMaxOpsetVersion(max_opset_version):
90
@functools.wraps(func)
91
def wrapper(self, *args, **kwargs):
92
if self.opset_version > max_opset_version:
93
raise unittest.SkipTest(
94
f"Unsupported opset_version: {self.opset_version} > {max_opset_version}"
96
return func(self, *args, **kwargs)
104
def skipForAllOpsetVersions():
106
@functools.wraps(func)
107
def wrapper(self, *args, **kwargs):
108
if self.opset_version:
109
raise unittest.SkipTest(
110
"Skip verify test for unsupported opset_version"
112
return func(self, *args, **kwargs)
119
def skipTraceTest(skip_before_opset_version: Optional[int] = None, reason: str = ""):
120
"""Skip tracing test for opset version less than skip_before_opset_version.
123
skip_before_opset_version: The opset version before which to skip tracing test.
124
If None, tracing test is always skipped.
125
reason: The reason for skipping tracing test.
128
A decorator for skipping tracing test.
132
@functools.wraps(func)
133
def wrapper(self, *args, **kwargs):
134
if skip_before_opset_version is not None:
135
self.skip_this_opset = self.opset_version < skip_before_opset_version
137
self.skip_this_opset = True
138
if self.skip_this_opset and not self.is_script:
139
raise unittest.SkipTest(f"Skip verify test for torch trace. {reason}")
140
return func(self, *args, **kwargs)
147
def skipScriptTest(skip_before_opset_version: Optional[int] = None, reason: str = ""):
148
"""Skip scripting test for opset version less than skip_before_opset_version.
151
skip_before_opset_version: The opset version before which to skip scripting test.
152
If None, scripting test is always skipped.
153
reason: The reason for skipping scripting test.
156
A decorator for skipping scripting test.
160
@functools.wraps(func)
161
def wrapper(self, *args, **kwargs):
162
if skip_before_opset_version is not None:
163
self.skip_this_opset = self.opset_version < skip_before_opset_version
165
self.skip_this_opset = True
166
if self.skip_this_opset and self.is_script:
167
raise unittest.SkipTest(f"Skip verify test for TorchScript. {reason}")
168
return func(self, *args, **kwargs)
177
def skip_min_ort_version(reason: str, version: str, dynamic_only: bool = False):
179
@functools.wraps(func)
180
def wrapper(self, *args, **kwargs):
182
packaging.version.parse(self.ort_version).release
183
< packaging.version.parse(version).release
185
if dynamic_only and not self.dynamic_shapes:
186
return func(self, *args, **kwargs)
188
raise unittest.SkipTest(
189
f"ONNX Runtime version: {version} is older than required version {version}. "
192
return func(self, *args, **kwargs)
199
def xfail_dynamic_fx_test(
201
model_type: Optional[TorchModelType] = None,
202
reason: Optional[str] = None,
204
"""Xfail dynamic exporting test.
207
reason: The reason for xfailing dynamic exporting test.
208
model_type (TorchModelType): The model type to xfail dynamic exporting test for.
209
When None, model type is not used to xfail dynamic tests.
212
A decorator for xfailing dynamic exporting test.
216
@functools.wraps(func)
217
def wrapper(self, *args, **kwargs):
218
if self.dynamic_shapes and (
219
not model_type or self.model_type == model_type
221
return xfail(error_message, reason)(func)(self, *args, **kwargs)
222
return func(self, *args, **kwargs)
229
def skip_dynamic_fx_test(reason: str, model_type: TorchModelType = None):
230
"""Skip dynamic exporting test.
233
reason: The reason for skipping dynamic exporting test.
234
model_type (TorchModelType): The model type to skip dynamic exporting test for.
235
When None, model type is not used to skip dynamic tests.
238
A decorator for skipping dynamic exporting test.
242
@functools.wraps(func)
243
def wrapper(self, *args, **kwargs):
244
if self.dynamic_shapes and (
245
not model_type or self.model_type == model_type
247
raise unittest.SkipTest(
248
f"Skip verify dynamic shapes test for FX. {reason}"
250
return func(self, *args, **kwargs)
257
def skip_in_ci(reason: str):
261
reason: The reason for skipping test in CI.
264
A decorator for skipping test in CI.
268
@functools.wraps(func)
269
def wrapper(self, *args, **kwargs):
271
raise unittest.SkipTest(f"Skip test in CI. {reason}")
272
return func(self, *args, **kwargs)
279
def xfail(error_message: str, reason: Optional[str] = None):
283
reason: The reason for expected failure.
286
A decorator for expecting test failure.
290
@functools.wraps(func)
291
def inner(self, *args, **kwargs):
293
func(self, *args, **kwargs)
294
except Exception as e:
295
if isinstance(e, torch.onnx.OnnxExporterError):
298
error_message in str(e.__cause__)
299
), f"Expected error message: {error_message} NOT in {str(e.__cause__)}"
301
assert error_message in str(
303
), f"Expected error message: {error_message} NOT in {str(e)}"
304
pytest.xfail(reason if reason else f"Expected failure: {error_message}")
306
pytest.fail("Unexpected success!")
316
def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
318
@functools.wraps(func)
319
def wrapper(self, *args, **kwargs):
320
if self.opset_version in unsupported_opset_versions:
321
raise unittest.SkipTest(
322
"Skip verify test for unsupported opset_version"
324
return func(self, *args, **kwargs)
331
def skipShapeChecking(func):
332
@functools.wraps(func)
333
def wrapper(self, *args, **kwargs):
334
self.check_shape = False
335
return func(self, *args, **kwargs)
340
def skipDtypeChecking(func):
341
@functools.wraps(func)
342
def wrapper(self, *args, **kwargs):
343
self.check_dtype = False
344
return func(self, *args, **kwargs)
349
def xfail_if_model_type_is_exportedprogram(
350
error_message: str, reason: Optional[str] = None
352
"""xfail test with models using ExportedProgram as input.
355
error_message: The error message to raise when the test is xfailed.
356
reason: The reason for xfail the ONNX export test.
359
A decorator for xfail tests.
363
@functools.wraps(func)
364
def wrapper(self, *args, **kwargs):
365
if self.model_type == TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
366
return xfail(error_message, reason)(func)(self, *args, **kwargs)
367
return func(self, *args, **kwargs)
374
def xfail_if_model_type_is_not_exportedprogram(
375
error_message: str, reason: Optional[str] = None
377
"""xfail test without models using ExportedProgram as input.
380
reason: The reason for xfail the ONNX export test.
383
A decorator for xfail tests.
387
@functools.wraps(func)
388
def wrapper(self, *args, **kwargs):
389
if self.model_type != TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM:
390
return xfail(error_message, reason)(func)(self, *args, **kwargs)
391
return func(self, *args, **kwargs)
399
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))
402
def set_rng_seed(seed):
403
torch.manual_seed(seed)
408
class ExportTestCase(common_utils.TestCase):
409
"""Test case for ONNX export.
411
Any test case that tests functionalities under torch.onnx should inherit from this class.
418
if torch.cuda.is_available():
419
torch.cuda.manual_seed_all(0)
420
diagnostics.engine.clear()