pytorch
286 строк · 11.5 Кб
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import os
5import sys
6import tempfile
7
8import test_c10d_spawn
9import torch
10import torch.distributed as c10d
11import torch.nn as nn
12from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
13from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
14from torch.testing._internal.common_distributed import requires_gloo, \
15create_device, skip_if_lt_x_gpu
16from torch.testing._internal.common_utils import TestCase, run_tests, skip_but_pass_in_sandcastle_if, TEST_WITH_DEV_DBG_ASAN
17
18# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
19if sys.version_info < (3, 9):
20class ProcessGroupShareTensorTest(test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase):
21
22@classmethod
23def opts(cls, threads=2):
24opts = c10d.ProcessGroupGloo._Options()
25opts._timeout = 5.0
26opts._devices = [create_device(interface='lo')]
27opts._threads = threads
28return opts
29
30@classmethod
31def _init_pg_gloo(cls, rank, filename, world_size):
32store = c10d.FileStore(filename, world_size)
33backend = c10d.ProcessGroupGloo(
34store, rank, world_size, ProcessGroupShareTensorTest.opts())
35# set process group backends manually
36c10d.init_process_group(backend="gloo", store=store, rank=rank, world_size=world_size)
37pg = c10d.distributed_c10d._get_default_group()
38pg._register_backend(torch.device("cpu"), c10d.ProcessGroup.BackendType.GLOO, backend)
39pg._register_backend(torch.device("cuda"), c10d.ProcessGroup.BackendType.GLOO, backend)
40
41return pg
42
43@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
44def test_shared_broadcast_gloo(self):
45self._test_multiprocess(
46ProcessGroupShareTensorTest._test_broadcast_process,
47[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
48ProcessGroupShareTensorTest._init_pg_gloo,
491)
50
51@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
52def test_shared_allreduce_gloo(self):
53self._test_multiprocess(
54ProcessGroupShareTensorTest._test_allreduce_process,
55[torch.ones(2, 2).to(i) for i in range(self.world_size)],
56ProcessGroupShareTensorTest._init_pg_gloo,
571)
58
59@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
60def test_shared_allgather_gloo(self):
61self._test_multiprocess(
62ProcessGroupShareTensorTest._test_allgather_process,
63[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
64ProcessGroupShareTensorTest._init_pg_gloo,
65self.world_size)
66
67@classmethod
68def _test_allgather_chunk_process(
69cls, rank, filename, shared_tensor, world_size, init_pg, c2p, p2c):
70pg = init_pg(rank, filename, world_size)
71chunks = torch.chunk(shared_tensor, world_size, dim=0)
72x = chunks[rank]
73ys = [torch.zeros_like(x) for _ in range(world_size)]
74pg.allgather(ys, x).wait()
75c2p.put((rank, chunks[0].to("cpu"), ys[0].to("cpu")))
76c2p.put((rank, chunks[1].to("cpu"), ys[1].to("cpu")))
77p2c.get()
78
79@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
80def test_shared_allgather_chunk_gloo(self):
81self._test_multiprocess(
82ProcessGroupShareTensorTest._test_allgather_chunk_process,
83torch.tensor(range(4)).reshape(2, 2),
84ProcessGroupShareTensorTest._init_pg_gloo,
85self.world_size)
86
87
88class DistributedDataParallelSingleProcessTest(TestCase):
89def setUp(self):
90self.rank = 0
91self.world_size = 1
92self.file = tempfile.NamedTemporaryFile(delete=False) # noqa: P201
93
94def tearDown(self):
95try:
96os.remove(self.file.name)
97except OSError:
98pass
99
100def _test_base(self, net, inp, check_allclose=True):
101store = c10d.FileStore(self.file.name, self.world_size)
102c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)
103process_group = c10d.distributed_c10d._get_default_group()
104if inp[0].is_cuda:
105device_ids = [torch.cuda.current_device()]
106else:
107device_ids = None
108
109ddp = nn.parallel.DistributedDataParallel(
110copy.deepcopy(net),
111device_ids=device_ids,
112process_group=process_group
113)
114
115net_opt = torch.optim.Adam(net.parameters(), lr=0.001)
116ddp_opt = torch.optim.Adam(ddp.parameters(), lr=0.001)
117
118for i, j in zip(ddp.parameters(), net.parameters()):
119self.assertTrue(i.allclose(j))
120
121for _ in range(10):
122net_out = net(*inp)
123ddp_out = ddp(*inp)
124
125net_out.sum().backward()
126ddp_out.sum().backward()
127
128net_opt.step()
129ddp_opt.step()
130
131if check_allclose:
132for i, j in zip(ddp.parameters(), net.parameters()):
133self.assertTrue(i.allclose(j))
134
135@requires_gloo()
136def test_cpu(self):
137self._test_base(nn.Linear(2, 2), [torch.randn(30, 2)])
138
139@requires_gloo()
140@skip_but_pass_in_sandcastle_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
141def test_cuda(self):
142self._test_base(nn.Linear(2, 2).to(0), [torch.randn(30, 2).to(0)])
143
144@requires_gloo()
145@skip_but_pass_in_sandcastle_if(not TEST_CUDA, "At least 1 CUDA GPUS needed")
146def test_rnn(self):
147# This test is inspired by the bug reported in
148# https://github.com/pytorch/pytorch/issues/36268
149BATCH_SIZE = 12 # Divisible by 2, 3, 4
150INPUT_DIM = 256
151OUTPUT_DIM = 256
152HIDDEN_DIM = 256
153N_LAYERS = 3
154SEQ_LEN = 100
155
156class Net(nn.Module):
157def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
158super().__init__()
159self.input_dim = input_dim
160self.hidden_dim = hidden_dim
161self.output_dim = output_dim
162self.hidden_layers = hidden_layers
163
164self.lstm = nn.LSTM(input_dim, hidden_dim, hidden_layers, batch_first=True)
165self.h2o = nn.Linear(hidden_dim, output_dim)
166
167def forward(self, x, y):
168self.lstm.flatten_parameters()
169h_t, _ = self.lstm(x)
170output = self.h2o(h_t)
171loss = nn.functional.mse_loss(output, y)
172return loss
173
174net = Net(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS).to(0)
175inp = [
176torch.randn((BATCH_SIZE, SEQ_LEN, INPUT_DIM)).to(0),
177torch.rand((BATCH_SIZE, SEQ_LEN, OUTPUT_DIM)).to(0)
178]
179
180# Not checking result allclose as the parameter inconsistency exist
181# prior to this change. See #37079
182self._test_base(net, inp, check_allclose=False)
183
184
185# Skip dev-asan as torch + multiprocessing spawn have known issues
186if not TEST_WITH_DEV_DBG_ASAN:
187class TestDistributedNNFunctionsGloo(TestDistributedNNFunctions):
188# Test Common Ops First.
189@requires_gloo()
190@skip_if_lt_x_gpu(2)
191@skip_but_pass_in_sandcastle_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
192def test_broadcast(self):
193self._test_broadcast("gloo")
194
195@requires_gloo()
196@skip_if_lt_x_gpu(2)
197@skip_but_pass_in_sandcastle_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
198def test_reduce(self):
199self._test_reduce("gloo")
200
201@requires_gloo()
202@skip_if_lt_x_gpu(2)
203@skip_but_pass_in_sandcastle_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
204def test_allreduce(self):
205self._test_allreduce("gloo")
206
207@requires_gloo()
208@skip_if_lt_x_gpu(2)
209@skip_but_pass_in_sandcastle_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
210def test_all_gather(self):
211self._test_all_gather("gloo")
212
213@requires_gloo()
214@skip_if_lt_x_gpu(2)
215@skip_but_pass_in_sandcastle_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
216def test_all_to_all(self):
217self._test_all_to_all("gloo")
218
219@requires_gloo()
220@skip_if_lt_x_gpu(2)
221@skip_but_pass_in_sandcastle_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
222def test_all_to_all_single(self):
223self._test_all_to_all_single("gloo")
224
225# Test Ops only supported in GLOO.
226@requires_gloo()
227@skip_if_lt_x_gpu(2)
228@skip_but_pass_in_sandcastle_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
229def test_gather(self):
230store = c10d.FileStore(self.file_name, self.world_size)
231# This is required because these functions calls directly to the .dist and needs
232# the world to be initialized
233c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
234device = torch.device(f"cuda:{self.rank}")
235x = torch.ones(5, 5, device=device) + self.rank
236x.requires_grad = True
237tensors = torch.distributed.nn.gather(x, 1)
238if self.rank == 1:
239for i, t in enumerate(tensors):
240self.assertEqual(t, torch.ones(5, 5, device=device) + i)
241elif self.rank == 0:
242for i, t in enumerate(tensors):
243zeros = torch.zeros(5, 5, device=device)
244self.assertEqual(t, zeros)
245y = torch.sum(torch.stack(tensors), axis=0)
246z = y.sin().sum()
247z.backward()
248
249# Test gradient
250x_s = 3 * torch.ones(5, 5, device=device)
251self.assertEqual(x.grad, x_s.cos())
252
253@requires_gloo()
254@skip_if_lt_x_gpu(2)
255@skip_but_pass_in_sandcastle_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
256def test_scatter(self):
257store = c10d.FileStore(self.file_name, self.world_size)
258# This is required because these functions calls directly to the .dist and needs
259# the world to be initialized
260c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
261device = torch.device(f"cuda:{self.rank}")
262x0 = torch.ones(5, 5, device=device)
263x1 = torch.ones(5, 5, device=device) + 1
264x0.requires_grad = True
265x1.requires_grad = True
266
267y = torch.distributed.nn.scatter([x0, x1], 1)
268if self.rank == 1:
269self.assertEqual(y, 1 + torch.ones(5, 5, device=device))
270elif self.rank == 0:
271self.assertEqual(y, torch.ones(5, 5, device=device))
272z = y.sin().sum()
273z.backward()
274
275# Test gradient
276if self.rank == 1:
277x0_s = torch.ones(5, 5, device=device).cos()
278x1_s = (2 * torch.ones(5, 5, device=device)).cos()
279self.assertEqual(x0.grad, x0_s)
280self.assertEqual(x1.grad, x1_s)
281if self.rank == 0:
282self.assertEqual(x0.grad, torch.zeros(5, 5, device=device))
283
284
285if __name__ == '__main__':
286run_tests()
287