pytorch
179 строк · 5.3 Кб
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5from functools import partial, wraps
6
7import torch
8import torch.distributed as dist
9
10
11if not dist.is_available():
12print("Distributed not available, skipping tests", file=sys.stderr)
13sys.exit(0)
14
15from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
16from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
17
18
19if TEST_WITH_DEV_DBG_ASAN:
20print(
21"Skip dev-asan as torch + multiprocessing spawn have known issues",
22file=sys.stderr,
23)
24sys.exit(0)
25
26BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
27WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
28
29
30def with_comms(func=None):
31if func is None:
32return partial(
33with_comms,
34)
35
36@wraps(func)
37def wrapper(self, *args, **kwargs):
38if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
39sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
40self.dist_init()
41func(self)
42self.destroy_comms()
43
44return wrapper
45
46
47class TestObjectCollectives(MultiProcessTestCase):
48def setUp(self):
49super().setUp()
50os.environ["WORLD_SIZE"] = str(self.world_size)
51os.environ["BACKEND"] = BACKEND
52self._spawn_processes()
53
54@property
55def device(self):
56return (
57torch.device(self.rank)
58if BACKEND == dist.Backend.NCCL
59else torch.device("cpu")
60)
61
62@property
63def world_size(self):
64return WORLD_SIZE
65
66@property
67def process_group(self):
68return dist.group.WORLD
69
70def destroy_comms(self):
71# Wait for all ranks to reach here before starting shutdown.
72dist.barrier()
73dist.destroy_process_group()
74
75def dist_init(self):
76dist.init_process_group(
77backend=BACKEND,
78world_size=self.world_size,
79rank=self.rank,
80init_method=f"file://{self.file_name}",
81)
82
83# set device for nccl pg for collectives
84if BACKEND == "nccl":
85torch.cuda.set_device(self.rank)
86
87@with_comms()
88def test_all_gather_object(self):
89output = [None] * dist.get_world_size()
90dist.all_gather_object(object_list=output, obj=self.rank)
91
92for i, v in enumerate(output):
93self.assertEqual(i, v, f"rank: {self.rank}")
94
95@with_comms()
96def test_gather_object(self):
97output = [None] * dist.get_world_size() if self.rank == 0 else None
98dist.gather_object(obj=self.rank, object_gather_list=output)
99
100if self.rank == 0:
101for i, v in enumerate(output):
102self.assertEqual(i, v, f"rank: {self.rank}")
103
104@with_comms()
105def test_send_recv_object_list(self):
106val = 99 if self.rank == 0 else None
107object_list = [val] * dist.get_world_size()
108if self.rank == 0:
109dist.send_object_list(object_list, 1)
110if self.rank == 1:
111dist.recv_object_list(object_list, 0)
112
113if self.rank < 2:
114self.assertEqual(99, object_list[0])
115else:
116self.assertEqual(None, object_list[0])
117
118@with_comms()
119def test_broadcast_object_list(self):
120val = 99 if self.rank == 0 else None
121object_list = [val] * dist.get_world_size()
122# TODO test with broadcast_object_list's device argument
123dist.broadcast_object_list(object_list=object_list)
124
125self.assertEqual(99, object_list[0])
126
127@with_comms()
128def test_scatter_object_list(self):
129input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
130output_list = [None]
131dist.scatter_object_list(
132scatter_object_output_list=output_list, scatter_object_input_list=input_list
133)
134
135self.assertEqual(self.rank, output_list[0])
136
137# Test Object Collectives With Sub Pg
138
139def setup_sub_pg(self):
140rank = dist.get_rank()
141base_rank = rank - (rank % 2)
142ranks = [base_rank, base_rank + 1]
143my_pg = dist.new_group(ranks, use_local_synchronization=True)
144return rank, ranks, my_pg
145
146@with_comms()
147def test_subpg_scatter_object(self):
148rank, ranks, my_pg = self.setup_sub_pg()
149out_list = [None]
150dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
151self.assertEqual(rank, out_list[0])
152
153@with_comms()
154def test_subpg_all_gather_object(self):
155rank, ranks, my_pg = self.setup_sub_pg()
156out_list = [None] * len(ranks)
157dist.all_gather_object(out_list, rank, group=my_pg)
158self.assertEqual(ranks, out_list)
159
160@with_comms()
161def test_subpg_gather_object(self):
162rank, ranks, my_pg = self.setup_sub_pg()
163out_list = [None] * len(ranks) if rank == ranks[0] else None
164dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
165if rank == ranks[0]:
166self.assertEqual(ranks, out_list)
167
168@with_comms()
169def test_subpg_broadcast_object(self):
170rank, ranks, my_pg = self.setup_sub_pg()
171out_list = [None]
172if rank == ranks[0]:
173out_list[0] = rank
174dist.broadcast_object_list(out_list, src=ranks[0], group=my_pg)
175self.assertEqual(ranks[0], out_list[0])
176
177
178if __name__ == "__main__":
179run_tests()
180