pytorch
1# Owner(s): ["module: onnx"]
2
3import subprocess
4import sys
5import tempfile
6
7import pytorch_test_common
8
9from torch.testing._internal import common_utils
10
11
12class TestLazyONNXPackages(pytorch_test_common.ExportTestCase):
13def _test_package_is_lazily_imported(self, pkg, torch_pkg="torch.onnx"):
14with tempfile.TemporaryDirectory() as wd:
15r = subprocess.run(
16[sys.executable, "-Ximporttime", "-c", "import torch.onnx"],
17capture_output=True,
18text=True,
19cwd=wd,
20check=True,
21)
22
23# The extra space makes sure we're checking the package, not any package containing its name.
24self.assertTrue(
25f" {pkg}" not in r.stderr,
26f"`{pkg}` should not be imported, full importtime: {r.stderr}",
27)
28
29def test_onnxruntime_is_lazily_imported(self):
30self._test_package_is_lazily_imported("onnxruntime")
31
32def test_onnxscript_is_lazily_imported(self):
33self._test_package_is_lazily_imported("onnxscript")
34
35
36if __name__ == "__main__":
37common_utils.run_tests()
38