pytorch

Форк
0
/
test_zero_redundancy_optimizer.py 
1445 строк · 52.5 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
4
#
5
# This source code is licensed under the BSD license found in the
6
# LICENSE file in the root directory of this source tree.
7

8
import copy
9
import os
10
import sys
11
import unittest
12
from contextlib import nullcontext
13
from typing import Any, cast, List
14

15
import numpy as np
16

17
import torch
18
import torch.distributed as dist
19

20

21
if not dist.is_available():
22
    print("Distributed not available, skipping tests", file=sys.stderr)
23
    sys.exit(0)
24
from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
25
    hook_with_zero_step,
26
    hook_with_zero_step_interleaved,
27
)
28
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
29
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
30
from torch.distributed.optim import ZeroRedundancyOptimizer
31
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
32
from torch.nn.parallel import DistributedDataParallel as DDP
33
from torch.optim import AdamW, SGD
34
from torch.testing._internal import common_distributed
35
from torch.testing._internal.common_utils import (
36
    instantiate_parametrized_tests,
37
    IS_WINDOWS,
38
    parametrize,
39
    run_tests,
40
    TEST_WITH_ASAN,
41
    TEST_WITH_DEV_DBG_ASAN,
42
)
43

44

45
try:
46
    import torchvision
47

48
    HAS_TORCHVISION = True
49
except ImportError:
50
    HAS_TORCHVISION = False
51

52

53
# Use GLOO on GPU when running CUDA + Windows
54
def _get_backend_for_tests():
55
    return (
56
        dist.Backend.NCCL
57
        if not IS_WINDOWS and torch.cuda.is_available()
58
        # Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when
59
        # no GPUs are available.
60
        else dist.Backend.GLOO
61
    )
62

63

64
BACKEND = _get_backend_for_tests()
65

66

67
@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.")
68
class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase):
69
    def setUp(self):
70
        super().setUp()
71
        os.environ["WORLD_SIZE"] = str(self.world_size)
72
        self._spawn_processes()
73

74
    @property
75
    def device(self):
76
        return (
77
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
78
        )
79

80
    @property
81
    def world_size(self):
82
        return 1
83

84
    def tearDown(self):
85
        try:
86
            torch.distributed.destroy_process_group()
87
        except AssertionError:
88
            pass
89
        try:
90
            os.remove(self.file_name)
91
        except OSError:
92
            pass
93

94
    def dist_init(self, rank, world_size=-1, backend=BACKEND):
95
        if world_size < 1:
96
            world_size = self.world_size
97
        store = dist.FileStore(self.file_name, world_size)
98
        return dist.init_process_group(
99
            backend=backend,
100
            store=store,
101
            rank=rank,
102
            world_size=world_size,
103
        )
104

105

106
# TODO: skip_but_pass_in_sandcastle_if does not work here.
107
@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.")
108
class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
109
    def test_state_dict(self):
110
        """Check that ZeroRedundancyOptimizer exposes the expected state dict
111
        interface, irrespective of the sharding."""
112
        self.dist_init(self.rank)
113
        LR1 = 0.1
114
        LR2 = 0.01
115
        MOMENTUM = 0.9
116
        RECIPIENT_RANK = 0  # rank 0 is the only rank since the world size is 1
117
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
118
        o = ZeroRedundancyOptimizer(
119
            [x],
120
            optimizer_class=SGD,
121
            lr=LR1,
122
            momentum=MOMENTUM,
123
        )
124
        x.backward()
125
        o.step()
126
        self.assertEqual(x, torch.tensor([0.9], device=self.device))
127
        self.assertEqual(
128
            o.optim.state[x]["momentum_buffer"],
129
            torch.tensor([1.0], device=self.device),
130
        )
131

132
        o.zero_grad()
133
        o.consolidate_state_dict(to=RECIPIENT_RANK)
134
        state_dict = o.state_dict()
135

136
        # Check that the state dict has keys compliant with PyTorch
137
        self.assertIn("param_groups", state_dict.keys())
138
        self.assertIn("state", state_dict.keys())
139

140
        # Check that the state has the expected keys
141
        self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1)
142
        self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9)
143
        self.assertFalse(state_dict["param_groups"][0]["nesterov"])
144
        self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0)
145
        self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0)
146

147
        # Check that the state and the `param_groups` attribute are in sync
148
        for k in state_dict["param_groups"][0]:
149
            if k != "params":
150
                self.assertEqual(
151
                    state_dict["param_groups"][0][k],
152
                    o.param_groups[0][k],
153
                )
154

155
        # Check that the state is reloaded with the correct values and device
156
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR2)
157
        o.load_state_dict(state_dict)
158
        self.assertEqual(
159
            o.optim.state[x]["momentum_buffer"],
160
            torch.tensor([1.0], device=self.device),
161
        )
162

163
        # We should we using `LR1` and not `LR2` after reloading, both within
164
        # the optimizer and as exposed by the `param_groups` attribute
165
        self.assertEqual(o.param_groups[0]["lr"], LR1)
166
        x.backward()
167
        o.step()
168
        self.assertEqual(x, torch.tensor([0.71], device=self.device))
169
        self.assertEqual(
170
            o.optim.state[x]["momentum_buffer"],
171
            torch.tensor([1.9], device=self.device),
172
        )
173

174
        # Check that the exposed `param_groups`` are on the proper device
175
        self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
176

177
    def test_lr_scheduler(self):
178
        """Check that a normal PyTorch ``lr_scheduler`` is usable with
179
        ZeroRedundancyOptimizer."""
180
        self.dist_init(self.rank)
181
        NUM_ITERS = 5
182
        LR = 0.01
183
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
184
        x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
185
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR)
186
        o2 = torch.optim.SGD([x2], lr=LR)
187
        s = torch.optim.lr_scheduler.StepLR(o, 1)
188
        s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
189
        for _ in range(NUM_ITERS):
190
            x.backward()
191
            o.zero_grad()
192
            o.step()
193
            s.step()
194
            x2.backward()
195
            o2.zero_grad()
196
            o2.step()
197
            s2.step()
198
            self.assertEqual(x, x2)
199

200
    def test_step_with_kwargs(self):
201
        """Check that the ``step(**kwargs)`` interface is properly exposed."""
202
        self.dist_init(self.rank)
203
        LR = 0.1
204

205
        class SGDWithStepKWArg(torch.optim.SGD):
206
            def step(self, closure=None, kwarg=None):
207
                super().step()
208
                kwarg.append(5)
209

210
        kwarg: List[Any] = []
211
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
212
        o = ZeroRedundancyOptimizer(
213
            [x],
214
            optimizer_class=SGDWithStepKWArg,
215
            lr=LR,
216
        )
217
        x.backward()
218
        o.step(0, kwarg=kwarg)
219
        self.assertEqual(kwarg, [5])
220
        self.assertEqual(x, torch.tensor([0.9], device=self.device))
221

222
    def test_step_with_extra_inner_key(self):
223
        """Check that ZeroRedundancyOptimizer wrapping an optimizer that adds
224
        extra keys to ``param_groups`` exposes those keys through ZeRO's own
225
        ``param_groups``."""
226
        self.dist_init(self.rank)
227
        LR = 0.1
228

229
        class SGDWithNewKey(torch.optim.SGD):
230
            # Dummy optimizer which adds a new key to the param groups
231
            def step(self, closure=None):
232
                super().step()
233
                self.param_groups[0]["new_key"] = 0.1
234

235
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
236
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGDWithNewKey, lr=LR)
237
        x.backward()
238
        o.step()
239
        self.assertEqual(o.param_groups[0]["new_key"], 0.1)
240
        self.assertEqual(x, torch.tensor([0.9], device=self.device))
241

242
    def test_step_without_closure(self):
243
        """Check that the ``step()`` method (without closure) is handled as
244
        expected."""
245
        self.dist_init(self.rank)
246
        LR = 0.1
247

248
        class SGDWithoutClosure(torch.optim.SGD):
249
            def step(self):
250
                return super().step()
251

252
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
253
        o = ZeroRedundancyOptimizer(
254
            [x],
255
            optimizer_class=SGDWithoutClosure,
256
            lr=LR,
257
        )
258
        x.backward()
259
        o.step()
260
        self.assertEqual(x, torch.tensor([0.9], device=self.device))
261

262
    def test_zero_grad(self):
263
        """Check that the ``zero_grad`` method is properly handled."""
264
        self.dist_init(self.rank)
265
        LR = 0.01
266
        x = torch.rand(1)
267
        m = torch.nn.Linear(1, 1)
268
        o = ZeroRedundancyOptimizer(m.parameters(), optimizer_class=SGD, lr=LR)
269
        y = m(x)
270
        y.backward(x)
271
        self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
272
        self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
273
        o.zero_grad()
274
        self.assertIsNone(m.weight.grad)
275
        self.assertIsNone(m.bias.grad)
276

277
    def test_constructor(self):
278
        """Check the robustness of the ZeroRedundancyOptimizer constructor by
279
        passing different values for the ``params`` argument."""
280
        self.dist_init(self.rank)
281
        LR = 0.01
282
        m = torch.nn.Sequential(
283
            torch.nn.Linear(5, 10),
284
            torch.nn.Linear(10, 10),
285
            torch.nn.Linear(10, 10),
286
        )
287
        # Test various constructor inputs in the form: (input, expected error)
288
        ctor_inputs = [
289
            ([], ValueError),  # empty parameter list
290
            (torch.randn(1), TypeError),  # non-iterable: `torch.Tensor`
291
            (1.2, TypeError),  # non-iterable: `float`
292
            (
293
                [
294
                    {"params": [l.weight for l in m]},
295
                    {"params": [l.bias for l in m]},
296
                ],
297
                None,
298
            ),  # iterable of dict
299
            (
300
                list(m.parameters()) + [42],
301
                TypeError,
302
            ),  # iterable containing invalid type
303
            (m.parameters(), None),  # `params` as a generator
304
            (list(m.parameters()), None),  # `params` as a list
305
        ]
306
        for ctor_input, error in ctor_inputs:
307
            context = self.assertRaises(error) if error else nullcontext()
308
            with context:
309
                ZeroRedundancyOptimizer(
310
                    ctor_input,
311
                    optimizer_class=SGD,
312
                    lr=LR,
313
                )
314

315
        # Test constructing with multiple parameter groups more thoroughly
316
        WD = 0.01
317
        BETAS = (0.9, 0.999)
318
        EPS = 1e-8
319
        params = [
320
            {"params": [l.weight for l in m], "weight_decay": 0.0},
321
            {"params": [l.bias for l in m], "weight_decay": WD},
322
        ]
323
        o = ZeroRedundancyOptimizer(
324
            params,
325
            optimizer_class=AdamW,
326
            lr=LR,
327
            betas=BETAS,
328
            eps=EPS,
329
        )
330
        assert (
331
            len(o.param_groups) == 2
332
        ), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
333
        assert len(o.optim.param_groups) == 2, (
334
            "Expected 2 local optimizer param groups, but got "
335
            f"{len(o.optim.param_groups)}"
336
        )
337

338
    def test_same_dense_param_type(self):
339
        """Check that ZeroRedundancyOptimizer raises an exception if the input
340
        parameters include sparse tensors or different dense types.
341

342
        NOTE: This test should be removed once support for sparse parameters
343
        and varying parameter types is added.
344
        """
345
        self.dist_init(self.rank)
346
        LR = 0.01
347
        inputs = [
348
            [torch.sparse_coo_tensor(size=(2, 3))],
349
            [torch.FloatTensor(1), torch.DoubleTensor(1)],
350
            [
351
                torch.FloatTensor(1),
352
                torch.FloatTensor(1),
353
                torch.sparse_coo_tensor(size=(2, 3)),
354
            ],
355
        ]
356
        for input in inputs:
357
            with self.assertRaises(ValueError):
358
                ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=LR)
359

360

361
class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
362
    @property
363
    def device(self):
364
        return (
365
            torch.device(self.rank)
366
            if torch.cuda.is_available()
367
            else torch.device("cpu")
368
        )
369

370
    @property
371
    def world_size(self):
372
        return min(4, max(2, torch.cuda.device_count()))
373

374
    @property
375
    def context(self):
376
        return (
377
            nullcontext()
378
            if not torch.cuda.is_available()
379
            else torch.cuda.device(self.rank)
380
        )
381

382
    def _check_same_model_params(
383
        self,
384
        model_a: torch.nn.Module,
385
        model_b: torch.nn.Module,
386
        message: str = "",
387
    ) -> None:
388
        # Check that model parameters match
389
        for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
390
            torch.testing.assert_close(
391
                p_a,
392
                p_b,
393
                atol=1e-3,
394
                rtol=1e-5,
395
                msg=f"Model parameters differ:\n{p_a} {p_b}\n" + message,
396
            )
397
        # Check that model buffers match
398
        for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
399
            torch.testing.assert_close(
400
                b_a,
401
                b_b,
402
                msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message,
403
            )
404

405
    @common_distributed.skip_if_no_gpu
406
    @common_distributed.skip_if_rocm
407
    def test_step(self):
408
        """Check that ZeroRedundancyOptimizer properly exposes the ``step()``
409
        interface."""
410
        self.dist_init(self.rank, world_size=self.world_size)
411
        LR = 0.01
412

413
        with self.context:
414
            x = torch.tensor([float(self.rank + 1)], device=self.device)
415
            m = torch.nn.Linear(1, 1)
416
            m.weight.data = torch.tensor([[1.0]])
417
            m.bias.data = torch.tensor([2.0])
418
            m = m.to(self.device)
419
            m_zero = copy.deepcopy(m).to(self.device)
420

421
            o = SGD(m.parameters(), lr=LR)
422
            o_zero = ZeroRedundancyOptimizer(
423
                m_zero.parameters(),
424
                optimizer_class=SGD,
425
                lr=LR,
426
            )
427

428
            y = m(x)
429
            y.backward(x)
430
            y_zero = m_zero(x)
431
            y_zero.backward(x)
432

433
            for p in m.parameters():
434
                dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
435
                p.grad.data /= self.world_size
436
            o.step()
437
            for p in m_zero.parameters():
438
                dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
439
                p.grad.data /= self.world_size
440
            o_zero.step()
441

442
            self.assertEqual(m.weight, m_zero.weight)
443
            self.assertEqual(m.bias, m_zero.bias)
444

445
    @common_distributed.skip_if_no_gpu
446
    @common_distributed.skip_if_rocm
447
    def test_step_with_closure(self):
448
        """Check that ZeroRedundancyOptimizer properly exposes the
449
        ``step(closure)`` interface."""
450
        self.dist_init(self.rank, world_size=self.world_size)
451

452
        with self.context:
453
            for bucket_view in [False, True]:
454
                x_val = self.rank + 1
455
                weight = 1.0
456
                bias = 2.0
457
                error = 1.0
458
                target = torch.tensor(
459
                    [x_val * weight + bias + error],
460
                    device=self.device,
461
                )
462
                loss_fn = torch.nn.L1Loss()
463

464
                x = torch.tensor([float(x_val)], device=self.device)
465
                m = torch.nn.Linear(1, 1)
466
                m.weight.data = torch.tensor([[weight]])
467
                m.bias.data = torch.tensor([bias])
468
                m.to(self.device)
469

470
                o = ZeroRedundancyOptimizer(
471
                    m.parameters(),
472
                    optimizer_class=SGD,
473
                    parameters_as_bucket_view=bucket_view,
474
                    lr=0.1,
475
                )
476

477
                y = m(x)
478
                y.backward(x)
479
                for p in m.parameters():
480
                    dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
481
                    p.grad.data /= self.world_size
482

483
                def closure():
484
                    o.zero_grad()
485
                    output = m(x)
486
                    loss = loss_fn(output, target)
487
                    loss.backward()
488
                    return loss
489

490
                loss = o.step(closure=closure)
491

492
                self.assertEqual(loss, torch.tensor(error))
493
                self.assertEqual(m.weight, torch.tensor([[1.1]]))
494
                self.assertEqual(m.bias, torch.tensor([2.1]))
495

496
    @common_distributed.skip_if_no_gpu
497
    def test_lr_scheduler(self):
498
        """Check that a normal PyTorch ``lr_scheduler`` is usable with
499
        ZeroRedundancyOptimizer."""
500
        self.dist_init(self.rank)
501
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
502
        x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
503
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
504
        o2 = torch.optim.SGD([x2], lr=0.01)
505
        s = torch.optim.lr_scheduler.StepLR(o, 1)
506
        s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
507
        for _ in range(5):
508
            x.backward()
509
            o.zero_grad()
510
            o.step()
511
            s.step()
512
            x2.backward()
513
            o2.zero_grad()
514
            o2.step()
515
            s2.step()
516
            self.assertEqual(x, x2)
517

518
    def test_sharding(self):
519
        """
520
        Check ZeroRedundancyOptimizer's parameter sharding at construction
521
        time.
522

523
        NOTE: The correctness of this test depends on the ZeRO implementation
524
        using the sorted-greedy partitioning algorithm. For details, see
525
        ``ZeroRedundancyOptimizer._partition_parameters()`` in
526
        zero_redundancy_optimizer.py.
527
        """
528
        self.dist_init(self.rank)
529
        LR = 0.01
530
        sizes = [9, 7, 5, 3]
531
        params = []
532
        for size in sizes * self.world_size:
533
            params.append(torch.rand(size, 1))
534
        o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
535
        self.assertEqual(
536
            sum(x.numel() for x in o.optim.param_groups[0]["params"]),
537
            sum(sizes),
538
        )
539

540
    def test_add_param_group(self):
541
        """Check that ZeroRedundancyOptimizer properly handles adding a new
542
        parameter group a posteriori and that all ranks get a shard of the
543
        contained parameters.
544

545
        NOTE: The correctness of this test depends on the ZeRO implementation
546
        using the sorted-greedy partitioning algorithm. For details, see
547
        ``ZeroRedundancyOptimizer._partition_parameters()`` in
548
        zero_redundancy_optimizer.py.
549
        """
550
        self.dist_init(self.rank)
551
        LR = 0.01
552

553
        # Test with all parameters trainable to begin with
554
        def all_trainable():
555
            params = []
556
            sizes = [9, 7, 5, 3]
557
            sizes_world = sizes * self.world_size
558
            for size in sizes_world[:-1]:
559
                params.append(torch.rand(size, 1))
560

561
            # Make sure that the params are trainable so that they are factored
562
            # into the size-based parameter partitioning
563
            for p in params:
564
                p.requires_grad = True
565

566
            o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
567
            self.assertEqual(len(o.param_groups), 1)
568
            o.add_param_group({"params": [torch.rand(3, 1)]})
569
            # Verify that new group is added to the correct partition, making
570
            # all partitions have the same elements
571
            self.assertEqual(len(o.param_groups), 2)
572
            self.assertEqual(
573
                sum(x.numel() for g in o.optim.param_groups for x in g["params"]),
574
                sum(sizes),
575
            )
576
            self.assertEqual(len(o.optim.param_groups), 2)
577

578
        # Test a pathological config with a first big non-trainable param
579
        def some_trainable():
580
            params = []
581
            for size in [100, 3, 5, 2, 6, 4]:
582
                params.append(torch.rand(size, 1))
583

584
            # Make sure that all but the first param are trainable so that they
585
            # are factored into the size-based parameter partitioning
586
            for p in params[1:]:
587
                p.requires_grad = True
588

589
            o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
590
            self.assertEqual(len(o.param_groups), 1)
591
            o.add_param_group({"params": [torch.rand(3, 1)]})
592
            self.assertEqual(len(o.param_groups), 2)
593
            self.assertEqual(len(o.optim.param_groups), 2)
594

595
        all_trainable()
596
        some_trainable()
597

598
    @common_distributed.skip_if_no_gpu
599
    def test_multiple_param_groups(self):
600
        """
601
        Check parity between constructing ZeRO with multiple parameter groups
602
        upfront versus adding parameter groups to ZeRO after construction
603
        versus a non-sharded optimizer.
604
        """
605
        self.dist_init(self.rank)
606
        BATCH_SIZE, NUM_ITERS = 8, 3
607
        INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5
608
        WD, LR = 0.01, 0.01
609
        model1 = torch.nn.Sequential(
610
            torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
611
            torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
612
            torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
613
        )
614
        model2 = copy.deepcopy(model1)
615
        model3 = copy.deepcopy(model1)
616
        model1 = model1.to(self.device)
617
        model2 = model2.to(self.device)
618
        model3 = model3.to(self.device)
619
        inputs = [
620
            torch.randn(BATCH_SIZE, INPUT_DIM).to(self.device) for _ in range(NUM_ITERS)
621
        ]
622
        # Construct `optim1` with both parameter groups upfront
623
        optim1 = ZeroRedundancyOptimizer(
624
            [
625
                {"params": [l.weight for l in model1], "weight_decay": 0.0},
626
                {"params": [l.bias for l in model1], "weight_decay": WD},
627
            ],
628
            optimizer_class=AdamW,
629
            lr=LR,
630
        )
631
        # Construct `optim2` by adding the second parameter after
632
        optim2 = ZeroRedundancyOptimizer(
633
            [l.weight for l in model2],
634
            optimizer_class=AdamW,
635
            lr=LR,
636
            weight_decay=0.0,
637
        )
638
        optim2.add_param_group({"params": [l.bias for l in model2], "weight_decay": WD})
639
        # Construct `optim3` as a non-sharded optimizer
640
        optim3 = AdamW(
641
            [
642
                {"params": [l.weight for l in model3], "weight_decay": 0.0},
643
                {"params": [l.bias for l in model3], "weight_decay": WD},
644
            ],
645
            lr=LR,
646
        )
647
        # Check parity over a few iterations
648
        for input in inputs:
649
            for model, optim in (
650
                (model1, optim1),
651
                (model2, optim2),
652
                (model3, optim3),
653
            ):
654
                optim.zero_grad()
655
                out = model(input)
656
                loss = out.sum()
657
                loss.backward()
658
                optim.step()
659
            for layer1, layer2, layer3 in zip(model1, model2, model3):
660
                torch.testing.assert_close(layer1.weight, layer2.weight)
661
                torch.testing.assert_close(layer1.weight, layer3.weight)
662
                torch.testing.assert_close(layer1.bias, layer2.bias)
663
                torch.testing.assert_close(layer1.bias, layer3.bias)
664

665
    @common_distributed.skip_if_no_gpu
666
    @common_distributed.skip_if_rocm
667
    def test_collect_shards(self):
668
        """Check the state consolidation mechanism and the state dict exposed
669
        by ZeroRedundancyOptimizer."""
670
        self.dist_init(self.rank)
671
        LR = 1e-3
672
        MOMENTUM = 0.99
673
        BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5
674
        REFERENCE_RANK = 0
675
        target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=self.device)
676
        inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=self.device)
677
        model = torch.nn.Sequential(
678
            torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
679
            torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
680
        ).to(self.device)
681
        loss_fn = torch.nn.L1Loss()
682
        loss_fn.to(self.device)
683
        optimizer = ZeroRedundancyOptimizer(
684
            model.parameters(),
685
            optimizer_class=SGD,
686
            lr=LR,
687
            momentum=MOMENTUM,  # ensure there exists state to shard
688
        )
689

690
        def closure():
691
            optimizer.zero_grad()
692
            output = model(inputs)
693
            loss = loss_fn(output, target)
694
            loss.backward()
695
            return loss
696

697
        # Run a dummy step so that the optimizer state dict exists
698
        _ = optimizer.step(closure=closure)
699

700
        # Get the optimizer state on the reference rank
701
        optimizer.consolidate_state_dict(to=REFERENCE_RANK)
702
        if self.rank == REFERENCE_RANK:
703
            # Check that the state has the correct size
704
            optimizer_state_dict = optimizer.state_dict()
705
            self.assertEqual(
706
                len(optimizer_state_dict["state"]),
707
                len(list(model.parameters())),
708
            )
709
        else:
710
            optimizer_state_dict = {}
711

712
        # Load the optimizer state on all ranks without any exceptions
713
        optimizer_state_dict = _broadcast_object(
714
            optimizer_state_dict,
715
            src_rank=REFERENCE_RANK,
716
            group=dist.group.WORLD,
717
            device=self.device,
718
        )
719
        optimizer.load_state_dict(optimizer_state_dict)
720

721
    def test_nondefault_process_group(self):
722
        """Check that ZeroRedundancyOptimizer works with a non-default process
723
        group consisting only of even ranks."""
724
        # Skip the test if below the minimum world size since then the test is
725
        # trivial
726
        MIN_WORLD_SIZE = 4
727
        if self.world_size < MIN_WORLD_SIZE:
728
            common_distributed.logger.info(
729
                "Skipping `test_nondefault_process_group()` since world size "
730
                "of %s is less than %s",
731
                self.world_size,
732
                MIN_WORLD_SIZE,
733
            )
734
            return
735
        BACKEND = dist.Backend.GLOO
736
        self.dist_init(self.rank, self.world_size, BACKEND)
737
        # Use GPU if enough are available, or fall back to CPU otherwise, which
738
        # is fine since Gloo backend supports both
739
        if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size:
740
            device = torch.device(self.rank)
741
        else:
742
            device = torch.device("cpu")
743
        # Create a new process group consisting of the even ranks to exercise
744
        # the case where the global and local ranks do not necessarily match
745
        subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0]
746
        process_group = dist.new_group(
747
            ranks=subgroup_ranks,
748
            backend=BACKEND,
749
        )
750
        # Ranks not participating in the new process group are no longer needed
751
        if self.rank not in subgroup_ranks:
752
            return
753

754
        # Set different seeds across ranks so that each rank gets different
755
        # training data and hence the model sync check is meaningful
756
        torch.manual_seed(self.rank)
757
        np.random.seed(self.rank)
758

759
        EPOCHS, BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 3, 20, 10, 5
760
        LR = 1e-3
761
        MOMENTUM = 0.99
762
        REFERENCE_RANK = 0
763
        assert (
764
            REFERENCE_RANK in subgroup_ranks
765
        ), "Reference rank must be in the new process group"
766
        loss_fn = torch.nn.L1Loss().to(device)
767

768
        def check(optimizer):
769
            for _ in range(EPOCHS):
770
                target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=device)
771
                inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=device)
772

773
                def closure():
774
                    optimizer.zero_grad()
775
                    output = model(inputs)
776
                    loss = loss_fn(output, target)
777
                    loss /= self.world_size
778
                    loss.backward()
779
                    dist.all_reduce(loss, group=process_group)
780
                    return loss
781

782
                _ = optimizer.step(closure=closure)
783

784
                # Check that the parameters match across ranks after a step
785
                for pg in optimizer.param_groups:
786
                    for p in pg["params"]:
787
                        receptacle = (
788
                            [p.clone() for _ in subgroup_ranks]
789
                            if self.rank == REFERENCE_RANK
790
                            else []
791
                        )
792
                        dist.gather(
793
                            p,
794
                            receptacle,
795
                            dst=REFERENCE_RANK,
796
                            group=process_group,
797
                        )
798
                        if self.rank == REFERENCE_RANK:
799
                            reference_param = receptacle[0]
800
                            for param in receptacle[1:]:
801
                                torch.testing.assert_close(
802
                                    reference_param,
803
                                    param,
804
                                    msg="Models differ between ranks",
805
                                )
806

807
        model = torch.nn.Sequential(
808
            torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
809
            torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
810
        ).to(device)
811
        optimizer = ZeroRedundancyOptimizer(
812
            model.parameters(),
813
            optimizer_class=SGD,
814
            lr=LR,
815
            momentum=MOMENTUM,  # ensure there exists state to shard
816
            process_group=process_group,
817
        )
818
        check(optimizer)
819

820
    @common_distributed.skip_if_no_gpu
821
    @parametrize(
822
        "optimizer_class_str",
823
        ["Adam", "AdamW", "SGD"],
824
        # Use string to appease the internal test name parser
825
    )
826
    @parametrize(
827
        "maximize",
828
        [False, True],
829
    )
830
    def test_local_optimizer_parity(
831
        self,
832
        optimizer_class_str: str,
833
        maximize: bool,
834
    ):
835
        """When combined with DDP, check that a local optimizer gives the same
836
        results as wrapping that optimizer with ZeroRedundancyOptimizer."""
837
        self.dist_init(self.rank)
838
        BATCHES = 20
839
        BATCH_SIZE = 64
840
        LR = 1e-3
841
        INPUT_DIM = 2
842
        HIDDEN_DIM = 3
843
        OUTPUT_DIM = 3
844
        torch.manual_seed(self.rank)
845
        np.random.seed(self.rank)
846
        if optimizer_class_str == "Adam":
847
            optimizer_class = torch.optim.Adam
848
        elif optimizer_class_str == "AdamW":
849
            optimizer_class = torch.optim.AdamW
850
        elif optimizer_class_str == "SGD":
851
            optimizer_class = torch.optim.SGD
852
        else:
853
            assert 0, f"Unsupported optimizer class: {optimizer_class_str}"
854

855
        with self.context:
856
            # Define a base model with a different buffer for each rank
857
            model = torch.nn.Sequential(
858
                torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
859
                torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
860
                torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
861
            ).to(self.device)
862
            model.test_buffer = torch.nn.Buffer(
863
                torch.ones((1), device=self.device) * self.rank,
864
            )
865
            # Define models/optimizers for DDP with ZeRO and DDP with local
866
            # optimizer
867
            defaults = {"maximize": True} if maximize else {}
868
            sharded_optimizer = ZeroRedundancyOptimizer(
869
                params=model.parameters(),
870
                optimizer_class=optimizer_class,
871
                lr=LR,
872
                **defaults,
873
            )
874
            sharded_ddp_model = DDP(
875
                module=model,
876
                device_ids=[self.rank],
877
                broadcast_buffers=True,
878
                find_unused_parameters=True,
879
            )
880
            local_model = copy.deepcopy(model).to(self.device)
881
            ddp_optimizer = optimizer_class(
882
                local_model.parameters(),
883
                lr=LR,
884
                **defaults,
885
            )
886
            ddp_model = DDP(
887
                local_model,
888
                device_ids=[self.rank],
889
                broadcast_buffers=True,
890
                find_unused_parameters=True,
891
            )
892
            # Check that the model is properly synchronized between ranks
893
            # at construction time
894
            self._check_same_model_params(
895
                sharded_ddp_model,
896
                ddp_model,
897
                "Models differ from the start",
898
            )
899

900
            def check_step():
901
                input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM))
902

903
                def closure_ddp(input_tensor=input_tensor):
904
                    ddp_optimizer.zero_grad()
905
                    ddp_loss = ddp_model(input_tensor).abs().sum()
906
                    ddp_loss.backward()
907
                    return ddp_loss
908

909
                def closure_sharded(input_tensor=input_tensor):
910
                    sharded_optimizer.zero_grad()
911
                    sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
912
                    sharded_loss.backward()
913
                    return sharded_loss
914

915
                loss_ddp = cast(
916
                    torch.Tensor,
917
                    ddp_optimizer.step(closure=closure_ddp),
918
                )
919
                loss_sharded_optim = cast(
920
                    torch.Tensor,
921
                    sharded_optimizer.step(closure=closure_sharded),
922
                )
923
                torch.testing.assert_close(
924
                    loss_ddp,
925
                    loss_sharded_optim,
926
                    msg="Losses differ between local optimizer and ZeRO",
927
                )
928
                self._check_same_model_params(
929
                    sharded_ddp_model,
930
                    ddp_model,
931
                    "Models differ after a step",
932
                )
933

934
            # Check that parity is maintained
935
            for i in range(BATCHES):
936
                check_step()
937
                # For the second half of batches, change the parameter
938
                # trainability to further test parity
939
                if i > BATCHES // 2:
940
                    next(ddp_model.parameters()).requires_grad = bool(i % 2)
941
                    next(sharded_ddp_model.parameters()).requires_grad = bool(i % 2)
942

943
            # Check that the `state_dict` checkpoints are compatible between
944
            # the local optimizer and ZeRO
945
            REFERENCE_RANK = 0
946
            # - Get states
947
            ddp_state_dict = ddp_optimizer.state_dict()
948
            sharded_optimizer.consolidate_state_dict(to=REFERENCE_RANK)
949
            sharded_optim_state_dict = [
950
                sharded_optimizer.state_dict() if self.rank == REFERENCE_RANK else {}
951
            ]
952
            dist.broadcast_object_list(
953
                sharded_optim_state_dict,
954
                src=REFERENCE_RANK,
955
                group=dist.group.WORLD,
956
            )
957
            sharded_optim_state_dict = sharded_optim_state_dict[0]
958

959
            # - Cross-load the states
960
            # Run one step and check that the models are still the same
961
            ddp_state_dict_ref = copy.deepcopy(ddp_state_dict)
962
            ddp_optimizer.load_state_dict(sharded_optim_state_dict)
963
            sharded_optimizer.load_state_dict(ddp_state_dict)
964
            check_step()
965

966
            # - Reload their respective states
967
            # Run one step and check that the models are still the same
968
            ddp_optimizer.load_state_dict(ddp_state_dict_ref)
969
            sharded_optimizer.load_state_dict(sharded_optim_state_dict)
970
            check_step()
971

972
    def _test_zero_join(self, device):
973
        """Check that the ZeRO join hook allows training with uneven inputs
974
        when using the given device."""
975
        NUM_INPUTS = 3
976
        NUM_EPOCHS = 2
977
        LR = 0.01
978
        torch.manual_seed(0)
979
        torch.cuda.manual_seed(0)
980

981
        rank = self.rank
982
        world_size = self.world_size
983
        is_gpu = device.type == "cuda"
984
        backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO
985
        self.dist_init(rank, world_size, backend)
986

987
        model = torch.nn.Sequential(
988
            torch.nn.Linear(2, 3),
989
            torch.nn.Linear(3, 3),
990
            torch.nn.Linear(3, 3),
991
        )
992
        model.to(device)
993

994
        # DDP ensures correct gradients in data parallel training, so DDP with
995
        # local optimizers on uneven inputs should be equivalent to ZeRO on
996
        # uneven inputs with gradients being manually set
997
        ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model)
998
        local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
999
        zero_model = copy.deepcopy(model)
1000
        zero_model.to(device)
1001
        zero_optim = ZeroRedundancyOptimizer(
1002
            zero_model.parameters(),
1003
            torch.optim.Adam,
1004
            lr=LR,
1005
        )
1006
        loss_fn = torch.nn.MSELoss()
1007

1008
        # Use uneven inputs: rank i has i extra inputs
1009
        inputs = [torch.randn(20, 2).to(device) for _ in range(NUM_INPUTS + rank)]
1010
        labels = torch.randn(20, 3).to(device)
1011

1012
        # Save the gradients and parameters from DDP as the ground truth; do
1013
        # so on the last-joining rank (in this case, the largest rank)
1014
        grads_at_each_iter = []
1015
        params_at_each_iter = []
1016
        with ddp_model.join():
1017
            for _ in range(NUM_EPOCHS):
1018
                for input in inputs:
1019
                    output = ddp_model(input)
1020
                    loss_fn(output, labels).backward()
1021
                    if rank == world_size - 1:
1022
                        grads = []
1023
                        for p in ddp_model.parameters():
1024
                            grads.append(p.grad.detach().clone().to(device))
1025
                    local_optim.step()
1026
                    if rank == world_size - 1:
1027
                        params = []
1028
                        for p in ddp_model.parameters():
1029
                            params.append(p.detach().clone().to(device))
1030
                        grads_at_each_iter.append(grads)
1031
                        params_at_each_iter.append(params)
1032

1033
        # Broadcast the saved gradients and parameters to all of the other
1034
        # ranks (which joined early)
1035
        grads_and_params = [grads_at_each_iter, params_at_each_iter]
1036
        grads_and_params = _broadcast_object(
1037
            grads_and_params,
1038
            src_rank=world_size - 1,
1039
            group=dist.group.WORLD,
1040
            device=device,
1041
        )
1042
        grads_at_each_iter = grads_and_params[0]
1043
        params_at_each_iter = grads_and_params[1]
1044
        # TODO: Replace this `_broadcast_object` with `broadcast_object_list`
1045
        # once the latter supports loading to the destination device instead
1046
        # of the source device
1047

1048
        # A process must still set the remaining gradients after joining, so we
1049
        # define a join hook to do this before the ZeRO join hook
1050
        class _JoinGradInfo:
1051
            def __init__(self, grads):
1052
                self.grads = grads  # remaining gradients to set (in order)
1053
                self.index = 0
1054

1055
        class _SetGradsJoinHook(JoinHook):
1056
            def __init__(self, zero_optim, grads):
1057
                zero_optim._join_grad_info = _JoinGradInfo(grads)
1058
                self.zero = zero_optim
1059
                super().__init__()
1060

1061
            def main_hook(self):
1062
                join_grad_info = self.zero._join_grad_info
1063
                grads = self.zero._join_grad_info.grads[join_grad_info.index]
1064
                join_grad_info.index += 1
1065
                for p, grad in zip(self.zero._all_params, grads):
1066
                    p.grad = grad.detach().clone().to(device)
1067

1068
        class _GradientSetter(Joinable):
1069
            def __init__(self) -> None:
1070
                super().__init__()
1071

1072
            def join_hook(self, **kwargs):
1073
                assert "zero_optim" in kwargs
1074
                assert "grads" in kwargs
1075
                zero_optim = kwargs["zero_optim"]
1076
                grads = kwargs["grads"]
1077
                return _SetGradsJoinHook(zero_optim, grads)
1078

1079
            @property
1080
            def join_device(self):
1081
                return device
1082

1083
            @property
1084
            def join_process_group(self):
1085
                return dist.group.WORLD
1086

1087
        num_grads_after_joining = NUM_EPOCHS * (world_size - rank - 1)
1088
        grads = grads_at_each_iter[-num_grads_after_joining:]
1089
        gradient_setter = _GradientSetter()
1090
        iter = 0
1091
        with Join(
1092
            [gradient_setter, zero_optim],
1093
            zero_optim=zero_optim,
1094
            grads=grads,
1095
        ):
1096
            for _ in range(NUM_EPOCHS):
1097
                for input in inputs:
1098
                    # Notify join context that this process has not joined
1099
                    Join.notify_join_context(gradient_setter)
1100
                    # Set gradients manually
1101
                    for p, grad in zip(
1102
                        zero_model.parameters(),
1103
                        grads_at_each_iter[iter],
1104
                    ):
1105
                        p.grad = grad.detach().clone().to(device)
1106
                    # Perform optimizer step and check parity
1107
                    zero_optim.step()
1108
                    for p, ddp_p in zip(
1109
                        zero_model.parameters(),
1110
                        params_at_each_iter[iter],
1111
                    ):
1112
                        torch.testing.assert_close(
1113
                            p,
1114
                            ddp_p,
1115
                            msg="Parameters differ between using ZeRO and "
1116
                            "local optimizer",
1117
                        )
1118
                    iter += 1
1119

1120
    @common_distributed.requires_nccl()
1121
    @common_distributed.skip_if_no_gpu
1122
    def test_zero_join_gpu(self):
1123
        """Check that the ZeRO join hook allows training with uneven inputs
1124
        on GPU."""
1125
        self._test_zero_join(self.device)
1126

1127
    @common_distributed.requires_gloo()
1128
    def test_zero_join_cpu(self):
1129
        """Check that the ZeRO join hook allows training with uneven inputs
1130
        on CPU."""
1131
        self._test_zero_join(torch.device("cpu"))
1132

1133
    def _test_zero_model_parallel(self, parameters_as_bucket_view: bool):
1134
        # Use two processes each with two GPUs
1135
        assert self.rank < 2
1136
        NUM_EPOCHS = 2
1137
        NUM_INPUTS = 4
1138
        LR = 0.01
1139
        torch.manual_seed(0)
1140
        torch.cuda.manual_seed(0)
1141

1142
        class ModelParallelModel(torch.nn.Module):
1143
            def __init__(self, dev0, dev1):
1144
                super().__init__()
1145
                self.dev0 = dev0
1146
                self.dev1 = dev1
1147
                self.net0 = torch.nn.Linear(10, 10).to(dev0)
1148
                self.relu = torch.nn.ReLU()
1149
                self.net1 = torch.nn.Linear(10, 5).to(dev1)
1150

1151
            def forward(self, x):
1152
                x = x.to(self.dev0)
1153
                x = self.relu(self.net0(x))
1154
                x = x.to(self.dev1)
1155
                return self.net1(x)
1156

1157
        class LocalModel(torch.nn.Module):
1158
            def __init__(self) -> None:
1159
                super().__init__()
1160
                self.net0 = torch.nn.Linear(10, 10)
1161
                self.relu = torch.nn.ReLU()
1162
                self.net1 = torch.nn.Linear(10, 5)
1163

1164
            def forward(self, x):
1165
                return self.net1(self.relu(self.net0(x)))
1166

1167
        dev0 = torch.device(2 * self.rank)
1168
        dev1 = torch.device(2 * self.rank + 1)
1169
        mp_model = ModelParallelModel(dev0, dev1)
1170
        ddp_model = DDP(mp_model)
1171
        local_model = LocalModel().to(dev0)
1172

1173
        # Ensure the parameters are the same across the two models
1174
        def copy_param(p):
1175
            return torch.nn.Parameter(p.detach().clone().to(dev0))
1176

1177
        local_model.net0.weight = copy_param(mp_model.net0.weight)
1178
        local_model.net0.bias = copy_param(mp_model.net0.bias)
1179
        local_model.net1.weight = copy_param(mp_model.net1.weight)
1180
        local_model.net1.bias = copy_param(mp_model.net1.bias)
1181

1182
        # Compare parity between DDP with model parallelism using ZeRO and
1183
        # a local model using a local optimizer
1184
        zero_optim = ZeroRedundancyOptimizer(
1185
            ddp_model.parameters(),
1186
            optimizer_class=torch.optim.Adam,
1187
            parameters_as_bucket_view=parameters_as_bucket_view,
1188
            lr=LR,
1189
        )
1190
        local_optim = torch.optim.Adam(local_model.parameters(), lr=LR)
1191
        inputs = [torch.randn(20, 10).to(dev0) for _ in range(NUM_INPUTS)]
1192

1193
        for _ in range(NUM_EPOCHS):
1194
            for input in inputs:
1195

1196
                def closure_local():
1197
                    local_optim.zero_grad()
1198
                    local_loss = local_model(input).abs().sum()
1199
                    local_loss.backward()
1200
                    return local_loss
1201

1202
                def closure_ddp():
1203
                    zero_optim.zero_grad()
1204
                    ddp_loss = ddp_model(input).abs().sum()
1205
                    ddp_loss.backward()
1206
                    return ddp_loss
1207

1208
                local_loss = cast(torch.Tensor, local_optim.step(closure=closure_local))
1209
                ddp_loss = cast(torch.Tensor, zero_optim.step(closure=closure_ddp))
1210

1211
                # Increased tolerances are needed to pass when using TF32
1212
                # See: https://github.com/pytorch/pytorch/issues/67764
1213
                torch.testing.assert_close(
1214
                    local_loss.cpu(),
1215
                    ddp_loss.cpu(),
1216
                    rtol=1e-03,
1217
                    atol=1e-08,
1218
                ), "Losses differ between local optimizer and ZeRO"
1219

1220
                for local_p, ddp_p in zip(
1221
                    local_model.parameters(), ddp_model.parameters()
1222
                ):
1223
                    torch.testing.assert_close(
1224
                        local_p.cpu(),
1225
                        ddp_p.cpu(),
1226
                        rtol=1e-03,
1227
                        atol=1e-04,
1228
                    ), "Models differ after a step"
1229

1230
    @common_distributed.skip_if_lt_x_gpu(4)
1231
    @parametrize(
1232
        "parameters_as_bucket_view",
1233
        [False, True],
1234
    )
1235
    def test_zero_model_parallel(
1236
        self,
1237
        parameters_as_bucket_view: bool,
1238
    ):
1239
        """Check that ZeRO works with model parallelism where the model's
1240
        layers are assigned to different devices."""
1241
        if self.rank >= 2:
1242
            return
1243
        self.dist_init(self.rank, world_size=2)
1244
        self._test_zero_model_parallel(parameters_as_bucket_view)
1245

1246
    def _test_ddp_zero_overlap(
1247
        self,
1248
        device,
1249
        hook_constructor,
1250
        gradient_as_bucket_view,
1251
        static_graph,
1252
        **kwargs,
1253
    ):
1254
        SGD_LR = 0.01
1255
        SGD_MOMENTUM = 0.9
1256
        SGD_WEIGHT_DECAY = 0.001
1257
        NUM_INPUTS = 5
1258
        torch.manual_seed(0)
1259
        torch.cuda.manual_seed(0)
1260

1261
        rank = self.rank
1262
        is_gpu = device.type == "cuda"
1263
        if is_gpu:
1264
            torch.cuda.set_device(device)
1265
        models_to_test = [
1266
            (
1267
                torch.nn.Sequential(
1268
                    torch.nn.Linear(1000, 2000),
1269
                    torch.nn.Linear(2000, 500),
1270
                ),
1271
                [torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)],
1272
            )
1273
        ]
1274
        if HAS_TORCHVISION:
1275
            models_to_test.append(
1276
                (
1277
                    torchvision.models.resnet50(),
1278
                    [torch.randn(1, 3, 3, 1000).to(device) for _ in range(NUM_INPUTS)],
1279
                )
1280
            )
1281
        for model, inputs in models_to_test:
1282
            # Enable determinism in cudnn operators
1283
            with torch.backends.cudnn.flags(
1284
                enabled=True, deterministic=True, benchmark=False
1285
            ):
1286
                device_ids = [rank] if is_gpu else None
1287
                # Set up the DDP model overlapping with ZeRO
1288
                ddp_model_overlap = DDP(
1289
                    copy.deepcopy(model).to(device),
1290
                    device_ids=device_ids,
1291
                    gradient_as_bucket_view=gradient_as_bucket_view,
1292
                )
1293
                if static_graph:
1294
                    ddp_model_overlap._set_static_graph()
1295
                zero_optim = ZeroRedundancyOptimizer(
1296
                    ddp_model_overlap.parameters(),
1297
                    optimizer_class=torch.optim.SGD,
1298
                    overlap_with_ddp=True,
1299
                    lr=SGD_LR,
1300
                    momentum=SGD_MOMENTUM,
1301
                    weight_decay=SGD_WEIGHT_DECAY,
1302
                )
1303
                ddp_model_overlap.register_comm_hook(
1304
                    None,
1305
                    hook_constructor(
1306
                        allreduce_hook,
1307
                        ddp_model_overlap,
1308
                        zero_optim,
1309
                        **kwargs,
1310
                    ),
1311
                )
1312

1313
                # Set up the DDP model with local optimizer
1314
                ddp_model_local = DDP(
1315
                    copy.deepcopy(model).to(device),
1316
                    device_ids=device_ids,
1317
                    gradient_as_bucket_view=gradient_as_bucket_view,
1318
                )
1319
                if static_graph:
1320
                    ddp_model_local._set_static_graph()
1321
                local_optim = torch.optim.SGD(
1322
                    ddp_model_local.parameters(),
1323
                    lr=SGD_LR,
1324
                    momentum=SGD_MOMENTUM,
1325
                    weight_decay=SGD_WEIGHT_DECAY,
1326
                )
1327

1328
                # Check that the parameters match initially
1329
                for p1, p2 in zip(
1330
                    ddp_model_overlap.parameters(), ddp_model_local.parameters()
1331
                ):
1332
                    self.assertEqual(p1, p2)
1333

1334
                # Save the parameters to ensure they were updated
1335
                init_params_overlap = copy.deepcopy(
1336
                    list(ddp_model_overlap.parameters())
1337
                )
1338

1339
                # Ensure that this test runs independently
1340
                dist.barrier()
1341

1342
                # Run the DDP model overlapping with ZeRO
1343
                # NOTE: Overlapping currently requires 2 or 3 warmup iterations
1344
                # to ensure DDP buckets have been rebuilt (depending on the
1345
                # value of `static_graph`)
1346
                num_warmup_inputs = 2 if not static_graph else 3
1347
                for input in inputs[:num_warmup_inputs]:
1348
                    output = ddp_model_overlap(input)
1349
                    loss = output.sum()
1350
                    loss.backward()
1351
                for input in inputs:
1352
                    zero_optim.zero_grad()
1353
                    output = ddp_model_overlap(input)
1354
                    loss = output.sum()
1355
                    loss.backward()
1356

1357
                # Run the DDP model with local optimizer
1358
                for input in inputs:
1359
                    local_optim.zero_grad()
1360
                    output = ddp_model_local(input)
1361
                    loss = output.sum()
1362
                    loss.backward()
1363
                    local_optim.step()
1364
                dist.barrier()
1365

1366
                # Check that the parameters are equal
1367
                for p1, p2 in zip(
1368
                    ddp_model_overlap.parameters(), ddp_model_local.parameters()
1369
                ):
1370
                    self.assertEqual(p1, p2)
1371

1372
                # Check that the parameters were updated
1373
                self.assertNotEqual(
1374
                    init_params_overlap,
1375
                    list(ddp_model_overlap.parameters()),
1376
                )
1377

1378
                # Ensure that this test runs independently
1379
                dist.barrier()
1380

1381
    # NOTE: The test is skipped if using Windows since functional optimizers
1382
    # are not currently supported.
1383
    @common_distributed.skip_if_win32()
1384
    @common_distributed.requires_nccl()
1385
    @common_distributed.skip_if_no_gpu
1386
    @common_distributed.skip_if_rocm
1387
    @parametrize(
1388
        "use_gpu",
1389
        [True],
1390
        # Add `False` once the Gloo sync issue causing hangs is fixed
1391
        # See: https://github.com/pytorch/pytorch/issues/62300
1392
    )
1393
    @parametrize(
1394
        "use_interleaved_hook",
1395
        [False, True],
1396
    )
1397
    @parametrize(
1398
        "gradient_as_bucket_view",
1399
        [False, True],
1400
    )
1401
    @parametrize(
1402
        "static_graph",
1403
        [False, True],
1404
    )
1405
    @parametrize(
1406
        "shard_buckets",
1407
        [False, True],
1408
    )
1409
    def test_ddp_zero_overlap(
1410
        self,
1411
        use_gpu: bool,
1412
        use_interleaved_hook: bool,
1413
        gradient_as_bucket_view: bool,
1414
        static_graph: bool,
1415
        shard_buckets: bool,
1416
    ):
1417
        """
1418
        Check that overlapping DDP with ZeRO using the given method determined
1419
        by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO
1420
        and DDP arguments achieves parity with DDP using a local optimizer.
1421
        """
1422
        device = torch.device(self.rank) if use_gpu else torch.device("cpu")
1423
        backend = _get_backend_for_tests()
1424
        self.dist_init(self.rank, self.world_size, backend)
1425
        hook_constructor = (
1426
            hook_with_zero_step
1427
            if not use_interleaved_hook
1428
            else hook_with_zero_step_interleaved
1429
        )
1430

1431
        self._test_ddp_zero_overlap(
1432
            device,
1433
            hook_constructor,
1434
            gradient_as_bucket_view,
1435
            static_graph,
1436
            shard_buckets=shard_buckets,
1437
        )
1438

1439

1440
instantiate_parametrized_tests(TestZeroRedundancyOptimizerSingleRank)
1441
instantiate_parametrized_tests(TestZeroRedundancyOptimizerDistributed)
1442

1443
if __name__ == "__main__":
1444
    # ! unittest should not be used here, else the tests are not properly registered
1445
    run_tests()
1446

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.