pytorch
59 строк · 1.9 Кб
1# Owner(s): ["oncall: distributed"]
2
3import os4import sys5
6import torch7import torch.distributed as dist8from os import path9
10torch.backends.cuda.matmul.allow_tf32 = False11
12if not dist.is_available():13print("Distributed not available, skipping tests", file=sys.stderr)14sys.exit(0)15
16from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN, NO_MULTIPROCESSING_SPAWN17from torch.testing._internal.distributed.distributed_test import (18DistributedTest, TestDistBackend19)
20
21if TEST_WITH_DEV_DBG_ASAN:22print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)23sys.exit(0)24
25if NO_MULTIPROCESSING_SPAWN:26print("Spawn not available, skipping tests.", file=sys.stderr)27sys.exit(0)28
29_allowed_backends = ("gloo", "nccl", "ucc")30if (31"BACKEND" not in os.environ32or "WORLD_SIZE" not in os.environ33or "TEMP_DIR" not in os.environ34or not path.exists(path.join(os.environ["TEMP_DIR"], "barrier"))35):36# TODO can we actually have `run_tests.py` emit the complete instructions when it prints a repro command?37raise RuntimeError(38"Missing expected env vars for `test_distributed_spawn.py`. Please ensure to specify the following:\n"39f"'BACKEND' = one of {_allowed_backends}\n"40f"'WORLD_SIZE' = int >= 2\n"41"'TEMP_DIR' specifying a directory containing a barrier file named 'barrier'.\n\n"42f"e.g.\ntouch /tmp/barrier && TEMP_DIR=/tmp BACKEND='nccl' WORLD_SIZE=2 python {__file__}",43)44
45BACKEND = os.environ["BACKEND"]46
47if BACKEND in _allowed_backends:48class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):49
50def setUp(self):51super().setUp()52self._spawn_processes()53torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__()54else:55print(f"Invalid backend {BACKEND}. Tests will not be run!")56
57
58if __name__ == "__main__":59run_tests()60