pytorch
341 строка · 12.0 Кб
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import os
5import sys
6import tempfile
7
8import test_c10d_spawn
9from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
10
11import torch
12import torch.distributed as c10d
13import torch.nn as nn
14from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
15from torch.testing._internal.common_distributed import (
16create_device,
17requires_gloo,
18skip_if_lt_x_gpu,
19)
20from torch.testing._internal.common_utils import (
21run_tests,
22skip_but_pass_in_sandcastle_if,
23TEST_WITH_DEV_DBG_ASAN,
24TestCase,
25)
26
27
28# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
29if sys.version_info < (3, 9):
30
31class ProcessGroupShareTensorTest(
32test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
33):
34@classmethod
35def opts(cls, threads=2):
36opts = c10d.ProcessGroupGloo._Options()
37opts._timeout = 5.0
38opts._devices = [create_device(interface="lo")]
39opts._threads = threads
40return opts
41
42@classmethod
43def _init_pg_gloo(cls, rank, filename, world_size):
44store = c10d.FileStore(filename, world_size)
45backend = c10d.ProcessGroupGloo(
46store, rank, world_size, ProcessGroupShareTensorTest.opts()
47)
48# set process group backends manually
49c10d.init_process_group(
50backend="gloo", store=store, rank=rank, world_size=world_size
51)
52pg = c10d.distributed_c10d._get_default_group()
53pg._register_backend(
54torch.device("cpu"), c10d.ProcessGroup.BackendType.GLOO, backend
55)
56pg._register_backend(
57torch.device("cuda"), c10d.ProcessGroup.BackendType.GLOO, backend
58)
59
60return pg
61
62@skip_but_pass_in_sandcastle_if(
63not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
64)
65def test_shared_broadcast_gloo(self):
66self._test_multiprocess(
67ProcessGroupShareTensorTest._test_broadcast_process,
68[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
69ProcessGroupShareTensorTest._init_pg_gloo,
701,
71)
72
73@skip_but_pass_in_sandcastle_if(
74not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
75)
76def test_shared_allreduce_gloo(self):
77self._test_multiprocess(
78ProcessGroupShareTensorTest._test_allreduce_process,
79[torch.ones(2, 2).to(i) for i in range(self.world_size)],
80ProcessGroupShareTensorTest._init_pg_gloo,
811,
82)
83
84@skip_but_pass_in_sandcastle_if(
85not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
86)
87def test_shared_allgather_gloo(self):
88self._test_multiprocess(
89ProcessGroupShareTensorTest._test_allgather_process,
90[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
91ProcessGroupShareTensorTest._init_pg_gloo,
92self.world_size,
93)
94
95@classmethod
96def _test_allgather_chunk_process(
97cls, rank, filename, shared_tensor, world_size, init_pg, c2p, p2c
98):
99pg = init_pg(rank, filename, world_size)
100chunks = torch.chunk(shared_tensor, world_size, dim=0)
101x = chunks[rank]
102ys = [torch.zeros_like(x) for _ in range(world_size)]
103pg.allgather(ys, x).wait()
104c2p.put((rank, chunks[0].to("cpu"), ys[0].to("cpu")))
105c2p.put((rank, chunks[1].to("cpu"), ys[1].to("cpu")))
106p2c.get()
107
108@skip_but_pass_in_sandcastle_if(
109not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
110)
111def test_shared_allgather_chunk_gloo(self):
112self._test_multiprocess(
113ProcessGroupShareTensorTest._test_allgather_chunk_process,
114torch.tensor(range(4)).reshape(2, 2),
115ProcessGroupShareTensorTest._init_pg_gloo,
116self.world_size,
117)
118
119
120class DistributedDataParallelSingleProcessTest(TestCase):
121def setUp(self):
122self.rank = 0
123self.world_size = 1
124self.file = tempfile.NamedTemporaryFile(delete=False) # noqa: P201
125
126def tearDown(self):
127try:
128os.remove(self.file.name)
129except OSError:
130pass
131
132def _test_base(self, net, inp, check_allclose=True):
133store = c10d.FileStore(self.file.name, self.world_size)
134c10d.init_process_group(
135backend="gloo", store=store, rank=self.rank, world_size=self.world_size
136)
137process_group = c10d.distributed_c10d._get_default_group()
138if inp[0].is_cuda:
139device_ids = [torch.cuda.current_device()]
140else:
141device_ids = None
142
143ddp = nn.parallel.DistributedDataParallel(
144copy.deepcopy(net), device_ids=device_ids, process_group=process_group
145)
146
147net_opt = torch.optim.Adam(net.parameters(), lr=0.001)
148ddp_opt = torch.optim.Adam(ddp.parameters(), lr=0.001)
149
150for i, j in zip(ddp.parameters(), net.parameters()):
151self.assertTrue(i.allclose(j))
152
153for _ in range(10):
154net_out = net(*inp)
155ddp_out = ddp(*inp)
156
157net_out.sum().backward()
158ddp_out.sum().backward()
159
160net_opt.step()
161ddp_opt.step()
162
163if check_allclose:
164for i, j in zip(ddp.parameters(), net.parameters()):
165self.assertTrue(i.allclose(j))
166
167@requires_gloo()
168def test_cpu(self):
169self._test_base(nn.Linear(2, 2), [torch.randn(30, 2)])
170
171@requires_gloo()
172@skip_but_pass_in_sandcastle_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
173def test_cuda(self):
174self._test_base(nn.Linear(2, 2).to(0), [torch.randn(30, 2).to(0)])
175
176@requires_gloo()
177@skip_but_pass_in_sandcastle_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
178def test_rnn(self):
179# This test is inspired by the bug reported in
180# https://github.com/pytorch/pytorch/issues/36268
181BATCH_SIZE = 12 # Divisible by 2, 3, 4
182INPUT_DIM = 256
183OUTPUT_DIM = 256
184HIDDEN_DIM = 256
185N_LAYERS = 3
186SEQ_LEN = 100
187
188class Net(nn.Module):
189def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
190super().__init__()
191self.input_dim = input_dim
192self.hidden_dim = hidden_dim
193self.output_dim = output_dim
194self.hidden_layers = hidden_layers
195
196self.lstm = nn.LSTM(
197input_dim, hidden_dim, hidden_layers, batch_first=True
198)
199self.h2o = nn.Linear(hidden_dim, output_dim)
200
201def forward(self, x, y):
202self.lstm.flatten_parameters()
203h_t, _ = self.lstm(x)
204output = self.h2o(h_t)
205loss = nn.functional.mse_loss(output, y)
206return loss
207
208net = Net(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS).to(0)
209inp = [
210torch.randn((BATCH_SIZE, SEQ_LEN, INPUT_DIM)).to(0),
211torch.rand((BATCH_SIZE, SEQ_LEN, OUTPUT_DIM)).to(0),
212]
213
214# Not checking result allclose as the parameter inconsistency exist
215# prior to this change. See #37079
216self._test_base(net, inp, check_allclose=False)
217
218
219# Skip dev-asan as torch + multiprocessing spawn have known issues
220if not TEST_WITH_DEV_DBG_ASAN:
221
222class TestDistributedNNFunctionsGloo(TestDistributedNNFunctions):
223# Test Common Ops First.
224@requires_gloo()
225@skip_if_lt_x_gpu(2)
226@skip_but_pass_in_sandcastle_if(
227not _torch_dist_nn_available, "torch.distributed.nn is not available"
228)
229def test_broadcast(self):
230self._test_broadcast("gloo")
231
232@requires_gloo()
233@skip_if_lt_x_gpu(2)
234@skip_but_pass_in_sandcastle_if(
235not _torch_dist_nn_available, "torch.distributed.nn is not available"
236)
237def test_reduce(self):
238self._test_reduce("gloo")
239
240@requires_gloo()
241@skip_if_lt_x_gpu(2)
242@skip_but_pass_in_sandcastle_if(
243not _torch_dist_nn_available, "torch.distributed.nn is not available"
244)
245def test_allreduce(self):
246self._test_allreduce("gloo")
247
248@requires_gloo()
249@skip_if_lt_x_gpu(2)
250@skip_but_pass_in_sandcastle_if(
251not _torch_dist_nn_available, "torch.distributed.nn is not available"
252)
253def test_all_gather(self):
254self._test_all_gather("gloo")
255
256@requires_gloo()
257@skip_if_lt_x_gpu(2)
258@skip_but_pass_in_sandcastle_if(
259not _torch_dist_nn_available, "torch.distributed.nn is not available"
260)
261def test_all_to_all(self):
262self._test_all_to_all("gloo")
263
264@requires_gloo()
265@skip_if_lt_x_gpu(2)
266@skip_but_pass_in_sandcastle_if(
267not _torch_dist_nn_available, "torch.distributed.nn is not available"
268)
269def test_all_to_all_single(self):
270self._test_all_to_all_single("gloo")
271
272# Test Ops only supported in GLOO.
273@requires_gloo()
274@skip_if_lt_x_gpu(2)
275@skip_but_pass_in_sandcastle_if(
276not _torch_dist_nn_available, "torch.distributed.nn is not available"
277)
278def test_gather(self):
279store = c10d.FileStore(self.file_name, self.world_size)
280# This is required because these functions calls directly to the .dist and needs
281# the world to be initialized
282c10d.init_process_group(
283store=store, rank=self.rank, world_size=self.world_size, backend="gloo"
284)
285device = torch.device(f"cuda:{self.rank}")
286x = torch.ones(5, 5, device=device) + self.rank
287x.requires_grad = True
288tensors = torch.distributed.nn.gather(x, 1)
289if self.rank == 1:
290for i, t in enumerate(tensors):
291self.assertEqual(t, torch.ones(5, 5, device=device) + i)
292elif self.rank == 0:
293for i, t in enumerate(tensors):
294zeros = torch.zeros(5, 5, device=device)
295self.assertEqual(t, zeros)
296y = torch.sum(torch.stack(tensors), axis=0)
297z = y.sin().sum()
298z.backward()
299
300# Test gradient
301x_s = 3 * torch.ones(5, 5, device=device)
302self.assertEqual(x.grad, x_s.cos())
303
304@requires_gloo()
305@skip_if_lt_x_gpu(2)
306@skip_but_pass_in_sandcastle_if(
307not _torch_dist_nn_available, "torch.distributed.nn is not available"
308)
309def test_scatter(self):
310store = c10d.FileStore(self.file_name, self.world_size)
311# This is required because these functions calls directly to the .dist and needs
312# the world to be initialized
313c10d.init_process_group(
314store=store, rank=self.rank, world_size=self.world_size, backend="gloo"
315)
316device = torch.device(f"cuda:{self.rank}")
317x0 = torch.ones(5, 5, device=device)
318x1 = torch.ones(5, 5, device=device) + 1
319x0.requires_grad = True
320x1.requires_grad = True
321
322y = torch.distributed.nn.scatter([x0, x1], 1)
323if self.rank == 1:
324self.assertEqual(y, 1 + torch.ones(5, 5, device=device))
325elif self.rank == 0:
326self.assertEqual(y, torch.ones(5, 5, device=device))
327z = y.sin().sum()
328z.backward()
329
330# Test gradient
331if self.rank == 1:
332x0_s = torch.ones(5, 5, device=device).cos()
333x1_s = (2 * torch.ones(5, 5, device=device)).cos()
334self.assertEqual(x0.grad, x0_s)
335self.assertEqual(x1.grad, x1_s)
336if self.rank == 0:
337self.assertEqual(x0.grad, torch.zeros(5, 5, device=device))
338
339
340if __name__ == "__main__":
341run_tests()
342