pytorch
1082 строки · 37.0 Кб
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import logging
5import math
6import operator
7import os
8import random
9import sys
10import tempfile
11from functools import reduce
12
13import torch
14import torch.distributed as c10d
15
16
17if not c10d.is_available() or not c10d.is_ucc_available():
18print("c10d UCC not available, skipping tests", file=sys.stderr)
19sys.exit(0)
20
21import test_c10d_common
22from test_c10d_common import (
23gpus_for_rank,
24ModuleForDdpCommHook,
25SparseGradientModule,
26Task,
27)
28
29import torch.distributed as dist
30import torch.nn.functional as F
31import torch.testing._internal.common_utils as common
32from torch import nn
33from torch.nn.parallel import DistributedDataParallel
34from torch.testing._internal.common_distributed import (
35MultiProcessTestCase,
36requires_ucc,
37skip_if_lt_x_gpu,
38verify_ddp_error_logged,
39)
40from torch.testing._internal.common_utils import (
41retry_on_connect_failures,
42run_tests,
43skip_but_pass_in_sandcastle,
44TestCase,
45)
46
47
48def simple_reduce_tests(rank, world_size):
49tests = [
50(
51c10d.ReduceOp.SUM,
52torch.tensor([rank + 1.0]),
53torch.tensor([float(world_size * (world_size + 1) / 2)]),
54),
55(
56c10d.ReduceOp.PRODUCT,
57torch.tensor([rank + 1.0]),
58torch.tensor([float(math.factorial(world_size))]),
59),
60(
61c10d.ReduceOp.MIN,
62torch.tensor([rank + 1.0]),
63torch.tensor([1.0]),
64),
65(
66c10d.ReduceOp.MAX,
67torch.tensor([rank + 1.0]),
68torch.tensor([world_size]),
69),
70]
71
72# Generate tests for BAND.
73# The bit that is set changes in every iteration to check
74# that the output changes accordingly.
75for i in range(4):
76vin = rank | (1 << i)
77vout = 1 << i
78tests.append(
79(
80c10d.ReduceOp.BAND,
81torch.tensor([vin], dtype=torch.int32),
82torch.tensor([vout], dtype=torch.int32),
83),
84)
85
86# Generate tests for BOR.
87# These emulate a larger world size per iteration by having every
88# rank contribute multiple values that are pre-OR'ed.
89for i in range(1, 5):
90vin = reduce(operator.or_, [rank * i + j for j in range(i)])
91vout = reduce(operator.or_, range(world_size * i))
92tests.append(
93(
94c10d.ReduceOp.BOR,
95torch.tensor([vin], dtype=torch.int32),
96torch.tensor([vout], dtype=torch.int32),
97),
98)
99
100# Generate tests for XOR.
101# These emulate a larger world size per iteration by having every
102# rank contribute multiple values that are pre-XOR'ed.
103for i in range(1, 5):
104vin = reduce(operator.xor, [rank * i + j for j in range(i)])
105vout = reduce(operator.xor, range(world_size * i))
106tests.append(
107(
108c10d.ReduceOp.BXOR,
109torch.tensor([vin], dtype=torch.int32),
110torch.tensor([vout], dtype=torch.int32),
111),
112)
113
114return tests
115
116
117class RendezvousEnvTest(TestCase):
118@requires_ucc()
119@retry_on_connect_failures
120def test_logging_init(self):
121os.environ["WORLD_SIZE"] = "1"
122os.environ["MASTER_ADDR"] = "127.0.0.1"
123os.environ["MASTER_PORT"] = str(common.find_free_port())
124os.environ["RANK"] = "0"
125
126previous_handlers = logging.root.handlers
127
128c10d.init_process_group(backend="ucc", init_method="env://")
129
130current_handlers = logging.root.handlers
131self.assertEqual(len(previous_handlers), len(current_handlers))
132for current, previous in zip(current_handlers, previous_handlers):
133self.assertEqual(current, previous)
134
135c10d.destroy_process_group()
136
137
138class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
139@requires_ucc()
140@retry_on_connect_failures
141def test_default_store_timeout_ucc(self):
142self._test_default_store_timeout("ucc")
143
144
145class ProcessGroupUCCTest(MultiProcessTestCase):
146def _create_process_group_ucc(self):
147store = c10d.FileStore(self.file_name, self.world_size)
148return c10d.ProcessGroupUCC(store, self.rank, self.world_size)
149
150def setUp(self):
151super().setUp()
152self._spawn_processes()
153
154def tearDown(self):
155super().tearDown()
156try:
157os.remove(self.file_name)
158except OSError:
159pass
160
161@requires_ucc()
162def test_empty_tensors(self):
163pg = self._create_process_group_ucc()
164
165xs = [torch.FloatTensor([])]
166fut = pg.broadcast(xs).get_future()
167fut.wait()
168output = fut.value()
169self.assertEqual(0, output[0].numel())
170self.assertEqual(xs[0], output[0], exact_dtype=False)
171
172# TODO: add error check testing
173
174def _test_broadcast_basics(self, fn):
175pg = self._create_process_group_ucc()
176
177def broadcast(xs, rootRank, rootTensor):
178opts = c10d.BroadcastOptions()
179opts.rootRank = rootRank
180opts.rootTensor = rootTensor
181fut = pg.broadcast(xs, opts).get_future()
182fut.wait()
183return fut.value()
184
185# Every rank is root once
186for i in range(self.world_size):
187# Run with 1 input tensor
188x = fn(torch.tensor([self.rank]))
189output = broadcast([x], i, 0)
190self.assertEqual(torch.tensor([i]), output[0], exact_dtype=False)
191
192# TODO: UCC currently does not support multi tensor input
193
194# Test overloaded convenience function
195x = torch.tensor([self.rank + 1.0])
196fut = pg.broadcast(x, root=0).get_future()
197fut.wait()
198result = fut.value()
199self.assertEqual(torch.tensor([1.0]), result[0])
200
201@requires_ucc()
202def test_broadcast_basics(self):
203self._test_broadcast_basics(lambda t: t.clone())
204
205# TODO: test_broadcast_basics_cuda times out locally
206
207def _test_allreduce_basics(self, fn):
208pg = self._create_process_group_ucc()
209
210# Single input tests
211tests = simple_reduce_tests(self.rank, self.world_size)
212for op, input, expected in tests:
213opts = c10d.AllreduceOptions()
214opts.reduceOp = op
215tensor = fn(input)
216fut = pg.allreduce([tensor], opts).get_future()
217fut.wait()
218result = fut.value()
219self.assertEqual(expected, result[0], exact_dtype=False)
220
221# TODO: UCC currently does not support multi tensor input
222
223# Test overloaded convenience function (defaults to using sum)
224x = fn(torch.tensor([self.rank + 1.0]))
225fut = pg.allreduce(x).get_future()
226fut.wait()
227result = fut.value()
228self.assertEqual(
229torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
230result[0],
231)
232
233@requires_ucc()
234def test_allreduce_basics(self):
235self._test_allreduce_basics(lambda t: t.clone())
236
237# TODO: test_allreduce_basics_cuda times out locally
238
239def _test_allgather_basics(self, fn):
240pg = self._create_process_group_ucc()
241
242# TODO: Run with N input tensor per rank; for now, UCC only supports single tensor input so N=1
243for n in [1]:
244input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)]
245output = [
246[fn(torch.tensor([-1])) for _ in range(n * self.world_size)]
247for _ in range(n)
248]
249expected_output = [
250[fn(torch.tensor([i])) for i in range(n * self.world_size)]
251for _ in range(n)
252]
253fut = pg.allgather(output, input).get_future()
254fut.wait()
255result = fut.value()
256if n == 1:
257result = [result]
258self.assertEqual(expected_output, result)
259
260def test_allgather_basics(self):
261self._test_allgather_basics(lambda t: t.clone())
262
263def _test_reduce_basics(self, fn):
264pg = self._create_process_group_ucc()
265for op, input, output in simple_reduce_tests(self.rank, self.world_size):
266for root in range(self.world_size):
267opts = c10d.ReduceOptions()
268opts.reduceOp = op
269opts.rootRank = root
270tmp = fn(input)
271fut = pg.reduce([tmp], opts).get_future()
272fut.wait()
273result = fut.value()
274if root == self.rank:
275self.assertEqual(output, result[0], exact_dtype=False)
276
277@requires_ucc()
278def test_reduce_basics(self):
279self._test_reduce_basics(lambda t: t.clone())
280
281# TODO: test_reduce_basics_cuda times out locally
282
283@requires_ucc()
284def test_send_recv_all_to_all(self):
285pg = self._create_process_group_ucc()
286
287# Preallocate tensors for input/output
288inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)]
289outputs = [torch.tensor([-1]) for _ in range(self.world_size)]
290
291# Issue sends
292send_work = []
293for i in range(self.world_size):
294if i == self.rank:
295continue
296send_work.append(pg.send([inputs[i]], i, 0))
297
298# Issue recvs
299recv_work = []
300for i in range(self.world_size):
301if i == self.rank:
302continue
303recv_work.append(pg.recv([outputs[i]], i, 0))
304
305# Wait for sends to complete
306for work in send_work:
307work.wait()
308self.assertTrue(work.is_completed())
309
310# Wait for recvs to complete
311for work in recv_work:
312work.wait()
313self.assertTrue(work.is_completed())
314
315# Test that every output other than our own contains the respective rank
316for i in range(self.world_size):
317if i == self.rank:
318continue
319self.assertEqual(torch.tensor([i]), outputs[i])
320
321# TODO: test_barrier_implies_wait fails with numerical mismatch, will investigate later
322@skip_but_pass_in_sandcastle("fails with numerical mismatch, skip for now")
323@requires_ucc()
324def test_barrier_implies_wait(self):
325pg = self._create_process_group_ucc()
326
327# Kick off allreduce operations
328size = (100, 100)
329num = 16
330tensors = [torch.full(size, float(i)) for i in range(num)]
331for tensor in tensors:
332# Note: leak the returned work handle
333pg.allreduce(tensor)
334
335# Barrier should ensure all previous work has completed
336pg.barrier().get_future().wait()
337
338for i, tensor in enumerate(tensors):
339self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
340
341
342class DistributedDataParallelTest(
343test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
344):
345def setUp(self):
346super().setUp()
347self._spawn_processes()
348
349def _get_process_group(self):
350store = self._get_store()
351c10d.init_process_group(
352"ucc", store=store, rank=self.rank, world_size=self.world_size
353)
354return c10d.distributed_c10d._get_default_group()
355
356def _test_ucc_backend(
357self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
358):
359process_group = self._get_process_group()
360self._test_ddp_with_process_group(
361process_group, devices, device_ids, multi_device, gradient_as_bucket_view
362)
363
364@requires_ucc()
365def test_ucc_backend_cpu_module(self):
366self._test_ucc_backend([torch.device("cpu")], None)
367
368@requires_ucc()
369def test_ucc_backend_cpu_module_grad_is_view(self):
370self._test_ucc_backend(
371[torch.device("cpu")], None, gradient_as_bucket_view=True
372)
373
374@requires_ucc()
375@skip_if_lt_x_gpu(2)
376def test_ucc_backend_1gpu_module_device_ids_integer_list(self):
377int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
378devices = [torch.device("cuda:" + str(i)) for i in int_devices]
379self._test_ucc_backend(devices, int_devices)
380
381@requires_ucc()
382@skip_if_lt_x_gpu(2)
383def test_ucc_backend_1gpu_module_device_ids_torch_device_list(self):
384int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
385devices = [torch.device("cuda:" + str(i)) for i in int_devices]
386self._test_ucc_backend(devices, devices)
387
388# TODO: test_ucc_backend_2gpu_module and test_ucc_backend_4gpu_module
389# require broadcast_coalesced which is not supported by ucc currently
390@skip_but_pass_in_sandcastle(
391"requires broadcast coalesced, which is not supported by ucc currently"
392)
393@requires_ucc()
394@skip_if_lt_x_gpu(4)
395def test_ucc_backend_2gpu_module(self):
396int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
397devices = [torch.device("cuda:" + str(i)) for i in int_devices]
398self._test_ucc_backend(devices, None, multi_device=True)
399
400@skip_but_pass_in_sandcastle(
401"requires broadcast coalesced, which is not supported by ucc currently"
402)
403@requires_ucc()
404@skip_if_lt_x_gpu(8)
405def test_ucc_backend_4gpu_module(self):
406int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
407devices = [torch.device("cuda:" + str(i)) for i in int_devices]
408self._test_ucc_backend(devices, None, multi_device=True)
409
410def _test_global_local_unused_params_grad(
411self, gradient_as_bucket_view=False, static_graph=False
412):
413"""
414By simulating a multi-task training, this test is to make sure:
4151) DDP does not touch the grad of globally unused parameters.
4162) DDP does update the grad of locally unused parameters.
417"""
418
419class GlobalLocalUnusedParamModule(nn.Module):
420def __init__(self) -> None:
421super().__init__()
422self.t0 = Task()
423self.t1 = Task()
424self.task_unused = Task()
425
426def task_parameters(self):
427return (self.t0.p, self.t1.p, self.task_unused.p)
428
429def forward(self, x, rank):
430return self.t0(x) if rank == 0 else self.t1(x)
431
432def run_and_verify_grad(model):
433# Run forward
434output = model(8, self.rank)
435
436# The grads of all parameters should be None at this point.
437t0_p, t1_p, task_unused_p = model.module.task_parameters()
438self.assertIsNone(t0_p.grad)
439self.assertIsNone(t1_p.grad)
440self.assertIsNone(task_unused_p.grad)
441
442# Run backward
443output.mean().backward()
444
445# Now locally unused parameter should have grad updated on all ranks.
446# However the globally unused parameter should still have None grad.
447self.assertIsNotNone(t0_p.grad)
448self.assertIsNotNone(t1_p.grad)
449self.assertIsNone(task_unused_p.grad)
450
451process_group = self._get_process_group()
452
453# Test on CPU
454cpu_model = DistributedDataParallel(
455GlobalLocalUnusedParamModule().cpu(),
456process_group=process_group,
457find_unused_parameters=True,
458gradient_as_bucket_view=gradient_as_bucket_view,
459static_graph=static_graph,
460)
461run_and_verify_grad(cpu_model)
462
463# Test on GPU
464device_id = gpus_for_rank(self.world_size)[self.rank][0]
465gpu_model = DistributedDataParallel(
466GlobalLocalUnusedParamModule().to(device_id),
467device_ids=[device_id],
468process_group=process_group,
469find_unused_parameters=True,
470gradient_as_bucket_view=gradient_as_bucket_view,
471static_graph=static_graph,
472)
473run_and_verify_grad(gpu_model)
474
475# TODO: times out
476@skip_but_pass_in_sandcastle("times out")
477@requires_ucc()
478@skip_if_lt_x_gpu(2)
479def test_global_local_unused_params_grad(self):
480self._test_global_local_unused_params_grad()
481
482# TODO: times out
483@skip_but_pass_in_sandcastle("times out")
484@requires_ucc()
485@skip_if_lt_x_gpu(2)
486def test_global_local_unused_params_grad_with_grad_is_view(self):
487self._test_global_local_unused_params_grad(gradient_as_bucket_view=True)
488
489# TODO: times out
490@skip_but_pass_in_sandcastle("times out")
491@requires_ucc()
492@skip_if_lt_x_gpu(2)
493def test_global_local_unused_params_grad_with_static_graph(self):
494self._test_global_local_unused_params_grad(static_graph=True)
495
496# TODO: times out
497@skip_but_pass_in_sandcastle("times out")
498@requires_ucc()
499@skip_if_lt_x_gpu(2)
500def test_find_unused_parameters_when_unused_parameters_empty(self):
501"""
502An empty unused_parameters array does not imply find_unused_parameters =
503false. This test makes sure that DDP allreduces unused parameters
504accordingly where the forward pass in some process uses all parameters.
505This unit test creates a module that uses all parameters in rank = 0, and
506has unused parameters in other ranks.
507"""
508
509class FindUnusedParamModule(nn.Module):
510def __init__(self) -> None:
511super().__init__()
512self.t0 = Task()
513self.t1 = Task()
514
515def task_parameters(self):
516return (self.t0.p, self.t1.p)
517
518def forward(self, x, rank):
519return self.t1(self.t0(x)) if rank == 0 else self.t1(x)
520
521def run_and_verify_grad(model):
522# Run forward
523output = model(8, self.rank)
524
525# The grads of all parameters should be None at this point.
526[self.assertIsNone(t_p.grad) for t_p in model.module.task_parameters()]
527
528# Run backward
529output.mean().backward()
530
531# Now locally unused parameter should have grad updated on all ranks.
532[self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()]
533
534process_group = self._get_process_group()
535
536# Test on CPU
537cpu_model = DistributedDataParallel(
538FindUnusedParamModule().cpu(),
539process_group=process_group,
540find_unused_parameters=True,
541)
542run_and_verify_grad(cpu_model)
543
544# Test on GPU
545device_id = gpus_for_rank(self.world_size)[self.rank][0]
546gpu_model = DistributedDataParallel(
547FindUnusedParamModule().to(device_id),
548device_ids=[device_id],
549process_group=process_group,
550find_unused_parameters=True,
551)
552run_and_verify_grad(gpu_model)
553
554@requires_ucc()
555def test_ignored_output(self):
556"""
557Test that the output of a model can be ignored and that there is no
558implicit requirement that `backward` gets called.
559"""
560process_group = self._get_process_group()
561
562class IgnoredOutput(nn.Module):
563def __init__(self) -> None:
564super().__init__()
565self.fc1 = nn.Linear(2, 10, bias=False)
566self.fc2 = nn.Linear(10, 4, bias=False)
567self.relu = nn.ReLU()
568
569def forward(self, x):
570x = self.relu(self.fc1(x))
571x = self.relu(self.fc2(x))
572return F.softmax(x, dim=1)
573
574model = DistributedDataParallel(
575IgnoredOutput().float(),
576process_group=process_group,
577)
578
579batch_size = 4
580criterion = nn.CrossEntropyLoss()
581input = torch.rand([batch_size, 2], dtype=torch.float)
582target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
583
584# Run a few iterations where we ignore the output.
585for _ in range(4):
586output = model(input)
587del output
588
589# Run a few iterations where we use the output.
590for _ in range(4):
591output = model(input)
592loss = criterion(output, target)
593loss.backward()
594
595@requires_ucc()
596def test_ignored_output_with_unused_parameters(self):
597"""
598Test that the output of a model can be ignored and that there is no
599implicit requirement that `backward` gets called, if not all model
600parameters participated in computing the model output.
601"""
602process_group = self._get_process_group()
603
604class IgnoredOutputWithUnusedParameters(nn.Module):
605def __init__(self) -> None:
606super().__init__()
607self.fc1 = nn.Linear(2, 10, bias=False)
608self.fc2 = nn.Linear(10, 4, bias=False)
609self.fc3 = nn.Linear(4, 4, bias=False)
610self.relu = nn.ReLU()
611
612def forward(self, x):
613x = self.relu(self.fc1(x))
614x = self.relu(self.fc2(x))
615return F.softmax(x, dim=1)
616
617model = DistributedDataParallel(
618IgnoredOutputWithUnusedParameters().float(),
619process_group=process_group,
620find_unused_parameters=True,
621)
622
623batch_size = 4
624criterion = nn.CrossEntropyLoss()
625input = torch.rand([batch_size, 2], dtype=torch.float)
626target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
627
628# Run a few iterations where we ignore the output.
629for _ in range(4):
630output = model(input)
631del output
632
633# Run a few iterations where we use the output.
634for _ in range(4):
635output = model(input)
636loss = criterion(output, target)
637loss.backward()
638
639def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):
640mult = 2
641batch_size = mult * self.world_size
642criterion = nn.CrossEntropyLoss()
643input = torch.randint(0, 10, [batch_size, 2])
644target = torch.randint(0, 10, [batch_size])
645
646# Run with entire batch against single process version
647criterion(vanilla_model(input), target).backward()
648
649# Run with partial batch against multi process version
650partial_input = input.split(mult)[self.rank]
651partial_target = target.split(mult)[self.rank]
652criterion(ddp_model(partial_input), partial_target).backward()
653
654# Check that the gradients are sparse and identical
655vanilla_parameter = next(vanilla_model.parameters())
656ddp_parameter = next(ddp_model.parameters())
657self.assertEqual(
658vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce()
659)
660
661@requires_ucc()
662@skip_if_lt_x_gpu(2)
663def test_save_load_checkpoint(self):
664dist.init_process_group(
665"ucc",
666init_method=f"file://{self.file_name}",
667world_size=self.world_size,
668rank=self.rank,
669)
670
671class TestModel(nn.Module):
672def __init__(self) -> None:
673super().__init__()
674self.fc1 = nn.Linear(2, 10, bias=False)
675self.fc2 = nn.Linear(10, 4, bias=False)
676self.relu = nn.ReLU()
677
678def forward(self, x):
679x = self.relu(self.fc1(x))
680x = self.relu(self.fc2(x))
681return F.softmax(x, dim=1)
682
683def train_loop(model, optimizer, iterations):
684for _ in range(iterations):
685optimizer.zero_grad()
686output = model(input)
687loss = criterion(output, target)
688loss.backward()
689optimizer.step()
690
691device_id = gpus_for_rank(self.world_size)[self.rank][0]
692
693model_withload = TestModel().float().to(device_id)
694model_withoutload = TestModel().float().to(device_id)
695
696ddp_withload = DistributedDataParallel(
697model_withload,
698device_ids=[device_id],
699)
700ddp_withoutload = DistributedDataParallel(
701model_withoutload,
702device_ids=[device_id],
703)
704
705# ensure that all the three models start with the same set of parameters. By default they are randomized on construction
706for p in ddp_withload.parameters():
707with torch.no_grad():
708p.zero_()
709for p in model_withload.parameters():
710with torch.no_grad():
711p.zero_()
712for p in ddp_withoutload.parameters():
713with torch.no_grad():
714p.zero_()
715
716batch_size = 4
717criterion = nn.CrossEntropyLoss()
718
719optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001)
720optimizer_non_ddp_withload = torch.optim.SGD(
721model_withload.parameters(), lr=0.001
722)
723optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001)
724
725input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)
726target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
727device_id
728)
729
730# run the model for 6 iterations, with a checkpoint in the middle
731train_loop(ddp_withload, optimizer_withload, 3)
732
733# zero out parameters of both DDP and non-DDP models and reload them from the DDP state dict
734checkpoint_path = tempfile.gettempdir() + "/model.checkpoint"
735if self.rank == 0:
736torch.save(ddp_withload.state_dict(), checkpoint_path)
737
738dist.barrier()
739map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
740ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
741
742for model in [ddp_withload, model_withload]:
743for p in ddp_withload.parameters():
744with torch.no_grad():
745p.zero_()
746ddp_withload.load_state_dict(ddp_state_dict)
747# the non-DDP model needs to first remove the prefix of "module." from the DDP state dict
748torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
749ddp_state_dict, "module."
750)
751model_withload.load_state_dict(ddp_state_dict)
752
753train_loop(ddp_withload, optimizer_withload, 3)
754train_loop(model_withload, optimizer_non_ddp_withload, 3)
755
756# re-run the model with the same inputs for 6 iterations with no checkpoint
757train_loop(ddp_withoutload, optimizer_withoutload, 6)
758
759for p_withload, p_withoutload, p_non_ddp_withload in zip(
760ddp_withload.parameters(),
761ddp_withoutload.parameters(),
762model_withload.parameters(),
763):
764self.assertEqual(p_withload, p_withoutload)
765self.assertEqual(p_non_ddp_withload, p_withoutload)
766
767def _test_sparse_gradients(self, gradient_as_bucket_view=False):
768process_group = self._get_process_group()
769
770# Ensure initialized weights and inputs are identical across processes
771torch.manual_seed(1337)
772
773vanilla_model = SparseGradientModule()
774ddp_model = DistributedDataParallel(
775copy.deepcopy(vanilla_model),
776process_group=process_group,
777gradient_as_bucket_view=gradient_as_bucket_view,
778)
779
780self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
781
782# TODO: backward pass: input tensor has to be dense
783@skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
784@requires_ucc()
785def test_sparse_gradients(self):
786self._test_sparse_gradients()
787
788# TODO: backward pass: input tensor has to be dense
789@skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
790@requires_ucc()
791def test_sparse_gradients_grad_is_view(self):
792self._test_sparse_gradients(gradient_as_bucket_view=True)
793
794@requires_ucc()
795def test_ddp_comm_hook_future_passing_cpu(self):
796"""
797This unit test verifies whether the Future object is passed properly.
798The callback function creates a Future object and sets a value to it.
799"""
800process_group = self._get_process_group()
801
802# Test on CPU
803cpu_model = DistributedDataParallel(
804ModuleForDdpCommHook().cpu(), process_group=process_group
805)
806
807# Register DDP Communication Hook
808cpu_model.register_comm_hook(None, self._simple_hook)
809
810# check whether the grads are equal to what then callback returns.
811# without the comm_hook, result would be 0.25 * torch.ones(2, 2).
812self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2))
813
814def _gpu_model_with_ddp_comm_hook(
815self, process_group, hook=None, gradient_as_bucket_view=False, state=None
816):
817device_id = gpus_for_rank(self.world_size)[self.rank][0]
818gpu_model = DistributedDataParallel(
819ModuleForDdpCommHook().to(device_id),
820device_ids=[device_id],
821process_group=process_group,
822gradient_as_bucket_view=gradient_as_bucket_view,
823)
824
825# Register a DDP communication hook if any.
826if hook is not None:
827gpu_model.register_comm_hook(state, hook)
828
829return gpu_model
830
831@requires_ucc()
832@skip_if_lt_x_gpu(2)
833def test_ddp_comm_hook_future_passing_gpu_ucc(self):
834"""
835This unit test verifies whether the Future object is passed properly using ucc backend.
836The hook callback function creates a Future object and sets a value to it.
837"""
838process_group = self._get_process_group()
839
840# Get GPU model with simple_hook registered.
841gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)
842
843# check whether the grads are equal to what simple_hook's then callback returns.
844# without the comm_hook, result would be 0.25 * torch.ones(2, 2).
845self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
846
847@requires_ucc()
848def test_ddp_invalid_comm_hook_init(self):
849"""
850This unit test makes sure that register_comm_hook properly checks the format
851of hook defined by user. The Python hook must be callable. This test also
852checks whether bucket annotation checked properly if defined.
853"""
854process_group = self._get_process_group()
855
856model = DistributedDataParallel(
857ModuleForDdpCommHook(), process_group=process_group
858)
859
860with self.assertRaisesRegex(TypeError, "Communication hook must be callable."):
861model.register_comm_hook(state=None, hook=1)
862
863with self.assertRaisesRegex(
864ValueError, "bucket annotation should be dist.GradBucket."
865):
866
867def comm_hook(
868state: object, bucket: int
869) -> torch.futures.Future[torch.Tensor]:
870return torch.futures.Future()
871
872model.register_comm_hook(state=None, hook=comm_hook)
873
874@requires_ucc()
875def test_ddp_invalid_comm_hook_return_type(self):
876"""
877This test checks whether return annotation checked properly if defined. It also
878checks whether an internal error is thrown if return type is incorrect and user
879hasn't specified any return type annotation.
880"""
881process_group = self._get_process_group()
882
883model = DistributedDataParallel(
884ModuleForDdpCommHook(), process_group=process_group
885)
886
887expected_err = (
888"Communication hook: return annotation should be torch.futures.Future"
889)
890with self.assertRaisesRegex(
891ValueError,
892expected_err,
893):
894
895def comm_hook(state: object, bucket: dist.GradBucket) -> int:
896return torch.futures.Future()
897
898model.register_comm_hook(state=None, hook=comm_hook)
899
900verify_ddp_error_logged(model, expected_err)
901
902with self.assertRaisesRegex(
903RuntimeError,
904"callback must return a torch.futures.Future object, but got",
905):
906
907def comm_hook(state: object, bucket: dist.GradBucket):
908return 1
909
910model.register_comm_hook(state=None, hook=comm_hook)
911
912# Run forward
913output = model(8, self.rank)
914
915# Run backward
916output.mean().backward()
917
918@requires_ucc()
919def test_ddp_comm_hook_register_just_once(self):
920"""
921DDP communication hook can only be registered once. This test validates whether
922the error is thrown properly when register_comm_hook is called more than once.
923"""
924process_group = self._get_process_group()
925
926model = DistributedDataParallel(
927ModuleForDdpCommHook(), process_group=process_group
928)
929
930def dummy_hook(state, bucket):
931fut = torch.futures.Future()
932fut.set_result([bucket.buffer()])
933return fut
934
935model.register_comm_hook(None, dummy_hook)
936
937with self.assertRaisesRegex(
938RuntimeError,
939"register_comm_hook or register_builtin_comm_hook can only be called once.",
940):
941model.register_comm_hook(None, dummy_hook)
942
943# TODO: backward pass: input tensor must be dense
944@skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
945@requires_ucc()
946def test_ddp_comm_hook_sparse_gradients(self):
947"""
948Runs "test_sparse_gradients" unit test with DDP communication hook. We define a
949simple hook that does allreduce and works with ucc backend for this test.
950"""
951process_group = self._get_process_group()
952
953# Ensure initialized weights and inputs are identical across processes
954torch.manual_seed(1337)
955
956vanilla_model = SparseGradientModule()
957ddp_model = DistributedDataParallel(
958copy.deepcopy(vanilla_model),
959process_group=process_group,
960)
961
962def allreduce_hook_ucc(
963state: object, bucket: dist.GradBucket
964) -> torch.futures.Future[torch.Tensor]:
965def div_by_world_size(fut):
966# Divide the result by 2 * world_size.
967return fut.wait()[0] / self.world_size
968
969# Prepare allreduced grad bucket tensors by running an async work.
970fut = process_group.allreduce([bucket.buffer()]).get_future()
971return fut.then(div_by_world_size)
972
973ddp_model.register_comm_hook(None, allreduce_hook_ucc)
974
975self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
976
977
978class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
979@property
980def device(self):
981return "cpu"
982
983def setUp(self):
984super().setUp()
985self._spawn_processes()
986
987def tearDown(self):
988super().tearDown()
989try:
990os.remove(self.file_name)
991except OSError:
992pass
993
994@requires_ucc()
995@skip_if_lt_x_gpu(2)
996def test_sequence_num_set_default_pg_ucc(self):
997self._test_sequence_num_set_default_pg(backend="ucc")
998
999@requires_ucc()
1000@skip_if_lt_x_gpu(2)
1001def test_sequence_num_set_ucc_new_group(self):
1002self._test_sequence_num_set_new_group(backend="ucc")
1003
1004@skip_if_lt_x_gpu(2)
1005@requires_ucc()
1006def test_sequence_num_incremented_ucc_default(self):
1007self._test_sequence_num_incremented_default_group("ucc")
1008
1009@skip_if_lt_x_gpu(4)
1010@requires_ucc()
1011def test_sequence_num_incremented_ucc_subgroup(self):
1012if self.world_size < 4:
1013return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
1014self._test_sequence_num_incremented_subgroup("ucc")
1015
1016@skip_but_pass_in_sandcastle("Fails on M60")
1017@requires_ucc()
1018def test_ucc_barrier_device_ids(self):
1019store = c10d.FileStore(self.file_name, self.world_size)
1020c10d.init_process_group(
1021backend="ucc", rank=self.rank, world_size=self.world_size, store=store
1022)
1023
1024with self.assertRaisesRegex(RuntimeError, "device_ids not supported"):
1025c10d.barrier(device_ids=[self.rank])
1026
1027@skip_but_pass_in_sandcastle("Fails on M60")
1028@skip_if_lt_x_gpu(2)
1029@requires_ucc()
1030def test_ucc_warn_not_in_group(self):
1031self._test_warn_not_in_group(backend="ucc")
1032
1033@skip_if_lt_x_gpu(2)
1034@requires_ucc()
1035def test_ucc_rank_membership(self):
1036self._test_rank_membership(backend="ucc")
1037
1038@skip_if_lt_x_gpu(2)
1039@requires_ucc()
1040def test_tensor_dtype_mismatch(self):
1041self._test_tensor_dtype_mismatch(backend="ucc")
1042
1043@skip_if_lt_x_gpu(2)
1044@requires_ucc()
1045def test_tensor_dtype_complex(self):
1046self._test_tensor_dtype_complex(backend="ucc")
1047
1048
1049class UccProcessGroupWithDispatchedCollectivesTests(
1050test_c10d_common.ProcessGroupWithDispatchedCollectivesTests
1051):
1052@skip_but_pass_in_sandcastle("Fails on M60")
1053@requires_ucc()
1054@skip_if_lt_x_gpu(1)
1055def test_collectives(self):
1056# includes reduce, broadcast, all_reduce, all_gather, reduce_scatter, barrier, all_to_all, scatter
1057self._test_collectives(backend="ucc")
1058
1059@skip_but_pass_in_sandcastle("Fails on M60")
1060@requires_ucc()
1061@skip_if_lt_x_gpu(1)
1062def test_allgather_base(self):
1063store = dist.FileStore(self.file_name, self.world_size)
1064dist.init_process_group(
1065"ucc",
1066world_size=self.world_size,
1067rank=self.rank,
1068store=store,
1069)
1070device = "cuda"
1071tensor = torch.ones(10, 10, device=torch.device(device))
1072output_tensor = torch.zeros(10, 10, device=torch.device(device))
1073dist.all_gather_into_tensor(output_tensor, tensor)
1074self.assertEqual(output_tensor, tensor)
1075
1076
1077if __name__ == "__main__":
1078assert (
1079not torch.cuda._initialized
1080), "test_distributed must not have initialized CUDA context on main process"
1081
1082run_tests()
1083