pytorch

Форк
0
/
test_c10d_ucc.py 
1082 строки · 37.0 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import copy
4
import logging
5
import math
6
import operator
7
import os
8
import random
9
import sys
10
import tempfile
11
from functools import reduce
12

13
import torch
14
import torch.distributed as c10d
15

16

17
if not c10d.is_available() or not c10d.is_ucc_available():
18
    print("c10d UCC not available, skipping tests", file=sys.stderr)
19
    sys.exit(0)
20

21
import test_c10d_common
22
from test_c10d_common import (
23
    gpus_for_rank,
24
    ModuleForDdpCommHook,
25
    SparseGradientModule,
26
    Task,
27
)
28

29
import torch.distributed as dist
30
import torch.nn.functional as F
31
import torch.testing._internal.common_utils as common
32
from torch import nn
33
from torch.nn.parallel import DistributedDataParallel
34
from torch.testing._internal.common_distributed import (
35
    MultiProcessTestCase,
36
    requires_ucc,
37
    skip_if_lt_x_gpu,
38
    verify_ddp_error_logged,
39
)
40
from torch.testing._internal.common_utils import (
41
    retry_on_connect_failures,
42
    run_tests,
43
    skip_but_pass_in_sandcastle,
44
    TestCase,
45
)
46

47

48
def simple_reduce_tests(rank, world_size):
49
    tests = [
50
        (
51
            c10d.ReduceOp.SUM,
52
            torch.tensor([rank + 1.0]),
53
            torch.tensor([float(world_size * (world_size + 1) / 2)]),
54
        ),
55
        (
56
            c10d.ReduceOp.PRODUCT,
57
            torch.tensor([rank + 1.0]),
58
            torch.tensor([float(math.factorial(world_size))]),
59
        ),
60
        (
61
            c10d.ReduceOp.MIN,
62
            torch.tensor([rank + 1.0]),
63
            torch.tensor([1.0]),
64
        ),
65
        (
66
            c10d.ReduceOp.MAX,
67
            torch.tensor([rank + 1.0]),
68
            torch.tensor([world_size]),
69
        ),
70
    ]
71

72
    # Generate tests for BAND.
73
    # The bit that is set changes in every iteration to check
74
    # that the output changes accordingly.
75
    for i in range(4):
76
        vin = rank | (1 << i)
77
        vout = 1 << i
78
        tests.append(
79
            (
80
                c10d.ReduceOp.BAND,
81
                torch.tensor([vin], dtype=torch.int32),
82
                torch.tensor([vout], dtype=torch.int32),
83
            ),
84
        )
85

86
    # Generate tests for BOR.
87
    # These emulate a larger world size per iteration by having every
88
    # rank contribute multiple values that are pre-OR'ed.
89
    for i in range(1, 5):
90
        vin = reduce(operator.or_, [rank * i + j for j in range(i)])
91
        vout = reduce(operator.or_, range(world_size * i))
92
        tests.append(
93
            (
94
                c10d.ReduceOp.BOR,
95
                torch.tensor([vin], dtype=torch.int32),
96
                torch.tensor([vout], dtype=torch.int32),
97
            ),
98
        )
99

100
    # Generate tests for XOR.
101
    # These emulate a larger world size per iteration by having every
102
    # rank contribute multiple values that are pre-XOR'ed.
103
    for i in range(1, 5):
104
        vin = reduce(operator.xor, [rank * i + j for j in range(i)])
105
        vout = reduce(operator.xor, range(world_size * i))
106
        tests.append(
107
            (
108
                c10d.ReduceOp.BXOR,
109
                torch.tensor([vin], dtype=torch.int32),
110
                torch.tensor([vout], dtype=torch.int32),
111
            ),
112
        )
113

114
    return tests
115

116

117
class RendezvousEnvTest(TestCase):
118
    @requires_ucc()
119
    @retry_on_connect_failures
120
    def test_logging_init(self):
121
        os.environ["WORLD_SIZE"] = "1"
122
        os.environ["MASTER_ADDR"] = "127.0.0.1"
123
        os.environ["MASTER_PORT"] = str(common.find_free_port())
124
        os.environ["RANK"] = "0"
125

126
        previous_handlers = logging.root.handlers
127

128
        c10d.init_process_group(backend="ucc", init_method="env://")
129

130
        current_handlers = logging.root.handlers
131
        self.assertEqual(len(previous_handlers), len(current_handlers))
132
        for current, previous in zip(current_handlers, previous_handlers):
133
            self.assertEqual(current, previous)
134

135
        c10d.destroy_process_group()
136

137

138
class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
139
    @requires_ucc()
140
    @retry_on_connect_failures
141
    def test_default_store_timeout_ucc(self):
142
        self._test_default_store_timeout("ucc")
143

144

145
class ProcessGroupUCCTest(MultiProcessTestCase):
146
    def _create_process_group_ucc(self):
147
        store = c10d.FileStore(self.file_name, self.world_size)
148
        return c10d.ProcessGroupUCC(store, self.rank, self.world_size)
149

150
    def setUp(self):
151
        super().setUp()
152
        self._spawn_processes()
153

154
    def tearDown(self):
155
        super().tearDown()
156
        try:
157
            os.remove(self.file_name)
158
        except OSError:
159
            pass
160

161
    @requires_ucc()
162
    def test_empty_tensors(self):
163
        pg = self._create_process_group_ucc()
164

165
        xs = [torch.FloatTensor([])]
166
        fut = pg.broadcast(xs).get_future()
167
        fut.wait()
168
        output = fut.value()
169
        self.assertEqual(0, output[0].numel())
170
        self.assertEqual(xs[0], output[0], exact_dtype=False)
171

172
    # TODO: add error check testing
173

174
    def _test_broadcast_basics(self, fn):
175
        pg = self._create_process_group_ucc()
176

177
        def broadcast(xs, rootRank, rootTensor):
178
            opts = c10d.BroadcastOptions()
179
            opts.rootRank = rootRank
180
            opts.rootTensor = rootTensor
181
            fut = pg.broadcast(xs, opts).get_future()
182
            fut.wait()
183
            return fut.value()
184

185
        # Every rank is root once
186
        for i in range(self.world_size):
187
            # Run with 1 input tensor
188
            x = fn(torch.tensor([self.rank]))
189
            output = broadcast([x], i, 0)
190
            self.assertEqual(torch.tensor([i]), output[0], exact_dtype=False)
191

192
            # TODO: UCC currently does not support multi tensor input
193

194
        # Test overloaded convenience function
195
        x = torch.tensor([self.rank + 1.0])
196
        fut = pg.broadcast(x, root=0).get_future()
197
        fut.wait()
198
        result = fut.value()
199
        self.assertEqual(torch.tensor([1.0]), result[0])
200

201
    @requires_ucc()
202
    def test_broadcast_basics(self):
203
        self._test_broadcast_basics(lambda t: t.clone())
204

205
    # TODO: test_broadcast_basics_cuda times out locally
206

207
    def _test_allreduce_basics(self, fn):
208
        pg = self._create_process_group_ucc()
209

210
        # Single input tests
211
        tests = simple_reduce_tests(self.rank, self.world_size)
212
        for op, input, expected in tests:
213
            opts = c10d.AllreduceOptions()
214
            opts.reduceOp = op
215
            tensor = fn(input)
216
            fut = pg.allreduce([tensor], opts).get_future()
217
            fut.wait()
218
            result = fut.value()
219
            self.assertEqual(expected, result[0], exact_dtype=False)
220

221
        # TODO: UCC currently does not support multi tensor input
222

223
        # Test overloaded convenience function (defaults to using sum)
224
        x = fn(torch.tensor([self.rank + 1.0]))
225
        fut = pg.allreduce(x).get_future()
226
        fut.wait()
227
        result = fut.value()
228
        self.assertEqual(
229
            torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
230
            result[0],
231
        )
232

233
    @requires_ucc()
234
    def test_allreduce_basics(self):
235
        self._test_allreduce_basics(lambda t: t.clone())
236

237
    # TODO: test_allreduce_basics_cuda times out locally
238

239
    def _test_allgather_basics(self, fn):
240
        pg = self._create_process_group_ucc()
241

242
        # TODO: Run with N input tensor per rank; for now, UCC only supports single tensor input so N=1
243
        for n in [1]:
244
            input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)]
245
            output = [
246
                [fn(torch.tensor([-1])) for _ in range(n * self.world_size)]
247
                for _ in range(n)
248
            ]
249
            expected_output = [
250
                [fn(torch.tensor([i])) for i in range(n * self.world_size)]
251
                for _ in range(n)
252
            ]
253
            fut = pg.allgather(output, input).get_future()
254
            fut.wait()
255
            result = fut.value()
256
            if n == 1:
257
                result = [result]
258
            self.assertEqual(expected_output, result)
259

260
    def test_allgather_basics(self):
261
        self._test_allgather_basics(lambda t: t.clone())
262

263
    def _test_reduce_basics(self, fn):
264
        pg = self._create_process_group_ucc()
265
        for op, input, output in simple_reduce_tests(self.rank, self.world_size):
266
            for root in range(self.world_size):
267
                opts = c10d.ReduceOptions()
268
                opts.reduceOp = op
269
                opts.rootRank = root
270
                tmp = fn(input)
271
                fut = pg.reduce([tmp], opts).get_future()
272
                fut.wait()
273
                result = fut.value()
274
                if root == self.rank:
275
                    self.assertEqual(output, result[0], exact_dtype=False)
276

277
    @requires_ucc()
278
    def test_reduce_basics(self):
279
        self._test_reduce_basics(lambda t: t.clone())
280

281
    # TODO: test_reduce_basics_cuda times out locally
282

283
    @requires_ucc()
284
    def test_send_recv_all_to_all(self):
285
        pg = self._create_process_group_ucc()
286

287
        # Preallocate tensors for input/output
288
        inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)]
289
        outputs = [torch.tensor([-1]) for _ in range(self.world_size)]
290

291
        # Issue sends
292
        send_work = []
293
        for i in range(self.world_size):
294
            if i == self.rank:
295
                continue
296
            send_work.append(pg.send([inputs[i]], i, 0))
297

298
        # Issue recvs
299
        recv_work = []
300
        for i in range(self.world_size):
301
            if i == self.rank:
302
                continue
303
            recv_work.append(pg.recv([outputs[i]], i, 0))
304

305
        # Wait for sends to complete
306
        for work in send_work:
307
            work.wait()
308
            self.assertTrue(work.is_completed())
309

310
        # Wait for recvs to complete
311
        for work in recv_work:
312
            work.wait()
313
            self.assertTrue(work.is_completed())
314

315
        # Test that every output other than our own contains the respective rank
316
        for i in range(self.world_size):
317
            if i == self.rank:
318
                continue
319
            self.assertEqual(torch.tensor([i]), outputs[i])
320

321
    # TODO: test_barrier_implies_wait fails with numerical mismatch, will investigate later
322
    @skip_but_pass_in_sandcastle("fails with numerical mismatch, skip for now")
323
    @requires_ucc()
324
    def test_barrier_implies_wait(self):
325
        pg = self._create_process_group_ucc()
326

327
        # Kick off allreduce operations
328
        size = (100, 100)
329
        num = 16
330
        tensors = [torch.full(size, float(i)) for i in range(num)]
331
        for tensor in tensors:
332
            # Note: leak the returned work handle
333
            pg.allreduce(tensor)
334

335
        # Barrier should ensure all previous work has completed
336
        pg.barrier().get_future().wait()
337

338
        for i, tensor in enumerate(tensors):
339
            self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
340

341

342
class DistributedDataParallelTest(
343
    test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
344
):
345
    def setUp(self):
346
        super().setUp()
347
        self._spawn_processes()
348

349
    def _get_process_group(self):
350
        store = self._get_store()
351
        c10d.init_process_group(
352
            "ucc", store=store, rank=self.rank, world_size=self.world_size
353
        )
354
        return c10d.distributed_c10d._get_default_group()
355

356
    def _test_ucc_backend(
357
        self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
358
    ):
359
        process_group = self._get_process_group()
360
        self._test_ddp_with_process_group(
361
            process_group, devices, device_ids, multi_device, gradient_as_bucket_view
362
        )
363

364
    @requires_ucc()
365
    def test_ucc_backend_cpu_module(self):
366
        self._test_ucc_backend([torch.device("cpu")], None)
367

368
    @requires_ucc()
369
    def test_ucc_backend_cpu_module_grad_is_view(self):
370
        self._test_ucc_backend(
371
            [torch.device("cpu")], None, gradient_as_bucket_view=True
372
        )
373

374
    @requires_ucc()
375
    @skip_if_lt_x_gpu(2)
376
    def test_ucc_backend_1gpu_module_device_ids_integer_list(self):
377
        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
378
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
379
        self._test_ucc_backend(devices, int_devices)
380

381
    @requires_ucc()
382
    @skip_if_lt_x_gpu(2)
383
    def test_ucc_backend_1gpu_module_device_ids_torch_device_list(self):
384
        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
385
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
386
        self._test_ucc_backend(devices, devices)
387

388
    # TODO: test_ucc_backend_2gpu_module and test_ucc_backend_4gpu_module
389
    # require broadcast_coalesced which is not supported by ucc currently
390
    @skip_but_pass_in_sandcastle(
391
        "requires broadcast coalesced, which is not supported by ucc currently"
392
    )
393
    @requires_ucc()
394
    @skip_if_lt_x_gpu(4)
395
    def test_ucc_backend_2gpu_module(self):
396
        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
397
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
398
        self._test_ucc_backend(devices, None, multi_device=True)
399

400
    @skip_but_pass_in_sandcastle(
401
        "requires broadcast coalesced, which is not supported by ucc currently"
402
    )
403
    @requires_ucc()
404
    @skip_if_lt_x_gpu(8)
405
    def test_ucc_backend_4gpu_module(self):
406
        int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
407
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
408
        self._test_ucc_backend(devices, None, multi_device=True)
409

410
    def _test_global_local_unused_params_grad(
411
        self, gradient_as_bucket_view=False, static_graph=False
412
    ):
413
        """
414
        By simulating a multi-task training, this test is to make sure:
415
        1) DDP does not touch the grad of globally unused parameters.
416
        2) DDP does update the grad of locally unused parameters.
417
        """
418

419
        class GlobalLocalUnusedParamModule(nn.Module):
420
            def __init__(self) -> None:
421
                super().__init__()
422
                self.t0 = Task()
423
                self.t1 = Task()
424
                self.task_unused = Task()
425

426
            def task_parameters(self):
427
                return (self.t0.p, self.t1.p, self.task_unused.p)
428

429
            def forward(self, x, rank):
430
                return self.t0(x) if rank == 0 else self.t1(x)
431

432
        def run_and_verify_grad(model):
433
            # Run forward
434
            output = model(8, self.rank)
435

436
            # The grads of all parameters should be None at this point.
437
            t0_p, t1_p, task_unused_p = model.module.task_parameters()
438
            self.assertIsNone(t0_p.grad)
439
            self.assertIsNone(t1_p.grad)
440
            self.assertIsNone(task_unused_p.grad)
441

442
            # Run backward
443
            output.mean().backward()
444

445
            # Now locally unused parameter should have grad updated on all ranks.
446
            # However the globally unused parameter should still have None grad.
447
            self.assertIsNotNone(t0_p.grad)
448
            self.assertIsNotNone(t1_p.grad)
449
            self.assertIsNone(task_unused_p.grad)
450

451
        process_group = self._get_process_group()
452

453
        # Test on CPU
454
        cpu_model = DistributedDataParallel(
455
            GlobalLocalUnusedParamModule().cpu(),
456
            process_group=process_group,
457
            find_unused_parameters=True,
458
            gradient_as_bucket_view=gradient_as_bucket_view,
459
            static_graph=static_graph,
460
        )
461
        run_and_verify_grad(cpu_model)
462

463
        # Test on GPU
464
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
465
        gpu_model = DistributedDataParallel(
466
            GlobalLocalUnusedParamModule().to(device_id),
467
            device_ids=[device_id],
468
            process_group=process_group,
469
            find_unused_parameters=True,
470
            gradient_as_bucket_view=gradient_as_bucket_view,
471
            static_graph=static_graph,
472
        )
473
        run_and_verify_grad(gpu_model)
474

475
    # TODO: times out
476
    @skip_but_pass_in_sandcastle("times out")
477
    @requires_ucc()
478
    @skip_if_lt_x_gpu(2)
479
    def test_global_local_unused_params_grad(self):
480
        self._test_global_local_unused_params_grad()
481

482
    # TODO: times out
483
    @skip_but_pass_in_sandcastle("times out")
484
    @requires_ucc()
485
    @skip_if_lt_x_gpu(2)
486
    def test_global_local_unused_params_grad_with_grad_is_view(self):
487
        self._test_global_local_unused_params_grad(gradient_as_bucket_view=True)
488

489
    # TODO: times out
490
    @skip_but_pass_in_sandcastle("times out")
491
    @requires_ucc()
492
    @skip_if_lt_x_gpu(2)
493
    def test_global_local_unused_params_grad_with_static_graph(self):
494
        self._test_global_local_unused_params_grad(static_graph=True)
495

496
    # TODO: times out
497
    @skip_but_pass_in_sandcastle("times out")
498
    @requires_ucc()
499
    @skip_if_lt_x_gpu(2)
500
    def test_find_unused_parameters_when_unused_parameters_empty(self):
501
        """
502
        An empty unused_parameters array does not imply find_unused_parameters =
503
        false. This test makes sure that DDP allreduces unused parameters
504
        accordingly where the forward pass in some process uses all parameters.
505
        This unit test creates a module that uses all parameters in rank = 0, and
506
        has unused parameters in other ranks.
507
        """
508

509
        class FindUnusedParamModule(nn.Module):
510
            def __init__(self) -> None:
511
                super().__init__()
512
                self.t0 = Task()
513
                self.t1 = Task()
514

515
            def task_parameters(self):
516
                return (self.t0.p, self.t1.p)
517

518
            def forward(self, x, rank):
519
                return self.t1(self.t0(x)) if rank == 0 else self.t1(x)
520

521
        def run_and_verify_grad(model):
522
            # Run forward
523
            output = model(8, self.rank)
524

525
            # The grads of all parameters should be None at this point.
526
            [self.assertIsNone(t_p.grad) for t_p in model.module.task_parameters()]
527

528
            # Run backward
529
            output.mean().backward()
530

531
            # Now locally unused parameter should have grad updated on all ranks.
532
            [self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()]
533

534
        process_group = self._get_process_group()
535

536
        # Test on CPU
537
        cpu_model = DistributedDataParallel(
538
            FindUnusedParamModule().cpu(),
539
            process_group=process_group,
540
            find_unused_parameters=True,
541
        )
542
        run_and_verify_grad(cpu_model)
543

544
        # Test on GPU
545
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
546
        gpu_model = DistributedDataParallel(
547
            FindUnusedParamModule().to(device_id),
548
            device_ids=[device_id],
549
            process_group=process_group,
550
            find_unused_parameters=True,
551
        )
552
        run_and_verify_grad(gpu_model)
553

554
    @requires_ucc()
555
    def test_ignored_output(self):
556
        """
557
        Test that the output of a model can be ignored and that there is no
558
        implicit requirement that `backward` gets called.
559
        """
560
        process_group = self._get_process_group()
561

562
        class IgnoredOutput(nn.Module):
563
            def __init__(self) -> None:
564
                super().__init__()
565
                self.fc1 = nn.Linear(2, 10, bias=False)
566
                self.fc2 = nn.Linear(10, 4, bias=False)
567
                self.relu = nn.ReLU()
568

569
            def forward(self, x):
570
                x = self.relu(self.fc1(x))
571
                x = self.relu(self.fc2(x))
572
                return F.softmax(x, dim=1)
573

574
        model = DistributedDataParallel(
575
            IgnoredOutput().float(),
576
            process_group=process_group,
577
        )
578

579
        batch_size = 4
580
        criterion = nn.CrossEntropyLoss()
581
        input = torch.rand([batch_size, 2], dtype=torch.float)
582
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
583

584
        # Run a few iterations where we ignore the output.
585
        for _ in range(4):
586
            output = model(input)
587
            del output
588

589
        # Run a few iterations where we use the output.
590
        for _ in range(4):
591
            output = model(input)
592
            loss = criterion(output, target)
593
            loss.backward()
594

595
    @requires_ucc()
596
    def test_ignored_output_with_unused_parameters(self):
597
        """
598
        Test that the output of a model can be ignored and that there is no
599
        implicit requirement that `backward` gets called, if not all model
600
        parameters participated in computing the model output.
601
        """
602
        process_group = self._get_process_group()
603

604
        class IgnoredOutputWithUnusedParameters(nn.Module):
605
            def __init__(self) -> None:
606
                super().__init__()
607
                self.fc1 = nn.Linear(2, 10, bias=False)
608
                self.fc2 = nn.Linear(10, 4, bias=False)
609
                self.fc3 = nn.Linear(4, 4, bias=False)
610
                self.relu = nn.ReLU()
611

612
            def forward(self, x):
613
                x = self.relu(self.fc1(x))
614
                x = self.relu(self.fc2(x))
615
                return F.softmax(x, dim=1)
616

617
        model = DistributedDataParallel(
618
            IgnoredOutputWithUnusedParameters().float(),
619
            process_group=process_group,
620
            find_unused_parameters=True,
621
        )
622

623
        batch_size = 4
624
        criterion = nn.CrossEntropyLoss()
625
        input = torch.rand([batch_size, 2], dtype=torch.float)
626
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
627

628
        # Run a few iterations where we ignore the output.
629
        for _ in range(4):
630
            output = model(input)
631
            del output
632

633
        # Run a few iterations where we use the output.
634
        for _ in range(4):
635
            output = model(input)
636
            loss = criterion(output, target)
637
            loss.backward()
638

639
    def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):
640
        mult = 2
641
        batch_size = mult * self.world_size
642
        criterion = nn.CrossEntropyLoss()
643
        input = torch.randint(0, 10, [batch_size, 2])
644
        target = torch.randint(0, 10, [batch_size])
645

646
        # Run with entire batch against single process version
647
        criterion(vanilla_model(input), target).backward()
648

649
        # Run with partial batch against multi process version
650
        partial_input = input.split(mult)[self.rank]
651
        partial_target = target.split(mult)[self.rank]
652
        criterion(ddp_model(partial_input), partial_target).backward()
653

654
        # Check that the gradients are sparse and identical
655
        vanilla_parameter = next(vanilla_model.parameters())
656
        ddp_parameter = next(ddp_model.parameters())
657
        self.assertEqual(
658
            vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce()
659
        )
660

661
    @requires_ucc()
662
    @skip_if_lt_x_gpu(2)
663
    def test_save_load_checkpoint(self):
664
        dist.init_process_group(
665
            "ucc",
666
            init_method=f"file://{self.file_name}",
667
            world_size=self.world_size,
668
            rank=self.rank,
669
        )
670

671
        class TestModel(nn.Module):
672
            def __init__(self) -> None:
673
                super().__init__()
674
                self.fc1 = nn.Linear(2, 10, bias=False)
675
                self.fc2 = nn.Linear(10, 4, bias=False)
676
                self.relu = nn.ReLU()
677

678
            def forward(self, x):
679
                x = self.relu(self.fc1(x))
680
                x = self.relu(self.fc2(x))
681
                return F.softmax(x, dim=1)
682

683
        def train_loop(model, optimizer, iterations):
684
            for _ in range(iterations):
685
                optimizer.zero_grad()
686
                output = model(input)
687
                loss = criterion(output, target)
688
                loss.backward()
689
                optimizer.step()
690

691
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
692

693
        model_withload = TestModel().float().to(device_id)
694
        model_withoutload = TestModel().float().to(device_id)
695

696
        ddp_withload = DistributedDataParallel(
697
            model_withload,
698
            device_ids=[device_id],
699
        )
700
        ddp_withoutload = DistributedDataParallel(
701
            model_withoutload,
702
            device_ids=[device_id],
703
        )
704

705
        # ensure that all the three models start with the same set of parameters. By default they are randomized on construction
706
        for p in ddp_withload.parameters():
707
            with torch.no_grad():
708
                p.zero_()
709
        for p in model_withload.parameters():
710
            with torch.no_grad():
711
                p.zero_()
712
        for p in ddp_withoutload.parameters():
713
            with torch.no_grad():
714
                p.zero_()
715

716
        batch_size = 4
717
        criterion = nn.CrossEntropyLoss()
718

719
        optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001)
720
        optimizer_non_ddp_withload = torch.optim.SGD(
721
            model_withload.parameters(), lr=0.001
722
        )
723
        optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001)
724

725
        input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)
726
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
727
            device_id
728
        )
729

730
        # run the model for 6 iterations, with a checkpoint in the middle
731
        train_loop(ddp_withload, optimizer_withload, 3)
732

733
        # zero out parameters of both DDP and non-DDP models and reload them from the DDP state dict
734
        checkpoint_path = tempfile.gettempdir() + "/model.checkpoint"
735
        if self.rank == 0:
736
            torch.save(ddp_withload.state_dict(), checkpoint_path)
737

738
        dist.barrier()
739
        map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
740
        ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
741

742
        for model in [ddp_withload, model_withload]:
743
            for p in ddp_withload.parameters():
744
                with torch.no_grad():
745
                    p.zero_()
746
        ddp_withload.load_state_dict(ddp_state_dict)
747
        # the non-DDP model needs to first remove the prefix of "module." from the DDP state dict
748
        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
749
            ddp_state_dict, "module."
750
        )
751
        model_withload.load_state_dict(ddp_state_dict)
752

753
        train_loop(ddp_withload, optimizer_withload, 3)
754
        train_loop(model_withload, optimizer_non_ddp_withload, 3)
755

756
        # re-run the model with the same inputs for 6 iterations with no checkpoint
757
        train_loop(ddp_withoutload, optimizer_withoutload, 6)
758

759
        for p_withload, p_withoutload, p_non_ddp_withload in zip(
760
            ddp_withload.parameters(),
761
            ddp_withoutload.parameters(),
762
            model_withload.parameters(),
763
        ):
764
            self.assertEqual(p_withload, p_withoutload)
765
            self.assertEqual(p_non_ddp_withload, p_withoutload)
766

767
    def _test_sparse_gradients(self, gradient_as_bucket_view=False):
768
        process_group = self._get_process_group()
769

770
        # Ensure initialized weights and inputs are identical across processes
771
        torch.manual_seed(1337)
772

773
        vanilla_model = SparseGradientModule()
774
        ddp_model = DistributedDataParallel(
775
            copy.deepcopy(vanilla_model),
776
            process_group=process_group,
777
            gradient_as_bucket_view=gradient_as_bucket_view,
778
        )
779

780
        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
781

782
    # TODO: backward pass: input tensor has to be dense
783
    @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
784
    @requires_ucc()
785
    def test_sparse_gradients(self):
786
        self._test_sparse_gradients()
787

788
    # TODO: backward pass: input tensor has to be dense
789
    @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
790
    @requires_ucc()
791
    def test_sparse_gradients_grad_is_view(self):
792
        self._test_sparse_gradients(gradient_as_bucket_view=True)
793

794
    @requires_ucc()
795
    def test_ddp_comm_hook_future_passing_cpu(self):
796
        """
797
        This unit test verifies whether the Future object is passed properly.
798
        The callback function creates a Future object and sets a value to it.
799
        """
800
        process_group = self._get_process_group()
801

802
        # Test on CPU
803
        cpu_model = DistributedDataParallel(
804
            ModuleForDdpCommHook().cpu(), process_group=process_group
805
        )
806

807
        # Register DDP Communication Hook
808
        cpu_model.register_comm_hook(None, self._simple_hook)
809

810
        # check whether the grads are equal to what then callback returns.
811
        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
812
        self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2))
813

814
    def _gpu_model_with_ddp_comm_hook(
815
        self, process_group, hook=None, gradient_as_bucket_view=False, state=None
816
    ):
817
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
818
        gpu_model = DistributedDataParallel(
819
            ModuleForDdpCommHook().to(device_id),
820
            device_ids=[device_id],
821
            process_group=process_group,
822
            gradient_as_bucket_view=gradient_as_bucket_view,
823
        )
824

825
        # Register a DDP communication hook if any.
826
        if hook is not None:
827
            gpu_model.register_comm_hook(state, hook)
828

829
        return gpu_model
830

831
    @requires_ucc()
832
    @skip_if_lt_x_gpu(2)
833
    def test_ddp_comm_hook_future_passing_gpu_ucc(self):
834
        """
835
        This unit test verifies whether the Future object is passed properly using ucc backend.
836
        The hook callback function creates a Future object and sets a value to it.
837
        """
838
        process_group = self._get_process_group()
839

840
        # Get GPU model with simple_hook registered.
841
        gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)
842

843
        # check whether the grads are equal to what simple_hook's then callback returns.
844
        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
845
        self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
846

847
    @requires_ucc()
848
    def test_ddp_invalid_comm_hook_init(self):
849
        """
850
        This unit test makes sure that register_comm_hook properly checks the format
851
        of hook defined by user. The Python hook must be callable. This test also
852
        checks whether bucket annotation checked properly if defined.
853
        """
854
        process_group = self._get_process_group()
855

856
        model = DistributedDataParallel(
857
            ModuleForDdpCommHook(), process_group=process_group
858
        )
859

860
        with self.assertRaisesRegex(TypeError, "Communication hook must be callable."):
861
            model.register_comm_hook(state=None, hook=1)
862

863
        with self.assertRaisesRegex(
864
            ValueError, "bucket annotation should be dist.GradBucket."
865
        ):
866

867
            def comm_hook(
868
                state: object, bucket: int
869
            ) -> torch.futures.Future[torch.Tensor]:
870
                return torch.futures.Future()
871

872
            model.register_comm_hook(state=None, hook=comm_hook)
873

874
    @requires_ucc()
875
    def test_ddp_invalid_comm_hook_return_type(self):
876
        """
877
        This test checks whether return annotation checked properly if defined. It also
878
        checks whether an internal error is thrown if return type is incorrect and user
879
        hasn't specified any return type annotation.
880
        """
881
        process_group = self._get_process_group()
882

883
        model = DistributedDataParallel(
884
            ModuleForDdpCommHook(), process_group=process_group
885
        )
886

887
        expected_err = (
888
            "Communication hook: return annotation should be torch.futures.Future"
889
        )
890
        with self.assertRaisesRegex(
891
            ValueError,
892
            expected_err,
893
        ):
894

895
            def comm_hook(state: object, bucket: dist.GradBucket) -> int:
896
                return torch.futures.Future()
897

898
            model.register_comm_hook(state=None, hook=comm_hook)
899

900
        verify_ddp_error_logged(model, expected_err)
901

902
        with self.assertRaisesRegex(
903
            RuntimeError,
904
            "callback must return a torch.futures.Future object, but got",
905
        ):
906

907
            def comm_hook(state: object, bucket: dist.GradBucket):
908
                return 1
909

910
            model.register_comm_hook(state=None, hook=comm_hook)
911

912
            # Run forward
913
            output = model(8, self.rank)
914

915
            # Run backward
916
            output.mean().backward()
917

918
    @requires_ucc()
919
    def test_ddp_comm_hook_register_just_once(self):
920
        """
921
        DDP communication hook can only be registered once. This test validates whether
922
        the error is thrown properly when register_comm_hook is called more than once.
923
        """
924
        process_group = self._get_process_group()
925

926
        model = DistributedDataParallel(
927
            ModuleForDdpCommHook(), process_group=process_group
928
        )
929

930
        def dummy_hook(state, bucket):
931
            fut = torch.futures.Future()
932
            fut.set_result([bucket.buffer()])
933
            return fut
934

935
        model.register_comm_hook(None, dummy_hook)
936

937
        with self.assertRaisesRegex(
938
            RuntimeError,
939
            "register_comm_hook or register_builtin_comm_hook can only be called once.",
940
        ):
941
            model.register_comm_hook(None, dummy_hook)
942

943
    # TODO: backward pass: input tensor must be dense
944
    @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
945
    @requires_ucc()
946
    def test_ddp_comm_hook_sparse_gradients(self):
947
        """
948
        Runs "test_sparse_gradients" unit test with DDP communication hook. We define a
949
        simple hook that does allreduce and works with ucc backend for this test.
950
        """
951
        process_group = self._get_process_group()
952

953
        # Ensure initialized weights and inputs are identical across processes
954
        torch.manual_seed(1337)
955

956
        vanilla_model = SparseGradientModule()
957
        ddp_model = DistributedDataParallel(
958
            copy.deepcopy(vanilla_model),
959
            process_group=process_group,
960
        )
961

962
        def allreduce_hook_ucc(
963
            state: object, bucket: dist.GradBucket
964
        ) -> torch.futures.Future[torch.Tensor]:
965
            def div_by_world_size(fut):
966
                # Divide the result by 2 * world_size.
967
                return fut.wait()[0] / self.world_size
968

969
            # Prepare allreduced grad bucket tensors by running an async work.
970
            fut = process_group.allreduce([bucket.buffer()]).get_future()
971
            return fut.then(div_by_world_size)
972

973
        ddp_model.register_comm_hook(None, allreduce_hook_ucc)
974

975
        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
976

977

978
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
979
    @property
980
    def device(self):
981
        return "cpu"
982

983
    def setUp(self):
984
        super().setUp()
985
        self._spawn_processes()
986

987
    def tearDown(self):
988
        super().tearDown()
989
        try:
990
            os.remove(self.file_name)
991
        except OSError:
992
            pass
993

994
    @requires_ucc()
995
    @skip_if_lt_x_gpu(2)
996
    def test_sequence_num_set_default_pg_ucc(self):
997
        self._test_sequence_num_set_default_pg(backend="ucc")
998

999
    @requires_ucc()
1000
    @skip_if_lt_x_gpu(2)
1001
    def test_sequence_num_set_ucc_new_group(self):
1002
        self._test_sequence_num_set_new_group(backend="ucc")
1003

1004
    @skip_if_lt_x_gpu(2)
1005
    @requires_ucc()
1006
    def test_sequence_num_incremented_ucc_default(self):
1007
        self._test_sequence_num_incremented_default_group("ucc")
1008

1009
    @skip_if_lt_x_gpu(4)
1010
    @requires_ucc()
1011
    def test_sequence_num_incremented_ucc_subgroup(self):
1012
        if self.world_size < 4:
1013
            return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
1014
        self._test_sequence_num_incremented_subgroup("ucc")
1015

1016
    @skip_but_pass_in_sandcastle("Fails on M60")
1017
    @requires_ucc()
1018
    def test_ucc_barrier_device_ids(self):
1019
        store = c10d.FileStore(self.file_name, self.world_size)
1020
        c10d.init_process_group(
1021
            backend="ucc", rank=self.rank, world_size=self.world_size, store=store
1022
        )
1023

1024
        with self.assertRaisesRegex(RuntimeError, "device_ids not supported"):
1025
            c10d.barrier(device_ids=[self.rank])
1026

1027
    @skip_but_pass_in_sandcastle("Fails on M60")
1028
    @skip_if_lt_x_gpu(2)
1029
    @requires_ucc()
1030
    def test_ucc_warn_not_in_group(self):
1031
        self._test_warn_not_in_group(backend="ucc")
1032

1033
    @skip_if_lt_x_gpu(2)
1034
    @requires_ucc()
1035
    def test_ucc_rank_membership(self):
1036
        self._test_rank_membership(backend="ucc")
1037

1038
    @skip_if_lt_x_gpu(2)
1039
    @requires_ucc()
1040
    def test_tensor_dtype_mismatch(self):
1041
        self._test_tensor_dtype_mismatch(backend="ucc")
1042

1043
    @skip_if_lt_x_gpu(2)
1044
    @requires_ucc()
1045
    def test_tensor_dtype_complex(self):
1046
        self._test_tensor_dtype_complex(backend="ucc")
1047

1048

1049
class UccProcessGroupWithDispatchedCollectivesTests(
1050
    test_c10d_common.ProcessGroupWithDispatchedCollectivesTests
1051
):
1052
    @skip_but_pass_in_sandcastle("Fails on M60")
1053
    @requires_ucc()
1054
    @skip_if_lt_x_gpu(1)
1055
    def test_collectives(self):
1056
        # includes reduce, broadcast, all_reduce, all_gather, reduce_scatter, barrier, all_to_all, scatter
1057
        self._test_collectives(backend="ucc")
1058

1059
    @skip_but_pass_in_sandcastle("Fails on M60")
1060
    @requires_ucc()
1061
    @skip_if_lt_x_gpu(1)
1062
    def test_allgather_base(self):
1063
        store = dist.FileStore(self.file_name, self.world_size)
1064
        dist.init_process_group(
1065
            "ucc",
1066
            world_size=self.world_size,
1067
            rank=self.rank,
1068
            store=store,
1069
        )
1070
        device = "cuda"
1071
        tensor = torch.ones(10, 10, device=torch.device(device))
1072
        output_tensor = torch.zeros(10, 10, device=torch.device(device))
1073
        dist.all_gather_into_tensor(output_tensor, tensor)
1074
        self.assertEqual(output_tensor, tensor)
1075

1076

1077
if __name__ == "__main__":
1078
    assert (
1079
        not torch.cuda._initialized
1080
    ), "test_distributed must not have initialized CUDA context on main process"
1081

1082
    run_tests()
1083

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.