pytorch

Форк
0
/
test_c10d_ucc.py 
1141 строка · 38.9 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import copy
4
import logging
5
import math
6
import operator
7
import os
8
import random
9
import sys
10
import tempfile
11
from functools import reduce
12

13
import torch
14
import torch.distributed as c10d
15

16
if not c10d.is_available() or not c10d.is_ucc_available():
17
    print("c10d UCC not available, skipping tests", file=sys.stderr)
18
    sys.exit(0)
19

20
import test_c10d_common
21
import torch.distributed as dist
22
import torch.nn.functional as F
23
import torch.testing._internal.common_utils as common
24
from test_c10d_common import (
25
    gpus_for_rank,
26
    Task,
27
    ModuleForDdpCommHook,
28
    SparseGradientModule,
29
)
30
from torch import nn
31
from torch.nn.parallel import DistributedDataParallel
32
from torch.testing._internal.common_distributed import (
33
    MultiProcessTestCase,
34
    requires_ucc,
35
    skip_if_lt_x_gpu,
36
    verify_ddp_error_logged,
37
)
38
from torch.testing._internal.common_utils import (
39
    TestCase,
40
    run_tests,
41
    retry_on_connect_failures,
42
    skip_but_pass_in_sandcastle,
43
)
44

45

46
def simple_reduce_tests(rank, world_size):
47
    tests = [
48
        (
49
            c10d.ReduceOp.SUM,
50
            torch.tensor([rank + 1.0]),
51
            torch.tensor([float(world_size * (world_size + 1) / 2)]),
52
        ),
53
        (
54
            c10d.ReduceOp.PRODUCT,
55
            torch.tensor([rank + 1.0]),
56
            torch.tensor([float(math.factorial(world_size))]),
57
        ),
58
        (
59
            c10d.ReduceOp.MIN,
60
            torch.tensor([rank + 1.0]),
61
            torch.tensor([1.0]),
62
        ),
63
        (
64
            c10d.ReduceOp.MAX,
65
            torch.tensor([rank + 1.0]),
66
            torch.tensor([world_size]),
67
        ),
68
    ]
69

70
    # Generate tests for BAND.
71
    # The bit that is set changes in every iteration to check
72
    # that the output changes accordingly.
73
    for i in range(4):
74
        vin = rank | (1 << i)
75
        vout = 1 << i
76
        tests.append(
77
            (
78
                c10d.ReduceOp.BAND,
79
                torch.tensor([vin], dtype=torch.int32),
80
                torch.tensor([vout], dtype=torch.int32),
81
            ),
82
        )
83

84
    # Generate tests for BOR.
85
    # These emulate a larger world size per iteration by having every
86
    # rank contribute multiple values that are pre-OR'ed.
87
    for i in range(1, 5):
88
        vin = reduce(operator.or_, [rank * i + j for j in range(i)])
89
        vout = reduce(operator.or_, range(world_size * i))
90
        tests.append(
91
            (
92
                c10d.ReduceOp.BOR,
93
                torch.tensor([vin], dtype=torch.int32),
94
                torch.tensor([vout], dtype=torch.int32),
95
            ),
96
        )
97

98
    # Generate tests for XOR.
99
    # These emulate a larger world size per iteration by having every
100
    # rank contribute multiple values that are pre-XOR'ed.
101
    for i in range(1, 5):
102
        vin = reduce(operator.xor, [rank * i + j for j in range(i)])
103
        vout = reduce(operator.xor, range(world_size * i))
104
        tests.append(
105
            (
106
                c10d.ReduceOp.BXOR,
107
                torch.tensor([vin], dtype=torch.int32),
108
                torch.tensor([vout], dtype=torch.int32),
109
            ),
110
        )
111

112
    return tests
113

114

115
class RendezvousEnvTest(TestCase):
116
    @requires_ucc()
117
    @retry_on_connect_failures
118
    def test_logging_init(self):
119
        os.environ["WORLD_SIZE"] = "1"
120
        os.environ["MASTER_ADDR"] = "127.0.0.1"
121
        os.environ["MASTER_PORT"] = str(common.find_free_port())
122
        os.environ["RANK"] = "0"
123

124
        previous_handlers = logging.root.handlers
125

126
        c10d.init_process_group(backend="ucc", init_method="env://")
127

128
        current_handlers = logging.root.handlers
129
        self.assertEqual(len(previous_handlers), len(current_handlers))
130
        for current, previous in zip(current_handlers, previous_handlers):
131
            self.assertEqual(current, previous)
132

133
        c10d.destroy_process_group()
134

135

136
class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
137
    @requires_ucc()
138
    @retry_on_connect_failures
139
    def test_default_store_timeout_ucc(self):
140
        self._test_default_store_timeout("ucc")
141

142

143
class ProcessGroupUCCTest(MultiProcessTestCase):
144
    def _create_process_group_ucc(self):
145
        store = c10d.FileStore(self.file_name, self.world_size)
146
        return c10d.ProcessGroupUCC(store, self.rank, self.world_size)
147

148
    def setUp(self):
149
        super().setUp()
150
        self._spawn_processes()
151

152
    def tearDown(self):
153
        super().tearDown()
154
        try:
155
            os.remove(self.file_name)
156
        except OSError:
157
            pass
158

159
    @requires_ucc()
160
    def test_empty_tensors(self):
161
        pg = self._create_process_group_ucc()
162

163
        xs = [torch.FloatTensor([])]
164
        fut = pg.broadcast(xs).get_future()
165
        fut.wait()
166
        output = fut.value()
167
        self.assertEqual(0, output[0].numel())
168
        self.assertEqual(xs[0], output[0], exact_dtype=False)
169

170
    # TODO: add error check testing
171

172
    def _test_broadcast_basics(self, fn):
173
        pg = self._create_process_group_ucc()
174

175
        def broadcast(xs, rootRank, rootTensor):
176
            opts = c10d.BroadcastOptions()
177
            opts.rootRank = rootRank
178
            opts.rootTensor = rootTensor
179
            fut = pg.broadcast(xs, opts).get_future()
180
            fut.wait()
181
            return fut.value()
182

183
        # Every rank is root once
184
        for i in range(self.world_size):
185
            # Run with 1 input tensor
186
            x = fn(torch.tensor([self.rank]))
187
            output = broadcast([x], i, 0)
188
            self.assertEqual(torch.tensor([i]), output[0], exact_dtype=False)
189

190
            # TODO: UCC currently does not support multi tensor input
191

192
        # Test overloaded convenience function
193
        x = torch.tensor([self.rank + 1.0])
194
        fut = pg.broadcast(x, root=0).get_future()
195
        fut.wait()
196
        result = fut.value()
197
        self.assertEqual(torch.tensor([1.0]), result[0])
198

199
    @requires_ucc()
200
    def test_broadcast_basics(self):
201
        self._test_broadcast_basics(lambda t: t.clone())
202

203
    # TODO: test_broadcast_basics_cuda times out locally
204

205
    def _test_allreduce_basics(self, fn):
206
        pg = self._create_process_group_ucc()
207

208
        # Single input tests
209
        tests = simple_reduce_tests(self.rank, self.world_size)
210
        for (op, input, expected) in tests:
211
            opts = c10d.AllreduceOptions()
212
            opts.reduceOp = op
213
            tensor = fn(input)
214
            fut = pg.allreduce([tensor], opts).get_future()
215
            fut.wait()
216
            result = fut.value()
217
            self.assertEqual(expected, result[0], exact_dtype=False)
218

219
        # TODO: UCC currently does not support multi tensor input
220

221
        # Test overloaded convenience function (defaults to using sum)
222
        x = fn(torch.tensor([self.rank + 1.0]))
223
        fut = pg.allreduce(x).get_future()
224
        fut.wait()
225
        result = fut.value()
226
        self.assertEqual(
227
            torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
228
            result[0],
229
        )
230

231
    @requires_ucc()
232
    def test_allreduce_basics(self):
233
        self._test_allreduce_basics(lambda t: t.clone())
234

235
    # TODO: test_allreduce_basics_cuda times out locally
236

237
    def _test_allgather_basics(self, fn):
238
        pg = self._create_process_group_ucc()
239

240
        # TODO: Run with N input tensor per rank; for now, UCC only supports single tensor input so N=1
241
        for n in [1]:
242
            input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)]
243
            output = [
244
                [fn(torch.tensor([-1])) for _ in range(n * self.world_size)]
245
                for _ in range(n)
246
            ]
247
            expected_output = [
248
                [fn(torch.tensor([i])) for i in range(n * self.world_size)]
249
                for _ in range(n)
250
            ]
251
            fut = pg.allgather(output, input).get_future()
252
            fut.wait()
253
            result = fut.value()
254
            if n == 1:
255
                result = [result]
256
            self.assertEqual(expected_output, result)
257

258
    def test_allgather_basics(self):
259
        self._test_allgather_basics(lambda t: t.clone())
260

261
    def _test_reduce_basics(self, fn):
262
        pg = self._create_process_group_ucc()
263
        for (op, input, output) in simple_reduce_tests(self.rank, self.world_size):
264
            for root in range(self.world_size):
265
                opts = c10d.ReduceOptions()
266
                opts.reduceOp = op
267
                opts.rootRank = root
268
                tmp = fn(input)
269
                fut = pg.reduce([tmp], opts).get_future()
270
                fut.wait()
271
                result = fut.value()
272
                if root == self.rank:
273
                    self.assertEqual(output, result[0], exact_dtype=False)
274

275
    @requires_ucc()
276
    def test_reduce_basics(self):
277
        self._test_reduce_basics(lambda t: t.clone())
278

279
    # TODO: test_reduce_basics_cuda times out locally
280

281
    @requires_ucc()
282
    def test_send_recv_all_to_all(self):
283
        pg = self._create_process_group_ucc()
284

285
        # Preallocate tensors for input/output
286
        inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)]
287
        outputs = [torch.tensor([-1]) for _ in range(self.world_size)]
288

289
        # Issue sends
290
        send_work = []
291
        for i in range(self.world_size):
292
            if i == self.rank:
293
                continue
294
            send_work.append(pg.send([inputs[i]], i, 0))
295

296
        # Issue recvs
297
        recv_work = []
298
        for i in range(self.world_size):
299
            if i == self.rank:
300
                continue
301
            recv_work.append(pg.recv([outputs[i]], i, 0))
302

303
        # Wait for sends to complete
304
        for work in send_work:
305
            work.wait()
306
            self.assertTrue(work.is_completed())
307

308
        # Wait for recvs to complete
309
        for work in recv_work:
310
            work.wait()
311
            self.assertTrue(work.is_completed())
312

313
        # Test that every output other than our own contains the respective rank
314
        for i in range(self.world_size):
315
            if i == self.rank:
316
                continue
317
            self.assertEqual(torch.tensor([i]), outputs[i])
318

319
    # TODO: test_barrier_implies_wait fails with numerical mismatch, will investigate later
320
    @skip_but_pass_in_sandcastle("fails with numerical mismatch, skip for now")
321
    @requires_ucc()
322
    def test_barrier_implies_wait(self):
323
        pg = self._create_process_group_ucc()
324

325
        # Kick off allreduce operations
326
        size = (100, 100)
327
        num = 16
328
        tensors = [torch.full(size, float(i)) for i in range(num)]
329
        for tensor in tensors:
330
            # Note: leak the returned work handle
331
            pg.allreduce(tensor)
332

333
        # Barrier should ensure all previous work has completed
334
        pg.barrier().get_future().wait()
335

336
        for i, tensor in enumerate(tensors):
337
            self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
338

339

340
class DistributedDataParallelTest(
341
    test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
342
):
343
    def setUp(self):
344
        super().setUp()
345
        self._spawn_processes()
346

347
    def _get_process_group(self):
348
        store = self._get_store()
349
        c10d.init_process_group("ucc", store=store, rank=self.rank, world_size=self.world_size)
350
        return c10d.distributed_c10d._get_default_group()
351

352
    def _test_ucc_backend(
353
        self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
354
    ):
355
        process_group = self._get_process_group()
356
        self._test_ddp_with_process_group(
357
            process_group, devices, device_ids, multi_device, gradient_as_bucket_view
358
        )
359

360
    @requires_ucc()
361
    def test_ucc_backend_cpu_module(self):
362
        self._test_ucc_backend([torch.device("cpu")], None)
363

364
    @requires_ucc()
365
    def test_ucc_backend_cpu_module_grad_is_view(self):
366
        self._test_ucc_backend(
367
            [torch.device("cpu")], None, gradient_as_bucket_view=True
368
        )
369

370
    @requires_ucc()
371
    @skip_if_lt_x_gpu(2)
372
    def test_ucc_backend_1gpu_module_device_ids_integer_list(self):
373
        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
374
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
375
        self._test_ucc_backend(devices, int_devices)
376

377
    @requires_ucc()
378
    @skip_if_lt_x_gpu(2)
379
    def test_ucc_backend_1gpu_module_device_ids_torch_device_list(self):
380
        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
381
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
382
        self._test_ucc_backend(devices, devices)
383

384
    # TODO: test_ucc_backend_2gpu_module and test_ucc_backend_4gpu_module
385
    # require broadcast_coalesced which is not supported by ucc currently
386
    @skip_but_pass_in_sandcastle("requires broadcast coalesced, which is not supported by ucc currently")
387
    @requires_ucc()
388
    @skip_if_lt_x_gpu(4)
389
    def test_ucc_backend_2gpu_module(self):
390
        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
391
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
392
        self._test_ucc_backend(devices, None, multi_device=True)
393

394
    @skip_but_pass_in_sandcastle("requires broadcast coalesced, which is not supported by ucc currently")
395
    @requires_ucc()
396
    @skip_if_lt_x_gpu(8)
397
    def test_ucc_backend_4gpu_module(self):
398
        int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
399
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
400
        self._test_ucc_backend(devices, None, multi_device=True)
401

402
    def _test_global_local_unused_params_grad(
403
        self, gradient_as_bucket_view=False, static_graph=False
404
    ):
405
        """
406
        By simulating a multi-task training, this test is to make sure:
407
        1) DDP does not touch the grad of globally unused parameters.
408
        2) DDP does update the grad of locally unused parameters.
409
        """
410

411
        class GlobalLocalUnusedParamModule(nn.Module):
412
            def __init__(self):
413
                super().__init__()
414
                self.t0 = Task()
415
                self.t1 = Task()
416
                self.task_unused = Task()
417

418
            def task_parameters(self):
419
                return (self.t0.p, self.t1.p, self.task_unused.p)
420

421
            def forward(self, x, rank):
422
                return self.t0(x) if rank == 0 else self.t1(x)
423

424
        def run_and_verify_grad(model):
425
            # Run forward
426
            output = model(8, self.rank)
427

428
            # The grads of all parameters should be None at this point.
429
            t0_p, t1_p, task_unused_p = model.module.task_parameters()
430
            self.assertIsNone(t0_p.grad)
431
            self.assertIsNone(t1_p.grad)
432
            self.assertIsNone(task_unused_p.grad)
433

434
            # Run backward
435
            output.mean().backward()
436

437
            # Now locally unused parameter should have grad updated on all ranks.
438
            # However the globally unused parameter should still have None grad.
439
            self.assertIsNotNone(t0_p.grad)
440
            self.assertIsNotNone(t1_p.grad)
441
            self.assertIsNone(task_unused_p.grad)
442

443
        process_group = self._get_process_group()
444

445
        # Test on CPU
446
        cpu_model = DistributedDataParallel(
447
            GlobalLocalUnusedParamModule().cpu(),
448
            process_group=process_group,
449
            find_unused_parameters=True,
450
            gradient_as_bucket_view=gradient_as_bucket_view,
451
            static_graph=static_graph,
452
        )
453
        run_and_verify_grad(cpu_model)
454

455
        # Test on GPU
456
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
457
        gpu_model = DistributedDataParallel(
458
            GlobalLocalUnusedParamModule().to(device_id),
459
            device_ids=[device_id],
460
            process_group=process_group,
461
            find_unused_parameters=True,
462
            gradient_as_bucket_view=gradient_as_bucket_view,
463
            static_graph=static_graph,
464
        )
465
        run_and_verify_grad(gpu_model)
466

467
    # TODO: times out
468
    @skip_but_pass_in_sandcastle("times out")
469
    @requires_ucc()
470
    @skip_if_lt_x_gpu(2)
471
    def test_global_local_unused_params_grad(self):
472
        self._test_global_local_unused_params_grad()
473

474
    # TODO: times out
475
    @skip_but_pass_in_sandcastle("times out")
476
    @requires_ucc()
477
    @skip_if_lt_x_gpu(2)
478
    def test_global_local_unused_params_grad_with_grad_is_view(self):
479
        self._test_global_local_unused_params_grad(gradient_as_bucket_view=True)
480

481
    # TODO: times out
482
    @skip_but_pass_in_sandcastle("times out")
483
    @requires_ucc()
484
    @skip_if_lt_x_gpu(2)
485
    def test_global_local_unused_params_grad_with_static_graph(self):
486
        self._test_global_local_unused_params_grad(static_graph=True)
487

488
    # TODO: times out
489
    @skip_but_pass_in_sandcastle("times out")
490
    @requires_ucc()
491
    @skip_if_lt_x_gpu(2)
492
    def test_find_unused_parameters_when_unused_parameters_empty(self):
493
        """
494
        An empty unused_parameters array does not imply find_unused_parameters =
495
        false. This test makes sure that DDP allreduces unused parameters
496
        accordingly where the forward pass in some process uses all parameters.
497
        This unit test creates a module that uses all parameters in rank = 0, and
498
        has unused parameters in other ranks.
499
        """
500

501
        class FindUnusedParamModule(nn.Module):
502
            def __init__(self):
503
                super().__init__()
504
                self.t0 = Task()
505
                self.t1 = Task()
506

507
            def task_parameters(self):
508
                return (self.t0.p, self.t1.p)
509

510
            def forward(self, x, rank):
511
                return self.t1(self.t0(x)) if rank == 0 else self.t1(x)
512

513
        def run_and_verify_grad(model):
514
            # Run forward
515
            output = model(8, self.rank)
516

517
            # The grads of all parameters should be None at this point.
518
            [self.assertIsNone(t_p.grad) for t_p in model.module.task_parameters()]
519

520
            # Run backward
521
            output.mean().backward()
522

523
            # Now locally unused parameter should have grad updated on all ranks.
524
            [self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()]
525

526
        process_group = self._get_process_group()
527

528
        # Test on CPU
529
        cpu_model = DistributedDataParallel(
530
            FindUnusedParamModule().cpu(),
531
            process_group=process_group,
532
            find_unused_parameters=True,
533
        )
534
        run_and_verify_grad(cpu_model)
535

536
        # Test on GPU
537
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
538
        gpu_model = DistributedDataParallel(
539
            FindUnusedParamModule().to(device_id),
540
            device_ids=[device_id],
541
            process_group=process_group,
542
            find_unused_parameters=True,
543
        )
544
        run_and_verify_grad(gpu_model)
545

546
    @requires_ucc()
547
    def test_ignored_output(self):
548
        """
549
        Test that the output of a model can be ignored and that there is no
550
        implicit requirement that `backward` gets called.
551
        """
552
        process_group = self._get_process_group()
553

554
        class IgnoredOutput(nn.Module):
555
            def __init__(self):
556
                super().__init__()
557
                self.fc1 = nn.Linear(2, 10, bias=False)
558
                self.fc2 = nn.Linear(10, 4, bias=False)
559
                self.relu = nn.ReLU()
560

561
            def forward(self, x):
562
                x = self.relu(self.fc1(x))
563
                x = self.relu(self.fc2(x))
564
                return F.softmax(x, dim=1)
565

566
        model = DistributedDataParallel(
567
            IgnoredOutput().float(),
568
            process_group=process_group,
569
        )
570

571
        batch_size = 4
572
        criterion = nn.CrossEntropyLoss()
573
        input = torch.rand([batch_size, 2], dtype=torch.float)
574
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
575

576
        # Run a few iterations where we ignore the output.
577
        for _ in range(4):
578
            output = model(input)
579
            del output
580

581
        # Run a few iterations where we use the output.
582
        for _ in range(4):
583
            output = model(input)
584
            loss = criterion(output, target)
585
            loss.backward()
586

587
    @requires_ucc()
588
    def test_ignored_output_with_unused_parameters(self):
589
        """
590
        Test that the output of a model can be ignored and that there is no
591
        implicit requirement that `backward` gets called, if not all model
592
        parameters participated in computing the model output.
593
        """
594
        process_group = self._get_process_group()
595

596
        class IgnoredOutputWithUnusedParameters(nn.Module):
597
            def __init__(self):
598
                super().__init__()
599
                self.fc1 = nn.Linear(2, 10, bias=False)
600
                self.fc2 = nn.Linear(10, 4, bias=False)
601
                self.fc3 = nn.Linear(4, 4, bias=False)
602
                self.relu = nn.ReLU()
603

604
            def forward(self, x):
605
                x = self.relu(self.fc1(x))
606
                x = self.relu(self.fc2(x))
607
                return F.softmax(x, dim=1)
608

609
        model = DistributedDataParallel(
610
            IgnoredOutputWithUnusedParameters().float(),
611
            process_group=process_group,
612
            find_unused_parameters=True,
613
        )
614

615
        batch_size = 4
616
        criterion = nn.CrossEntropyLoss()
617
        input = torch.rand([batch_size, 2], dtype=torch.float)
618
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
619

620
        # Run a few iterations where we ignore the output.
621
        for _ in range(4):
622
            output = model(input)
623
            del output
624

625
        # Run a few iterations where we use the output.
626
        for _ in range(4):
627
            output = model(input)
628
            loss = criterion(output, target)
629
            loss.backward()
630

631
    def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):
632
        mult = 2
633
        batch_size = mult * self.world_size
634
        criterion = nn.CrossEntropyLoss()
635
        input = torch.randint(0, 10, [batch_size, 2])
636
        target = torch.randint(0, 10, [batch_size])
637

638
        # Run with entire batch against single process version
639
        criterion(vanilla_model(input), target).backward()
640

641
        # Run with partial batch against multi process version
642
        partial_input = input.split(mult)[self.rank]
643
        partial_target = target.split(mult)[self.rank]
644
        criterion(ddp_model(partial_input), partial_target).backward()
645

646
        # Check that the gradients are sparse and identical
647
        vanilla_parameter = next(vanilla_model.parameters())
648
        ddp_parameter = next(ddp_model.parameters())
649
        self.assertEqual(vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce())
650

651
    @requires_ucc()
652
    @skip_if_lt_x_gpu(2)
653
    def test_save_load_checkpoint(self):
654
        dist.init_process_group(
655
            "ucc",
656
            init_method=f"file://{self.file_name}",
657
            world_size=self.world_size,
658
            rank=self.rank,
659
        )
660

661
        class TestModel(nn.Module):
662
            def __init__(self):
663
                super().__init__()
664
                self.fc1 = nn.Linear(2, 10, bias=False)
665
                self.fc2 = nn.Linear(10, 4, bias=False)
666
                self.relu = nn.ReLU()
667

668
            def forward(self, x):
669
                x = self.relu(self.fc1(x))
670
                x = self.relu(self.fc2(x))
671
                return F.softmax(x, dim=1)
672

673
        def train_loop(model, optimizer, iterations):
674
            for _ in range(iterations):
675
                optimizer.zero_grad()
676
                output = model(input)
677
                loss = criterion(output, target)
678
                loss.backward()
679
                optimizer.step()
680

681
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
682

683
        model_withload = TestModel().float().to(device_id)
684
        model_withoutload = TestModel().float().to(device_id)
685

686
        ddp_withload = DistributedDataParallel(
687
            model_withload,
688
            device_ids=[device_id],
689
        )
690
        ddp_withoutload = DistributedDataParallel(
691
            model_withoutload,
692
            device_ids=[device_id],
693
        )
694

695
        # ensure that all the three models start with the same set of parameters. By default they are randomized on construction
696
        for p in ddp_withload.parameters():
697
            with torch.no_grad():
698
                p.zero_()
699
        for p in model_withload.parameters():
700
            with torch.no_grad():
701
                p.zero_()
702
        for p in ddp_withoutload.parameters():
703
            with torch.no_grad():
704
                p.zero_()
705

706
        batch_size = 4
707
        criterion = nn.CrossEntropyLoss()
708

709
        optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001)
710
        optimizer_non_ddp_withload = torch.optim.SGD(
711
            model_withload.parameters(), lr=0.001
712
        )
713
        optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001)
714

715
        input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)
716
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
717
            device_id
718
        )
719

720
        # run the model for 6 iterations, with a checkpoint in the middle
721
        train_loop(ddp_withload, optimizer_withload, 3)
722

723
        # zero out parameters of both DDP and non-DDP models and reload them from the DDP state dict
724
        checkpoint_path = tempfile.gettempdir() + "/model.checkpoint"
725
        if self.rank == 0:
726
            torch.save(ddp_withload.state_dict(), checkpoint_path)
727

728
        dist.barrier()
729
        map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
730
        ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
731

732
        for model in [ddp_withload, model_withload]:
733
            for p in ddp_withload.parameters():
734
                with torch.no_grad():
735
                    p.zero_()
736
        ddp_withload.load_state_dict(ddp_state_dict)
737
        # the non-DDP model needs to first remove the prefix of "module." from the DDP state dict
738
        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
739
            ddp_state_dict, "module."
740
        )
741
        model_withload.load_state_dict(ddp_state_dict)
742

743
        train_loop(ddp_withload, optimizer_withload, 3)
744
        train_loop(model_withload, optimizer_non_ddp_withload, 3)
745

746
        # re-run the model with the same inputs for 6 iterations with no checkpoint
747
        train_loop(ddp_withoutload, optimizer_withoutload, 6)
748

749
        for p_withload, p_withoutload, p_non_ddp_withload in zip(
750
            ddp_withload.parameters(),
751
            ddp_withoutload.parameters(),
752
            model_withload.parameters(),
753
        ):
754
            self.assertEqual(p_withload, p_withoutload)
755
            self.assertEqual(p_non_ddp_withload, p_withoutload)
756

757
    def _test_sparse_gradients(self, gradient_as_bucket_view=False):
758
        process_group = self._get_process_group()
759

760
        # Ensure initialized weights and inputs are identical across processes
761
        torch.manual_seed(1337)
762

763
        vanilla_model = SparseGradientModule()
764
        ddp_model = DistributedDataParallel(
765
            copy.deepcopy(vanilla_model),
766
            process_group=process_group,
767
            gradient_as_bucket_view=gradient_as_bucket_view,
768
        )
769

770
        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
771

772
    # TODO: backward pass: input tensor has to be dense
773
    @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
774
    @requires_ucc()
775
    def test_sparse_gradients(self):
776
        self._test_sparse_gradients()
777

778
    # TODO: backward pass: input tensor has to be dense
779
    @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
780
    @requires_ucc()
781
    def test_sparse_gradients_grad_is_view(self):
782
        self._test_sparse_gradients(gradient_as_bucket_view=True)
783

784
    @requires_ucc()
785
    def test_ddp_comm_hook_future_passing_cpu(self):
786
        """
787
        This unit test verifies whether the Future object is passed properly.
788
        The callback function creates a Future object and sets a value to it.
789
        """
790
        process_group = self._get_process_group()
791

792
        # Test on CPU
793
        cpu_model = DistributedDataParallel(
794
            ModuleForDdpCommHook().cpu(), process_group=process_group
795
        )
796

797
        # Register DDP Communication Hook
798
        cpu_model.register_comm_hook(None, self._simple_hook)
799

800
        # check whether the grads are equal to what then callback returns.
801
        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
802
        self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2))
803

804
    def _gpu_model_with_ddp_comm_hook(
805
        self, process_group, hook=None, gradient_as_bucket_view=False, state=None
806
    ):
807
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
808
        gpu_model = DistributedDataParallel(
809
            ModuleForDdpCommHook().to(device_id),
810
            device_ids=[device_id],
811
            process_group=process_group,
812
            gradient_as_bucket_view=gradient_as_bucket_view,
813
        )
814

815
        # Register a DDP communication hook if any.
816
        if hook is not None:
817
            gpu_model.register_comm_hook(state, hook)
818

819
        return gpu_model
820

821
    @requires_ucc()
822
    @skip_if_lt_x_gpu(2)
823
    def test_ddp_comm_hook_future_passing_gpu_ucc(self):
824
        """
825
        This unit test verifies whether the Future object is passed properly using ucc backend.
826
        The hook callback function creates a Future object and sets a value to it.
827
        """
828
        process_group = self._get_process_group()
829

830
        # Get GPU model with simple_hook registered.
831
        gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)
832

833
        # check whether the grads are equal to what simple_hook's then callback returns.
834
        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
835
        self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
836

837
    @requires_ucc()
838
    def test_ddp_invalid_comm_hook_init(self):
839
        """
840
        This unit test makes sure that register_comm_hook properly checks the format
841
        of hook defined by user. The Python hook must be callable. This test also
842
        checks whether bucket annotation checked properly if defined.
843
        """
844
        process_group = self._get_process_group()
845

846
        model = DistributedDataParallel(
847
            ModuleForDdpCommHook(), process_group=process_group
848
        )
849

850
        with self.assertRaisesRegex(TypeError, "Communication hook must be callable."):
851
            model.register_comm_hook(state=None, hook=1)
852

853
        with self.assertRaisesRegex(
854
            ValueError, "bucket annotation should be dist.GradBucket."
855
        ):
856

857
            def comm_hook(
858
                state: object, bucket: int
859
            ) -> torch.futures.Future[torch.Tensor]:
860
                return torch.futures.Future()
861

862
            model.register_comm_hook(state=None, hook=comm_hook)
863

864
    @requires_ucc()
865
    def test_ddp_invalid_comm_hook_return_type(self):
866
        """
867
        This test checks whether return annotation checked properly if defined. It also
868
        checks whether an internal error is thrown if return type is incorrect and user
869
        hasn't specified any return type annotation.
870
        """
871
        process_group = self._get_process_group()
872

873
        model = DistributedDataParallel(
874
            ModuleForDdpCommHook(), process_group=process_group
875
        )
876

877
        expected_err = "Communication hook: return annotation should be torch.futures.Future"
878
        with self.assertRaisesRegex(
879
            ValueError,
880
            expected_err,
881
        ):
882

883
            def comm_hook(state: object, bucket: dist.GradBucket) -> int:
884
                return torch.futures.Future()
885

886
            model.register_comm_hook(state=None, hook=comm_hook)
887

888
        verify_ddp_error_logged(model, expected_err)
889

890
        with self.assertRaisesRegex(
891
            RuntimeError,
892
            "callback must return a torch.futures.Future object, but got",
893
        ):
894

895
            def comm_hook(state: object, bucket: dist.GradBucket):
896
                return 1
897

898
            model.register_comm_hook(state=None, hook=comm_hook)
899

900
            # Run forward
901
            output = model(8, self.rank)
902

903
            # Run backward
904
            output.mean().backward()
905

906
    @requires_ucc()
907
    def test_ddp_comm_hook_register_just_once(self):
908
        """
909
        DDP communication hook can only be registered once. This test validates whether
910
        the error is thrown properly when register_comm_hook is called more than once.
911
        """
912
        process_group = self._get_process_group()
913

914
        model = DistributedDataParallel(
915
            ModuleForDdpCommHook(), process_group=process_group
916
        )
917

918
        def dummy_hook(state, bucket):
919
            fut = torch.futures.Future()
920
            fut.set_result([bucket.buffer()])
921
            return fut
922

923
        model.register_comm_hook(None, dummy_hook)
924

925
        with self.assertRaisesRegex(
926
            RuntimeError,
927
            "register_comm_hook or register_builtin_comm_hook can only be called once.",
928
        ):
929
            model.register_comm_hook(None, dummy_hook)
930

931
    # TODO: backward pass: input tensor must be dense
932
    @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
933
    @requires_ucc()
934
    def test_ddp_comm_hook_sparse_gradients(self):
935
        """
936
        Runs "test_sparse_gradients" unit test with DDP communication hook. We define a
937
        simple hook that does allreduce and works with ucc backend for this test.
938
        """
939
        process_group = self._get_process_group()
940

941
        # Ensure initialized weights and inputs are identical across processes
942
        torch.manual_seed(1337)
943

944
        vanilla_model = SparseGradientModule()
945
        ddp_model = DistributedDataParallel(
946
            copy.deepcopy(vanilla_model),
947
            process_group=process_group,
948
        )
949

950
        def allreduce_hook_ucc(
951
            state: object, bucket: dist.GradBucket
952
        ) -> torch.futures.Future[torch.Tensor]:
953
            def div_by_world_size(fut):
954
                # Divide the result by 2 * world_size.
955
                return fut.wait()[0] / self.world_size
956

957
            # Prepare allreduced grad bucket tensors by running an async work.
958
            fut = process_group.allreduce([bucket.buffer()]).get_future()
959
            return fut.then(div_by_world_size)
960

961
        ddp_model.register_comm_hook(None, allreduce_hook_ucc)
962

963
        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
964

965

966
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
967
    @property
968
    def device(self):
969
        return "cpu"
970

971
    def setUp(self):
972
        super().setUp()
973
        self._spawn_processes()
974

975
    def tearDown(self):
976
        super().tearDown()
977
        try:
978
            os.remove(self.file_name)
979
        except OSError:
980
            pass
981

982
    @requires_ucc()
983
    @skip_if_lt_x_gpu(2)
984
    def test_sequence_num_set_default_pg_ucc(self):
985
        self._test_sequence_num_set_default_pg(backend="ucc")
986

987
    @requires_ucc()
988
    @skip_if_lt_x_gpu(2)
989
    def test_sequence_num_set_ucc_new_group(self):
990
        self._test_sequence_num_set_new_group(backend="ucc")
991

992
    @skip_if_lt_x_gpu(2)
993
    @requires_ucc()
994
    def test_sequence_num_incremented_ucc_default(self):
995
        self._test_sequence_num_incremented_default_group("ucc")
996

997
    @skip_if_lt_x_gpu(4)
998
    @requires_ucc()
999
    def test_sequence_num_incremented_ucc_subgroup(self):
1000
        if self.world_size < 4:
1001
            return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
1002
        self._test_sequence_num_incremented_subgroup("ucc")
1003

1004
    @skip_but_pass_in_sandcastle("Fails on M60")
1005
    @requires_ucc()
1006
    def test_ucc_barrier_device_ids(self):
1007
        store = c10d.FileStore(self.file_name, self.world_size)
1008
        c10d.init_process_group(
1009
            backend="ucc", rank=self.rank, world_size=self.world_size, store=store
1010
        )
1011

1012
        with self.assertRaisesRegex(RuntimeError, "device_ids not supported"):
1013
            c10d.barrier(device_ids=[self.rank])
1014

1015
    @skip_but_pass_in_sandcastle("Fails on M60")
1016
    @skip_if_lt_x_gpu(2)
1017
    @requires_ucc()
1018
    def test_ucc_warn_not_in_group(self):
1019
        self._test_warn_not_in_group(backend="ucc")
1020

1021
    @skip_if_lt_x_gpu(2)
1022
    @requires_ucc()
1023
    def test_ucc_rank_membership(self):
1024
        self._test_rank_membership(backend="ucc")
1025

1026
    @skip_if_lt_x_gpu(2)
1027
    @requires_ucc()
1028
    def test_tensor_dtype_mismatch(self):
1029
        self._test_tensor_dtype_mismatch(backend="ucc")
1030

1031
    @skip_if_lt_x_gpu(2)
1032
    @requires_ucc()
1033
    def test_tensor_dtype_complex(self):
1034
        self._test_tensor_dtype_complex(backend="ucc")
1035

1036

1037
class CompilerTest(test_c10d_common.CompilerTest):
1038

1039
    @property
1040
    def world_size(self):
1041
        return 2
1042

1043
    def _get_default_group(self):
1044
        store = c10d.FileStore(self.file_name, self.world_size)
1045
        dist.init_process_group(
1046
            backend="ucc",
1047
            rank=self.rank,
1048
            world_size=self.world_size,
1049
            store=store,
1050
        )
1051
        return dist.distributed_c10d._get_default_group()
1052

1053
    @skip_if_lt_x_gpu(2)
1054
    def test_allreduce_work_wait_gpu(self):
1055
        self._test_allreduce_work_wait(
1056
            torch.ones(2, 2, device=self.rank) * self.rank,
1057
        )
1058

1059
    @skip_if_lt_x_gpu(2)
1060
    def test_allgather_work_wait_gpu(self):
1061
        self._test_allgather_work_wait(
1062
            torch.ones(2, 2, device=self.rank) * self.rank
1063
        )
1064

1065
    @skip_if_lt_x_gpu(2)
1066
    def test_broadcast_work_wait_gpu(self):
1067
        self._test_broadcast_work_wait(
1068
            torch.ones(2, 2, device=self.rank) * self.rank
1069
        )
1070

1071
    @skip_if_lt_x_gpu(2)
1072
    def test_nested_comm_tensor_wrapping_gpu(self):
1073
        self._test_nested_comm_tensor_wrapping(
1074
            torch.ones(2, 2, device=self.rank) * self.rank
1075
        )
1076

1077
    @skip_if_lt_x_gpu(2)
1078
    def test_consecutive_comm_work_wait_gpu(self):
1079
        self._test_consecutive_comm_work_wait(
1080
            torch.ones(2, 2, device=self.rank) * self.rank
1081
        )
1082

1083
    def test_allreduce_work_wait_cpu(self):
1084
        self._test_allreduce_work_wait(
1085
            torch.ones(2, 2) * self.rank,
1086
        )
1087

1088
    def test_allgather_work_wait_cpu(self):
1089
        self._test_allgather_work_wait(
1090
            torch.ones(2, 2) * self.rank
1091
        )
1092

1093
    def test_broadcast_work_wait_cpu(self):
1094
        self._test_broadcast_work_wait(
1095
            torch.ones(2, 2) * self.rank
1096
        )
1097

1098
    def test_nested_comm_tensor_wrapping_cpu(self):
1099
        self._test_nested_comm_tensor_wrapping(
1100
            torch.ones(2, 2) * self.rank
1101
        )
1102

1103
    def test_consecutive_comm_work_wait_cpu(self):
1104
        self._test_consecutive_comm_work_wait(
1105
            torch.ones(2, 2) * self.rank
1106
        )
1107

1108

1109
class UccProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGroupWithDispatchedCollectivesTests):
1110

1111
    @skip_but_pass_in_sandcastle("Fails on M60")
1112
    @requires_ucc()
1113
    @skip_if_lt_x_gpu(1)
1114
    def test_collectives(self):
1115
        # includes reduce, broadcast, all_reduce, all_gather, reduce_scatter, barrier, all_to_all, scatter
1116
        self._test_collectives(backend="ucc")
1117

1118
    @skip_but_pass_in_sandcastle("Fails on M60")
1119
    @requires_ucc()
1120
    @skip_if_lt_x_gpu(1)
1121
    def test_allgather_base(self):
1122
        store = dist.FileStore(self.file_name, self.world_size)
1123
        dist.init_process_group(
1124
            "ucc",
1125
            world_size=self.world_size,
1126
            rank=self.rank,
1127
            store=store,
1128
        )
1129
        device = "cuda"
1130
        tensor = torch.ones(10, 10, device=torch.device(device))
1131
        output_tensor = torch.zeros(10, 10, device=torch.device(device))
1132
        dist.all_gather_into_tensor(output_tensor, tensor)
1133
        self.assertEqual(output_tensor, tensor)
1134

1135

1136
if __name__ == "__main__":
1137
    assert (
1138
        not torch.cuda._initialized
1139
    ), "test_distributed must not have initialized CUDA context on main process"
1140

1141
    run_tests()
1142

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.