pytorch
1141 строка · 38.9 Кб
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import logging
5import math
6import operator
7import os
8import random
9import sys
10import tempfile
11from functools import reduce
12
13import torch
14import torch.distributed as c10d
15
16if not c10d.is_available() or not c10d.is_ucc_available():
17print("c10d UCC not available, skipping tests", file=sys.stderr)
18sys.exit(0)
19
20import test_c10d_common
21import torch.distributed as dist
22import torch.nn.functional as F
23import torch.testing._internal.common_utils as common
24from test_c10d_common import (
25gpus_for_rank,
26Task,
27ModuleForDdpCommHook,
28SparseGradientModule,
29)
30from torch import nn
31from torch.nn.parallel import DistributedDataParallel
32from torch.testing._internal.common_distributed import (
33MultiProcessTestCase,
34requires_ucc,
35skip_if_lt_x_gpu,
36verify_ddp_error_logged,
37)
38from torch.testing._internal.common_utils import (
39TestCase,
40run_tests,
41retry_on_connect_failures,
42skip_but_pass_in_sandcastle,
43)
44
45
46def simple_reduce_tests(rank, world_size):
47tests = [
48(
49c10d.ReduceOp.SUM,
50torch.tensor([rank + 1.0]),
51torch.tensor([float(world_size * (world_size + 1) / 2)]),
52),
53(
54c10d.ReduceOp.PRODUCT,
55torch.tensor([rank + 1.0]),
56torch.tensor([float(math.factorial(world_size))]),
57),
58(
59c10d.ReduceOp.MIN,
60torch.tensor([rank + 1.0]),
61torch.tensor([1.0]),
62),
63(
64c10d.ReduceOp.MAX,
65torch.tensor([rank + 1.0]),
66torch.tensor([world_size]),
67),
68]
69
70# Generate tests for BAND.
71# The bit that is set changes in every iteration to check
72# that the output changes accordingly.
73for i in range(4):
74vin = rank | (1 << i)
75vout = 1 << i
76tests.append(
77(
78c10d.ReduceOp.BAND,
79torch.tensor([vin], dtype=torch.int32),
80torch.tensor([vout], dtype=torch.int32),
81),
82)
83
84# Generate tests for BOR.
85# These emulate a larger world size per iteration by having every
86# rank contribute multiple values that are pre-OR'ed.
87for i in range(1, 5):
88vin = reduce(operator.or_, [rank * i + j for j in range(i)])
89vout = reduce(operator.or_, range(world_size * i))
90tests.append(
91(
92c10d.ReduceOp.BOR,
93torch.tensor([vin], dtype=torch.int32),
94torch.tensor([vout], dtype=torch.int32),
95),
96)
97
98# Generate tests for XOR.
99# These emulate a larger world size per iteration by having every
100# rank contribute multiple values that are pre-XOR'ed.
101for i in range(1, 5):
102vin = reduce(operator.xor, [rank * i + j for j in range(i)])
103vout = reduce(operator.xor, range(world_size * i))
104tests.append(
105(
106c10d.ReduceOp.BXOR,
107torch.tensor([vin], dtype=torch.int32),
108torch.tensor([vout], dtype=torch.int32),
109),
110)
111
112return tests
113
114
115class RendezvousEnvTest(TestCase):
116@requires_ucc()
117@retry_on_connect_failures
118def test_logging_init(self):
119os.environ["WORLD_SIZE"] = "1"
120os.environ["MASTER_ADDR"] = "127.0.0.1"
121os.environ["MASTER_PORT"] = str(common.find_free_port())
122os.environ["RANK"] = "0"
123
124previous_handlers = logging.root.handlers
125
126c10d.init_process_group(backend="ucc", init_method="env://")
127
128current_handlers = logging.root.handlers
129self.assertEqual(len(previous_handlers), len(current_handlers))
130for current, previous in zip(current_handlers, previous_handlers):
131self.assertEqual(current, previous)
132
133c10d.destroy_process_group()
134
135
136class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
137@requires_ucc()
138@retry_on_connect_failures
139def test_default_store_timeout_ucc(self):
140self._test_default_store_timeout("ucc")
141
142
143class ProcessGroupUCCTest(MultiProcessTestCase):
144def _create_process_group_ucc(self):
145store = c10d.FileStore(self.file_name, self.world_size)
146return c10d.ProcessGroupUCC(store, self.rank, self.world_size)
147
148def setUp(self):
149super().setUp()
150self._spawn_processes()
151
152def tearDown(self):
153super().tearDown()
154try:
155os.remove(self.file_name)
156except OSError:
157pass
158
159@requires_ucc()
160def test_empty_tensors(self):
161pg = self._create_process_group_ucc()
162
163xs = [torch.FloatTensor([])]
164fut = pg.broadcast(xs).get_future()
165fut.wait()
166output = fut.value()
167self.assertEqual(0, output[0].numel())
168self.assertEqual(xs[0], output[0], exact_dtype=False)
169
170# TODO: add error check testing
171
172def _test_broadcast_basics(self, fn):
173pg = self._create_process_group_ucc()
174
175def broadcast(xs, rootRank, rootTensor):
176opts = c10d.BroadcastOptions()
177opts.rootRank = rootRank
178opts.rootTensor = rootTensor
179fut = pg.broadcast(xs, opts).get_future()
180fut.wait()
181return fut.value()
182
183# Every rank is root once
184for i in range(self.world_size):
185# Run with 1 input tensor
186x = fn(torch.tensor([self.rank]))
187output = broadcast([x], i, 0)
188self.assertEqual(torch.tensor([i]), output[0], exact_dtype=False)
189
190# TODO: UCC currently does not support multi tensor input
191
192# Test overloaded convenience function
193x = torch.tensor([self.rank + 1.0])
194fut = pg.broadcast(x, root=0).get_future()
195fut.wait()
196result = fut.value()
197self.assertEqual(torch.tensor([1.0]), result[0])
198
199@requires_ucc()
200def test_broadcast_basics(self):
201self._test_broadcast_basics(lambda t: t.clone())
202
203# TODO: test_broadcast_basics_cuda times out locally
204
205def _test_allreduce_basics(self, fn):
206pg = self._create_process_group_ucc()
207
208# Single input tests
209tests = simple_reduce_tests(self.rank, self.world_size)
210for (op, input, expected) in tests:
211opts = c10d.AllreduceOptions()
212opts.reduceOp = op
213tensor = fn(input)
214fut = pg.allreduce([tensor], opts).get_future()
215fut.wait()
216result = fut.value()
217self.assertEqual(expected, result[0], exact_dtype=False)
218
219# TODO: UCC currently does not support multi tensor input
220
221# Test overloaded convenience function (defaults to using sum)
222x = fn(torch.tensor([self.rank + 1.0]))
223fut = pg.allreduce(x).get_future()
224fut.wait()
225result = fut.value()
226self.assertEqual(
227torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
228result[0],
229)
230
231@requires_ucc()
232def test_allreduce_basics(self):
233self._test_allreduce_basics(lambda t: t.clone())
234
235# TODO: test_allreduce_basics_cuda times out locally
236
237def _test_allgather_basics(self, fn):
238pg = self._create_process_group_ucc()
239
240# TODO: Run with N input tensor per rank; for now, UCC only supports single tensor input so N=1
241for n in [1]:
242input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)]
243output = [
244[fn(torch.tensor([-1])) for _ in range(n * self.world_size)]
245for _ in range(n)
246]
247expected_output = [
248[fn(torch.tensor([i])) for i in range(n * self.world_size)]
249for _ in range(n)
250]
251fut = pg.allgather(output, input).get_future()
252fut.wait()
253result = fut.value()
254if n == 1:
255result = [result]
256self.assertEqual(expected_output, result)
257
258def test_allgather_basics(self):
259self._test_allgather_basics(lambda t: t.clone())
260
261def _test_reduce_basics(self, fn):
262pg = self._create_process_group_ucc()
263for (op, input, output) in simple_reduce_tests(self.rank, self.world_size):
264for root in range(self.world_size):
265opts = c10d.ReduceOptions()
266opts.reduceOp = op
267opts.rootRank = root
268tmp = fn(input)
269fut = pg.reduce([tmp], opts).get_future()
270fut.wait()
271result = fut.value()
272if root == self.rank:
273self.assertEqual(output, result[0], exact_dtype=False)
274
275@requires_ucc()
276def test_reduce_basics(self):
277self._test_reduce_basics(lambda t: t.clone())
278
279# TODO: test_reduce_basics_cuda times out locally
280
281@requires_ucc()
282def test_send_recv_all_to_all(self):
283pg = self._create_process_group_ucc()
284
285# Preallocate tensors for input/output
286inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)]
287outputs = [torch.tensor([-1]) for _ in range(self.world_size)]
288
289# Issue sends
290send_work = []
291for i in range(self.world_size):
292if i == self.rank:
293continue
294send_work.append(pg.send([inputs[i]], i, 0))
295
296# Issue recvs
297recv_work = []
298for i in range(self.world_size):
299if i == self.rank:
300continue
301recv_work.append(pg.recv([outputs[i]], i, 0))
302
303# Wait for sends to complete
304for work in send_work:
305work.wait()
306self.assertTrue(work.is_completed())
307
308# Wait for recvs to complete
309for work in recv_work:
310work.wait()
311self.assertTrue(work.is_completed())
312
313# Test that every output other than our own contains the respective rank
314for i in range(self.world_size):
315if i == self.rank:
316continue
317self.assertEqual(torch.tensor([i]), outputs[i])
318
319# TODO: test_barrier_implies_wait fails with numerical mismatch, will investigate later
320@skip_but_pass_in_sandcastle("fails with numerical mismatch, skip for now")
321@requires_ucc()
322def test_barrier_implies_wait(self):
323pg = self._create_process_group_ucc()
324
325# Kick off allreduce operations
326size = (100, 100)
327num = 16
328tensors = [torch.full(size, float(i)) for i in range(num)]
329for tensor in tensors:
330# Note: leak the returned work handle
331pg.allreduce(tensor)
332
333# Barrier should ensure all previous work has completed
334pg.barrier().get_future().wait()
335
336for i, tensor in enumerate(tensors):
337self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
338
339
340class DistributedDataParallelTest(
341test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
342):
343def setUp(self):
344super().setUp()
345self._spawn_processes()
346
347def _get_process_group(self):
348store = self._get_store()
349c10d.init_process_group("ucc", store=store, rank=self.rank, world_size=self.world_size)
350return c10d.distributed_c10d._get_default_group()
351
352def _test_ucc_backend(
353self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
354):
355process_group = self._get_process_group()
356self._test_ddp_with_process_group(
357process_group, devices, device_ids, multi_device, gradient_as_bucket_view
358)
359
360@requires_ucc()
361def test_ucc_backend_cpu_module(self):
362self._test_ucc_backend([torch.device("cpu")], None)
363
364@requires_ucc()
365def test_ucc_backend_cpu_module_grad_is_view(self):
366self._test_ucc_backend(
367[torch.device("cpu")], None, gradient_as_bucket_view=True
368)
369
370@requires_ucc()
371@skip_if_lt_x_gpu(2)
372def test_ucc_backend_1gpu_module_device_ids_integer_list(self):
373int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
374devices = [torch.device("cuda:" + str(i)) for i in int_devices]
375self._test_ucc_backend(devices, int_devices)
376
377@requires_ucc()
378@skip_if_lt_x_gpu(2)
379def test_ucc_backend_1gpu_module_device_ids_torch_device_list(self):
380int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
381devices = [torch.device("cuda:" + str(i)) for i in int_devices]
382self._test_ucc_backend(devices, devices)
383
384# TODO: test_ucc_backend_2gpu_module and test_ucc_backend_4gpu_module
385# require broadcast_coalesced which is not supported by ucc currently
386@skip_but_pass_in_sandcastle("requires broadcast coalesced, which is not supported by ucc currently")
387@requires_ucc()
388@skip_if_lt_x_gpu(4)
389def test_ucc_backend_2gpu_module(self):
390int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
391devices = [torch.device("cuda:" + str(i)) for i in int_devices]
392self._test_ucc_backend(devices, None, multi_device=True)
393
394@skip_but_pass_in_sandcastle("requires broadcast coalesced, which is not supported by ucc currently")
395@requires_ucc()
396@skip_if_lt_x_gpu(8)
397def test_ucc_backend_4gpu_module(self):
398int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
399devices = [torch.device("cuda:" + str(i)) for i in int_devices]
400self._test_ucc_backend(devices, None, multi_device=True)
401
402def _test_global_local_unused_params_grad(
403self, gradient_as_bucket_view=False, static_graph=False
404):
405"""
406By simulating a multi-task training, this test is to make sure:
4071) DDP does not touch the grad of globally unused parameters.
4082) DDP does update the grad of locally unused parameters.
409"""
410
411class GlobalLocalUnusedParamModule(nn.Module):
412def __init__(self):
413super().__init__()
414self.t0 = Task()
415self.t1 = Task()
416self.task_unused = Task()
417
418def task_parameters(self):
419return (self.t0.p, self.t1.p, self.task_unused.p)
420
421def forward(self, x, rank):
422return self.t0(x) if rank == 0 else self.t1(x)
423
424def run_and_verify_grad(model):
425# Run forward
426output = model(8, self.rank)
427
428# The grads of all parameters should be None at this point.
429t0_p, t1_p, task_unused_p = model.module.task_parameters()
430self.assertIsNone(t0_p.grad)
431self.assertIsNone(t1_p.grad)
432self.assertIsNone(task_unused_p.grad)
433
434# Run backward
435output.mean().backward()
436
437# Now locally unused parameter should have grad updated on all ranks.
438# However the globally unused parameter should still have None grad.
439self.assertIsNotNone(t0_p.grad)
440self.assertIsNotNone(t1_p.grad)
441self.assertIsNone(task_unused_p.grad)
442
443process_group = self._get_process_group()
444
445# Test on CPU
446cpu_model = DistributedDataParallel(
447GlobalLocalUnusedParamModule().cpu(),
448process_group=process_group,
449find_unused_parameters=True,
450gradient_as_bucket_view=gradient_as_bucket_view,
451static_graph=static_graph,
452)
453run_and_verify_grad(cpu_model)
454
455# Test on GPU
456device_id = gpus_for_rank(self.world_size)[self.rank][0]
457gpu_model = DistributedDataParallel(
458GlobalLocalUnusedParamModule().to(device_id),
459device_ids=[device_id],
460process_group=process_group,
461find_unused_parameters=True,
462gradient_as_bucket_view=gradient_as_bucket_view,
463static_graph=static_graph,
464)
465run_and_verify_grad(gpu_model)
466
467# TODO: times out
468@skip_but_pass_in_sandcastle("times out")
469@requires_ucc()
470@skip_if_lt_x_gpu(2)
471def test_global_local_unused_params_grad(self):
472self._test_global_local_unused_params_grad()
473
474# TODO: times out
475@skip_but_pass_in_sandcastle("times out")
476@requires_ucc()
477@skip_if_lt_x_gpu(2)
478def test_global_local_unused_params_grad_with_grad_is_view(self):
479self._test_global_local_unused_params_grad(gradient_as_bucket_view=True)
480
481# TODO: times out
482@skip_but_pass_in_sandcastle("times out")
483@requires_ucc()
484@skip_if_lt_x_gpu(2)
485def test_global_local_unused_params_grad_with_static_graph(self):
486self._test_global_local_unused_params_grad(static_graph=True)
487
488# TODO: times out
489@skip_but_pass_in_sandcastle("times out")
490@requires_ucc()
491@skip_if_lt_x_gpu(2)
492def test_find_unused_parameters_when_unused_parameters_empty(self):
493"""
494An empty unused_parameters array does not imply find_unused_parameters =
495false. This test makes sure that DDP allreduces unused parameters
496accordingly where the forward pass in some process uses all parameters.
497This unit test creates a module that uses all parameters in rank = 0, and
498has unused parameters in other ranks.
499"""
500
501class FindUnusedParamModule(nn.Module):
502def __init__(self):
503super().__init__()
504self.t0 = Task()
505self.t1 = Task()
506
507def task_parameters(self):
508return (self.t0.p, self.t1.p)
509
510def forward(self, x, rank):
511return self.t1(self.t0(x)) if rank == 0 else self.t1(x)
512
513def run_and_verify_grad(model):
514# Run forward
515output = model(8, self.rank)
516
517# The grads of all parameters should be None at this point.
518[self.assertIsNone(t_p.grad) for t_p in model.module.task_parameters()]
519
520# Run backward
521output.mean().backward()
522
523# Now locally unused parameter should have grad updated on all ranks.
524[self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()]
525
526process_group = self._get_process_group()
527
528# Test on CPU
529cpu_model = DistributedDataParallel(
530FindUnusedParamModule().cpu(),
531process_group=process_group,
532find_unused_parameters=True,
533)
534run_and_verify_grad(cpu_model)
535
536# Test on GPU
537device_id = gpus_for_rank(self.world_size)[self.rank][0]
538gpu_model = DistributedDataParallel(
539FindUnusedParamModule().to(device_id),
540device_ids=[device_id],
541process_group=process_group,
542find_unused_parameters=True,
543)
544run_and_verify_grad(gpu_model)
545
546@requires_ucc()
547def test_ignored_output(self):
548"""
549Test that the output of a model can be ignored and that there is no
550implicit requirement that `backward` gets called.
551"""
552process_group = self._get_process_group()
553
554class IgnoredOutput(nn.Module):
555def __init__(self):
556super().__init__()
557self.fc1 = nn.Linear(2, 10, bias=False)
558self.fc2 = nn.Linear(10, 4, bias=False)
559self.relu = nn.ReLU()
560
561def forward(self, x):
562x = self.relu(self.fc1(x))
563x = self.relu(self.fc2(x))
564return F.softmax(x, dim=1)
565
566model = DistributedDataParallel(
567IgnoredOutput().float(),
568process_group=process_group,
569)
570
571batch_size = 4
572criterion = nn.CrossEntropyLoss()
573input = torch.rand([batch_size, 2], dtype=torch.float)
574target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
575
576# Run a few iterations where we ignore the output.
577for _ in range(4):
578output = model(input)
579del output
580
581# Run a few iterations where we use the output.
582for _ in range(4):
583output = model(input)
584loss = criterion(output, target)
585loss.backward()
586
587@requires_ucc()
588def test_ignored_output_with_unused_parameters(self):
589"""
590Test that the output of a model can be ignored and that there is no
591implicit requirement that `backward` gets called, if not all model
592parameters participated in computing the model output.
593"""
594process_group = self._get_process_group()
595
596class IgnoredOutputWithUnusedParameters(nn.Module):
597def __init__(self):
598super().__init__()
599self.fc1 = nn.Linear(2, 10, bias=False)
600self.fc2 = nn.Linear(10, 4, bias=False)
601self.fc3 = nn.Linear(4, 4, bias=False)
602self.relu = nn.ReLU()
603
604def forward(self, x):
605x = self.relu(self.fc1(x))
606x = self.relu(self.fc2(x))
607return F.softmax(x, dim=1)
608
609model = DistributedDataParallel(
610IgnoredOutputWithUnusedParameters().float(),
611process_group=process_group,
612find_unused_parameters=True,
613)
614
615batch_size = 4
616criterion = nn.CrossEntropyLoss()
617input = torch.rand([batch_size, 2], dtype=torch.float)
618target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
619
620# Run a few iterations where we ignore the output.
621for _ in range(4):
622output = model(input)
623del output
624
625# Run a few iterations where we use the output.
626for _ in range(4):
627output = model(input)
628loss = criterion(output, target)
629loss.backward()
630
631def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):
632mult = 2
633batch_size = mult * self.world_size
634criterion = nn.CrossEntropyLoss()
635input = torch.randint(0, 10, [batch_size, 2])
636target = torch.randint(0, 10, [batch_size])
637
638# Run with entire batch against single process version
639criterion(vanilla_model(input), target).backward()
640
641# Run with partial batch against multi process version
642partial_input = input.split(mult)[self.rank]
643partial_target = target.split(mult)[self.rank]
644criterion(ddp_model(partial_input), partial_target).backward()
645
646# Check that the gradients are sparse and identical
647vanilla_parameter = next(vanilla_model.parameters())
648ddp_parameter = next(ddp_model.parameters())
649self.assertEqual(vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce())
650
651@requires_ucc()
652@skip_if_lt_x_gpu(2)
653def test_save_load_checkpoint(self):
654dist.init_process_group(
655"ucc",
656init_method=f"file://{self.file_name}",
657world_size=self.world_size,
658rank=self.rank,
659)
660
661class TestModel(nn.Module):
662def __init__(self):
663super().__init__()
664self.fc1 = nn.Linear(2, 10, bias=False)
665self.fc2 = nn.Linear(10, 4, bias=False)
666self.relu = nn.ReLU()
667
668def forward(self, x):
669x = self.relu(self.fc1(x))
670x = self.relu(self.fc2(x))
671return F.softmax(x, dim=1)
672
673def train_loop(model, optimizer, iterations):
674for _ in range(iterations):
675optimizer.zero_grad()
676output = model(input)
677loss = criterion(output, target)
678loss.backward()
679optimizer.step()
680
681device_id = gpus_for_rank(self.world_size)[self.rank][0]
682
683model_withload = TestModel().float().to(device_id)
684model_withoutload = TestModel().float().to(device_id)
685
686ddp_withload = DistributedDataParallel(
687model_withload,
688device_ids=[device_id],
689)
690ddp_withoutload = DistributedDataParallel(
691model_withoutload,
692device_ids=[device_id],
693)
694
695# ensure that all the three models start with the same set of parameters. By default they are randomized on construction
696for p in ddp_withload.parameters():
697with torch.no_grad():
698p.zero_()
699for p in model_withload.parameters():
700with torch.no_grad():
701p.zero_()
702for p in ddp_withoutload.parameters():
703with torch.no_grad():
704p.zero_()
705
706batch_size = 4
707criterion = nn.CrossEntropyLoss()
708
709optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001)
710optimizer_non_ddp_withload = torch.optim.SGD(
711model_withload.parameters(), lr=0.001
712)
713optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001)
714
715input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)
716target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
717device_id
718)
719
720# run the model for 6 iterations, with a checkpoint in the middle
721train_loop(ddp_withload, optimizer_withload, 3)
722
723# zero out parameters of both DDP and non-DDP models and reload them from the DDP state dict
724checkpoint_path = tempfile.gettempdir() + "/model.checkpoint"
725if self.rank == 0:
726torch.save(ddp_withload.state_dict(), checkpoint_path)
727
728dist.barrier()
729map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
730ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
731
732for model in [ddp_withload, model_withload]:
733for p in ddp_withload.parameters():
734with torch.no_grad():
735p.zero_()
736ddp_withload.load_state_dict(ddp_state_dict)
737# the non-DDP model needs to first remove the prefix of "module." from the DDP state dict
738torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
739ddp_state_dict, "module."
740)
741model_withload.load_state_dict(ddp_state_dict)
742
743train_loop(ddp_withload, optimizer_withload, 3)
744train_loop(model_withload, optimizer_non_ddp_withload, 3)
745
746# re-run the model with the same inputs for 6 iterations with no checkpoint
747train_loop(ddp_withoutload, optimizer_withoutload, 6)
748
749for p_withload, p_withoutload, p_non_ddp_withload in zip(
750ddp_withload.parameters(),
751ddp_withoutload.parameters(),
752model_withload.parameters(),
753):
754self.assertEqual(p_withload, p_withoutload)
755self.assertEqual(p_non_ddp_withload, p_withoutload)
756
757def _test_sparse_gradients(self, gradient_as_bucket_view=False):
758process_group = self._get_process_group()
759
760# Ensure initialized weights and inputs are identical across processes
761torch.manual_seed(1337)
762
763vanilla_model = SparseGradientModule()
764ddp_model = DistributedDataParallel(
765copy.deepcopy(vanilla_model),
766process_group=process_group,
767gradient_as_bucket_view=gradient_as_bucket_view,
768)
769
770self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
771
772# TODO: backward pass: input tensor has to be dense
773@skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
774@requires_ucc()
775def test_sparse_gradients(self):
776self._test_sparse_gradients()
777
778# TODO: backward pass: input tensor has to be dense
779@skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
780@requires_ucc()
781def test_sparse_gradients_grad_is_view(self):
782self._test_sparse_gradients(gradient_as_bucket_view=True)
783
784@requires_ucc()
785def test_ddp_comm_hook_future_passing_cpu(self):
786"""
787This unit test verifies whether the Future object is passed properly.
788The callback function creates a Future object and sets a value to it.
789"""
790process_group = self._get_process_group()
791
792# Test on CPU
793cpu_model = DistributedDataParallel(
794ModuleForDdpCommHook().cpu(), process_group=process_group
795)
796
797# Register DDP Communication Hook
798cpu_model.register_comm_hook(None, self._simple_hook)
799
800# check whether the grads are equal to what then callback returns.
801# without the comm_hook, result would be 0.25 * torch.ones(2, 2).
802self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2))
803
804def _gpu_model_with_ddp_comm_hook(
805self, process_group, hook=None, gradient_as_bucket_view=False, state=None
806):
807device_id = gpus_for_rank(self.world_size)[self.rank][0]
808gpu_model = DistributedDataParallel(
809ModuleForDdpCommHook().to(device_id),
810device_ids=[device_id],
811process_group=process_group,
812gradient_as_bucket_view=gradient_as_bucket_view,
813)
814
815# Register a DDP communication hook if any.
816if hook is not None:
817gpu_model.register_comm_hook(state, hook)
818
819return gpu_model
820
821@requires_ucc()
822@skip_if_lt_x_gpu(2)
823def test_ddp_comm_hook_future_passing_gpu_ucc(self):
824"""
825This unit test verifies whether the Future object is passed properly using ucc backend.
826The hook callback function creates a Future object and sets a value to it.
827"""
828process_group = self._get_process_group()
829
830# Get GPU model with simple_hook registered.
831gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)
832
833# check whether the grads are equal to what simple_hook's then callback returns.
834# without the comm_hook, result would be 0.25 * torch.ones(2, 2).
835self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
836
837@requires_ucc()
838def test_ddp_invalid_comm_hook_init(self):
839"""
840This unit test makes sure that register_comm_hook properly checks the format
841of hook defined by user. The Python hook must be callable. This test also
842checks whether bucket annotation checked properly if defined.
843"""
844process_group = self._get_process_group()
845
846model = DistributedDataParallel(
847ModuleForDdpCommHook(), process_group=process_group
848)
849
850with self.assertRaisesRegex(TypeError, "Communication hook must be callable."):
851model.register_comm_hook(state=None, hook=1)
852
853with self.assertRaisesRegex(
854ValueError, "bucket annotation should be dist.GradBucket."
855):
856
857def comm_hook(
858state: object, bucket: int
859) -> torch.futures.Future[torch.Tensor]:
860return torch.futures.Future()
861
862model.register_comm_hook(state=None, hook=comm_hook)
863
864@requires_ucc()
865def test_ddp_invalid_comm_hook_return_type(self):
866"""
867This test checks whether return annotation checked properly if defined. It also
868checks whether an internal error is thrown if return type is incorrect and user
869hasn't specified any return type annotation.
870"""
871process_group = self._get_process_group()
872
873model = DistributedDataParallel(
874ModuleForDdpCommHook(), process_group=process_group
875)
876
877expected_err = "Communication hook: return annotation should be torch.futures.Future"
878with self.assertRaisesRegex(
879ValueError,
880expected_err,
881):
882
883def comm_hook(state: object, bucket: dist.GradBucket) -> int:
884return torch.futures.Future()
885
886model.register_comm_hook(state=None, hook=comm_hook)
887
888verify_ddp_error_logged(model, expected_err)
889
890with self.assertRaisesRegex(
891RuntimeError,
892"callback must return a torch.futures.Future object, but got",
893):
894
895def comm_hook(state: object, bucket: dist.GradBucket):
896return 1
897
898model.register_comm_hook(state=None, hook=comm_hook)
899
900# Run forward
901output = model(8, self.rank)
902
903# Run backward
904output.mean().backward()
905
906@requires_ucc()
907def test_ddp_comm_hook_register_just_once(self):
908"""
909DDP communication hook can only be registered once. This test validates whether
910the error is thrown properly when register_comm_hook is called more than once.
911"""
912process_group = self._get_process_group()
913
914model = DistributedDataParallel(
915ModuleForDdpCommHook(), process_group=process_group
916)
917
918def dummy_hook(state, bucket):
919fut = torch.futures.Future()
920fut.set_result([bucket.buffer()])
921return fut
922
923model.register_comm_hook(None, dummy_hook)
924
925with self.assertRaisesRegex(
926RuntimeError,
927"register_comm_hook or register_builtin_comm_hook can only be called once.",
928):
929model.register_comm_hook(None, dummy_hook)
930
931# TODO: backward pass: input tensor must be dense
932@skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
933@requires_ucc()
934def test_ddp_comm_hook_sparse_gradients(self):
935"""
936Runs "test_sparse_gradients" unit test with DDP communication hook. We define a
937simple hook that does allreduce and works with ucc backend for this test.
938"""
939process_group = self._get_process_group()
940
941# Ensure initialized weights and inputs are identical across processes
942torch.manual_seed(1337)
943
944vanilla_model = SparseGradientModule()
945ddp_model = DistributedDataParallel(
946copy.deepcopy(vanilla_model),
947process_group=process_group,
948)
949
950def allreduce_hook_ucc(
951state: object, bucket: dist.GradBucket
952) -> torch.futures.Future[torch.Tensor]:
953def div_by_world_size(fut):
954# Divide the result by 2 * world_size.
955return fut.wait()[0] / self.world_size
956
957# Prepare allreduced grad bucket tensors by running an async work.
958fut = process_group.allreduce([bucket.buffer()]).get_future()
959return fut.then(div_by_world_size)
960
961ddp_model.register_comm_hook(None, allreduce_hook_ucc)
962
963self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
964
965
966class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
967@property
968def device(self):
969return "cpu"
970
971def setUp(self):
972super().setUp()
973self._spawn_processes()
974
975def tearDown(self):
976super().tearDown()
977try:
978os.remove(self.file_name)
979except OSError:
980pass
981
982@requires_ucc()
983@skip_if_lt_x_gpu(2)
984def test_sequence_num_set_default_pg_ucc(self):
985self._test_sequence_num_set_default_pg(backend="ucc")
986
987@requires_ucc()
988@skip_if_lt_x_gpu(2)
989def test_sequence_num_set_ucc_new_group(self):
990self._test_sequence_num_set_new_group(backend="ucc")
991
992@skip_if_lt_x_gpu(2)
993@requires_ucc()
994def test_sequence_num_incremented_ucc_default(self):
995self._test_sequence_num_incremented_default_group("ucc")
996
997@skip_if_lt_x_gpu(4)
998@requires_ucc()
999def test_sequence_num_incremented_ucc_subgroup(self):
1000if self.world_size < 4:
1001return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
1002self._test_sequence_num_incremented_subgroup("ucc")
1003
1004@skip_but_pass_in_sandcastle("Fails on M60")
1005@requires_ucc()
1006def test_ucc_barrier_device_ids(self):
1007store = c10d.FileStore(self.file_name, self.world_size)
1008c10d.init_process_group(
1009backend="ucc", rank=self.rank, world_size=self.world_size, store=store
1010)
1011
1012with self.assertRaisesRegex(RuntimeError, "device_ids not supported"):
1013c10d.barrier(device_ids=[self.rank])
1014
1015@skip_but_pass_in_sandcastle("Fails on M60")
1016@skip_if_lt_x_gpu(2)
1017@requires_ucc()
1018def test_ucc_warn_not_in_group(self):
1019self._test_warn_not_in_group(backend="ucc")
1020
1021@skip_if_lt_x_gpu(2)
1022@requires_ucc()
1023def test_ucc_rank_membership(self):
1024self._test_rank_membership(backend="ucc")
1025
1026@skip_if_lt_x_gpu(2)
1027@requires_ucc()
1028def test_tensor_dtype_mismatch(self):
1029self._test_tensor_dtype_mismatch(backend="ucc")
1030
1031@skip_if_lt_x_gpu(2)
1032@requires_ucc()
1033def test_tensor_dtype_complex(self):
1034self._test_tensor_dtype_complex(backend="ucc")
1035
1036
1037class CompilerTest(test_c10d_common.CompilerTest):
1038
1039@property
1040def world_size(self):
1041return 2
1042
1043def _get_default_group(self):
1044store = c10d.FileStore(self.file_name, self.world_size)
1045dist.init_process_group(
1046backend="ucc",
1047rank=self.rank,
1048world_size=self.world_size,
1049store=store,
1050)
1051return dist.distributed_c10d._get_default_group()
1052
1053@skip_if_lt_x_gpu(2)
1054def test_allreduce_work_wait_gpu(self):
1055self._test_allreduce_work_wait(
1056torch.ones(2, 2, device=self.rank) * self.rank,
1057)
1058
1059@skip_if_lt_x_gpu(2)
1060def test_allgather_work_wait_gpu(self):
1061self._test_allgather_work_wait(
1062torch.ones(2, 2, device=self.rank) * self.rank
1063)
1064
1065@skip_if_lt_x_gpu(2)
1066def test_broadcast_work_wait_gpu(self):
1067self._test_broadcast_work_wait(
1068torch.ones(2, 2, device=self.rank) * self.rank
1069)
1070
1071@skip_if_lt_x_gpu(2)
1072def test_nested_comm_tensor_wrapping_gpu(self):
1073self._test_nested_comm_tensor_wrapping(
1074torch.ones(2, 2, device=self.rank) * self.rank
1075)
1076
1077@skip_if_lt_x_gpu(2)
1078def test_consecutive_comm_work_wait_gpu(self):
1079self._test_consecutive_comm_work_wait(
1080torch.ones(2, 2, device=self.rank) * self.rank
1081)
1082
1083def test_allreduce_work_wait_cpu(self):
1084self._test_allreduce_work_wait(
1085torch.ones(2, 2) * self.rank,
1086)
1087
1088def test_allgather_work_wait_cpu(self):
1089self._test_allgather_work_wait(
1090torch.ones(2, 2) * self.rank
1091)
1092
1093def test_broadcast_work_wait_cpu(self):
1094self._test_broadcast_work_wait(
1095torch.ones(2, 2) * self.rank
1096)
1097
1098def test_nested_comm_tensor_wrapping_cpu(self):
1099self._test_nested_comm_tensor_wrapping(
1100torch.ones(2, 2) * self.rank
1101)
1102
1103def test_consecutive_comm_work_wait_cpu(self):
1104self._test_consecutive_comm_work_wait(
1105torch.ones(2, 2) * self.rank
1106)
1107
1108
1109class UccProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGroupWithDispatchedCollectivesTests):
1110
1111@skip_but_pass_in_sandcastle("Fails on M60")
1112@requires_ucc()
1113@skip_if_lt_x_gpu(1)
1114def test_collectives(self):
1115# includes reduce, broadcast, all_reduce, all_gather, reduce_scatter, barrier, all_to_all, scatter
1116self._test_collectives(backend="ucc")
1117
1118@skip_but_pass_in_sandcastle("Fails on M60")
1119@requires_ucc()
1120@skip_if_lt_x_gpu(1)
1121def test_allgather_base(self):
1122store = dist.FileStore(self.file_name, self.world_size)
1123dist.init_process_group(
1124"ucc",
1125world_size=self.world_size,
1126rank=self.rank,
1127store=store,
1128)
1129device = "cuda"
1130tensor = torch.ones(10, 10, device=torch.device(device))
1131output_tensor = torch.zeros(10, 10, device=torch.device(device))
1132dist.all_gather_into_tensor(output_tensor, tensor)
1133self.assertEqual(output_tensor, tensor)
1134
1135
1136if __name__ == "__main__":
1137assert (
1138not torch.cuda._initialized
1139), "test_distributed must not have initialized CUDA context on main process"
1140
1141run_tests()
1142