pytorch
57 строк · 1.4 Кб
1# Owner(s): ["oncall: distributed"]
2
3import os4import sys5from contextlib import closing6
7import torch.distributed as dist8import torch.distributed.launch as launch9from torch.distributed.elastic.utils import get_socket_with_port10
11
12if not dist.is_available():13print("Distributed not available, skipping tests", file=sys.stderr)14sys.exit(0)15
16from torch.testing._internal.common_utils import (17run_tests,18TEST_WITH_DEV_DBG_ASAN,19TestCase,20)
21
22
23def path(script):24return os.path.join(os.path.dirname(__file__), script)25
26
27if TEST_WITH_DEV_DBG_ASAN:28print(29"Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr30)31sys.exit(0)32
33
34class TestDistributedLaunch(TestCase):35def test_launch_user_script(self):36nnodes = 137nproc_per_node = 438world_size = nnodes * nproc_per_node39sock = get_socket_with_port()40with closing(sock):41master_port = sock.getsockname()[1]42args = [43f"--nnodes={nnodes}",44f"--nproc-per-node={nproc_per_node}",45"--monitor-interval=1",46"--start-method=spawn",47"--master-addr=localhost",48f"--master-port={master_port}",49"--node-rank=0",50"--use-env",51path("bin/test_script.py"),52]53launch.main(args)54
55
56if __name__ == "__main__":57run_tests()58