pytorch

Форк
0
/
test_fsdp_comm_hooks.py 
432 строки · 15.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4
from typing import Optional
5

6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
from torch import distributed as dist
10
from torch.distributed.algorithms._comm_hooks import default_hooks
11
from torch.distributed.distributed_c10d import _get_default_group
12
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
13
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
14
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
15
from torch.testing._internal.common_distributed import (
16
    requires_nccl,
17
    requires_nccl_version,
18
    skip_but_pass_in_sandcastle_if,
19
    skip_if_lt_x_gpu,
20
)
21
from torch.testing._internal.common_fsdp import FSDPTest
22
from torch.testing._internal.common_utils import (
23
    instantiate_parametrized_tests,
24
    parametrize,
25
    run_tests,
26
)
27

28
if not dist.is_available():
29
    print("Distributed not available, skipping tests", file=sys.stderr)
30
    sys.exit(0)
31

32
# bfloat16 is only supported by CUDA 11+
33
BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
34
    torch.version.cuda is not None or torch.version.hip is not None
35
)
36

37

38
class Net(nn.Module):
39
    def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):
40
        # to ensure determinism
41
        torch.manual_seed(0)
42
        torch.cuda.manual_seed(0)
43
        super().__init__()
44

45
        if has_wrapping:
46
            self.net = FSDP(
47
                nn.Sequential(
48
                    nn.Linear(8, 16),
49
                    nn.ReLU(),
50
                    FSDP(
51
                        nn.Linear(16, 8),
52
                        device_id=torch.cuda.current_device(),
53
                        sharding_strategy=sharding_strategy,
54
                        mixed_precision=mixed_precision,
55
                    ),
56
                ),
57
                device_id=torch.cuda.current_device(),
58
                sharding_strategy=sharding_strategy,
59
                mixed_precision=mixed_precision,
60
            )
61
        else:
62
            self.net = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 8))
63

64
        self.out = nn.Linear(8, 4)
65

66
    def forward(self, x):
67
        return self.out(F.relu(self.net(x)))
68

69

70
class DummyState:
71
    __slots__ = ["process_group", "noise"]
72

73
    def __init__(self, process_group: dist.ProcessGroup, noise: int):
74
        self.process_group = process_group
75
        self.noise = noise
76

77

78
class DummyHook:
79
    def dummy_hook_for_no_shard_fsdp(self, state: DummyState, grad: torch.Tensor):
80
        """
81
        This communication hook is for illustration and testing purpose only.
82
        This communication hook is used during FSDP ``NO_SHARD`` training. It adds some noise to
83
        the provided ``grad`` parameter and uses ``all_reduce`` to communicate full, flattened,
84
        unsharded gradient.
85
        """
86
        grad.add_(state.noise)
87
        dist.all_reduce(grad, group=state.process_group)
88

89
    def custom_reduce_scatter(self, output, input, group=None):
90
        """
91
        This function is for illustrative purpose only.
92
        It is meant to implement a custom reduce-scatter
93
        of a flattened tensor to all processes in a group.
94
        Currently a no-op.
95
        """
96
        pass
97

98
    def dummy_hook_for_sharded_fsdp(
99
        self, state: DummyState, grad: torch.Tensor, output: torch.Tensor
100
    ):
101
        """
102
        This communication hook is for illustration and testing purposes only.
103
        This communication hook is used during FSDP ``FULL_SHARD`` or ``SHARD_GRAD_OP`` training.
104
        It adds some noise to the provided ``grad`` parameter, uses
105
        ``reduce_scatter`` for gradient communication and stores a sharded gradient in ``output``.
106
        """
107
        grad.add_(state.noise)
108
        self.custom_reduce_scatter(output, grad, group=state.process_group)
109

110

111
class TestCommunicationHooks(FSDPTest):
112
    @skip_if_lt_x_gpu(2)
113
    @parametrize(
114
        "sharding_strategy",
115
        [
116
            ShardingStrategy.NO_SHARD,
117
            ShardingStrategy.FULL_SHARD,
118
            ShardingStrategy.SHARD_GRAD_OP,
119
        ],
120
    )
121
    def test_default_communication_hook_behavior(
122
        self, sharding_strategy: Optional[ShardingStrategy]
123
    ):
124
        """
125
        Tests FSDP's default communication hook's behavior and correctness.
126
        This test creates a simple linear net with weight shape  ``1 X N``,
127
        where ``N`` is the number of workers.
128
        For sharded cases, each worker gets 1 element of the weight parameter. This test
129
        checks that after backward, each worker has a proper value in its chunk of
130
        the gradient, or the whole gradient on every worker is equal to an expected value.
131

132
        Arguments:
133
            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
134
        """
135
        out_dim = self.world_size
136
        net = torch.nn.Linear(1, out_dim, bias=False)
137
        inpt = torch.tensor([self.rank]).float().cuda(self.rank)
138

139
        net_default_hook = FSDP(
140
            net,
141
            device_id=torch.cuda.current_device(),
142
            sharding_strategy=sharding_strategy,
143
        ).to(self.rank)
144

145
        # Check that by default, `_comm_hook` is None
146
        for entry in FSDP.fsdp_modules(net_default_hook):
147
            self.assertEqual(entry._comm_hook, None)
148

149
        for _ in range(4):
150
            # Clear gradients
151
            net_default_hook.zero_grad()
152
            loss = net_default_hook(inpt).sum()
153
            loss.backward()
154

155
            # For each worker, the gradient on the weight should be worker_rank.
156
            grad = net_default_hook.params[0].grad
157
            expected_grad = (
158
                sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
159
            )
160
            # Verify default hook produces expected gradients
161
            self.assertEqual(
162
                grad[0].item(),
163
                expected_grad,
164
                msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}",
165
            )
166

167
    def _get_submodules(self, fsdp_net):
168
        return [
169
            submodule
170
            for submodule in FSDP.fsdp_modules(fsdp_net)
171
            if not submodule.check_is_root()
172
        ]
173

174
    def _init_model(self, core, sharding_strategy, mixed_precision=None):
175
        device = torch.device("cuda")
176
        return FSDP(
177
            core,
178
            device_id=torch.cuda.current_device(),
179
            sharding_strategy=sharding_strategy,
180
            mixed_precision=mixed_precision,
181
        ).to(device)
182

183
    @skip_if_lt_x_gpu(2)
184
    @parametrize("has_wrapping", [True, False])
185
    @parametrize(
186
        "sharding_strategy",
187
        [
188
            ShardingStrategy.NO_SHARD,
189
            ShardingStrategy.FULL_SHARD,
190
            ShardingStrategy.SHARD_GRAD_OP,
191
        ],
192
    )
193
    def test_default_communication_hook_initialization(
194
        self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
195
    ):
196
        """
197
        Tests FSDP's communication hook interface behavior.
198

199
        Arguments:
200
            has_wrapping (bool): Configures wrapping of a module.
201
            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
202
        """
203

204
        # Initialize a model
205
        fsdp_model_with_hook = self._init_model(
206
            Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
207
            sharding_strategy=sharding_strategy,
208
        )
209

210
        # Check that by default, `_comm_hook` is None
211
        for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):
212
            self.assertEqual(fsdp_module._comm_hook, None)
213

214
        dummy_state = DummyState(process_group=None, noise=1234)
215
        dummy_hook = (
216
            DummyHook.dummy_hook_for_no_shard_fsdp
217
            if sharding_strategy != ShardingStrategy.NO_SHARD
218
            else DummyHook.dummy_hook_for_sharded_fsdp
219
        )
220

221
        fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
222

223
        # Check that we can't register comm hook twice
224
        with self.assertRaisesRegex(
225
            AssertionError, "^A communication hook is already registered$"
226
        ):
227
            fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
228

229
        # Check dummy hook was registered for the root and all submodules if any
230
        for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):
231
            self.assertEqual(fsdp_module._comm_hook, dummy_hook)
232
            self.assertEqual(fsdp_module._comm_hook_state, dummy_state)
233

234
    @skip_if_lt_x_gpu(2)
235
    @parametrize(
236
        "sharding_strategy",
237
        [
238
            ShardingStrategy.NO_SHARD,
239
            ShardingStrategy.FULL_SHARD,
240
            ShardingStrategy.SHARD_GRAD_OP,
241
        ],
242
    )
243
    def test_registering_hook_non_root(
244
        self, sharding_strategy: Optional[ShardingStrategy]
245
    ):
246
        """
247
        Tests FSDP's communication hook registering for submodules.
248
        Make sure it can't be registered for non-root submodules.
249
        Currently tests only ``NO_SHARD`` strategy.
250

251
        Arguments:
252
            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
253
        """
254

255
        fsdp_model_with_hook = self._init_model(
256
            Net(has_wrapping=True, sharding_strategy=sharding_strategy),
257
            sharding_strategy=sharding_strategy,
258
        )
259
        dummy_state = DummyState(process_group=None, noise=1234)
260
        dummy_hook = (
261
            DummyHook.dummy_hook_for_no_shard_fsdp
262
            if sharding_strategy != ShardingStrategy.NO_SHARD
263
            else DummyHook.dummy_hook_for_sharded_fsdp
264
        )
265
        # Creating a list of non-root submodules to test
266
        submodules = self._get_submodules(fsdp_model_with_hook)
267
        # Check that assertion is raised for registering a comm hook on a non-root
268
        with self.assertRaisesRegex(
269
            AssertionError,
270
            "^register_comm_hook can only be called on a root instance.$",
271
        ):
272
            submodules[1].register_comm_hook(dummy_state, dummy_hook)
273

274
    @skip_if_lt_x_gpu(2)
275
    def test_registering_hook_hybrid_strategy(self):
276
        for sharding_strategy in (
277
            ShardingStrategy.HYBRID_SHARD,
278
            ShardingStrategy._HYBRID_SHARD_ZERO2,
279
        ):
280
            model = Net(False, None, None).cuda()
281
            fsdp_model = FSDP(
282
                model,
283
                auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
284
                sharding_strategy=sharding_strategy,
285
            )
286
            dummy_state = DummyState(process_group=None, noise=1234)
287
            dummy_hook = DummyHook.dummy_hook_for_sharded_fsdp
288
            with self.assertRaisesRegex(
289
                AssertionError,
290
                "Communication hook is not supported for hybrid strategies",
291
            ):
292
                fsdp_model.register_comm_hook(dummy_state, dummy_hook)
293

294
    @skip_if_lt_x_gpu(2)
295
    @parametrize(
296
        "sharding_strategy",
297
        [
298
            ShardingStrategy.NO_SHARD,
299
            ShardingStrategy.FULL_SHARD,
300
            ShardingStrategy.SHARD_GRAD_OP,
301
        ],
302
    )
303
    def test_registering_hook_submodules(
304
        self, sharding_strategy: Optional[ShardingStrategy]
305
    ):
306
        """
307
        Tests FSDP's communication hook registering for submodules.
308
        Checks behavior if a hook was registered for a non-root submodule
309
        Currently tests only ``NO_SHARD`` strategy.
310

311
        Arguments:
312
            sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
313
        """
314

315
        fsdp_model_with_hook = self._init_model(
316
            Net(has_wrapping=True, sharding_strategy=sharding_strategy),
317
            sharding_strategy=sharding_strategy,
318
        )
319
        dummy_state = DummyState(process_group=None, noise=1234)
320
        dummy_hook = (
321
            DummyHook.dummy_hook_for_no_shard_fsdp
322
            if sharding_strategy != ShardingStrategy.NO_SHARD
323
            else DummyHook.dummy_hook_for_sharded_fsdp
324
        )
325
        submodules = self._get_submodules(fsdp_model_with_hook)
326

327
        # Simulate a registration of a hook on a submodule
328
        submodules[1]._comm_hook = dummy_hook
329
        # Check that an error is raised when some of submodules have a non-default hook assigned
330
        with self.assertRaisesRegex(
331
            AssertionError, "^A communication hook is already registered$"
332
        ):
333
            fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)
334

335
    def _check_low_precision_hook(
336
        self, state, hook, sharding_strategy, dtype, has_wrapping
337
    ):
338
        # keep everything deterministic for input data
339
        torch.manual_seed(0)
340
        torch.cuda.manual_seed(0)
341

342
        fsdp_with_hook = self._init_model(
343
            Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),
344
            sharding_strategy=sharding_strategy,
345
        )
346
        fsdp_with_hook.register_comm_hook(state, hook)
347

348
        mp_only_grad = MixedPrecision(reduce_dtype=dtype)
349
        fsdp_with_mp = self._init_model(
350
            Net(
351
                has_wrapping=has_wrapping,
352
                sharding_strategy=sharding_strategy,
353
                mixed_precision=mp_only_grad,
354
            ),
355
            sharding_strategy=sharding_strategy,
356
            mixed_precision=mp_only_grad,
357
        )
358

359
        optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
360
        optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)
361

362
        in_data = torch.rand(16, 8).cuda()
363
        fsdp_with_hook.train()
364
        fsdp_with_mp.train()
365
        loss_hook = fsdp_with_hook(in_data).sum()
366
        loss_mp = fsdp_with_mp(in_data).sum()
367
        loss_hook.backward()
368
        # Make sure grads were cast to the parameter's precision
369
        self.assertEqual(fsdp_with_hook.params[0].grad.dtype, state.parameter_type)
370
        loss_mp.backward()
371
        optim_hook.step()
372
        optim_mp.step()
373

374
        dist.barrier()
375

376
        for hook_param, mp_param in zip(
377
            fsdp_with_hook.parameters(), fsdp_with_mp.parameters()
378
        ):
379
            self.assertEqual(hook_param.grad, mp_param.grad)
380

381
    @requires_nccl()
382
    @skip_if_lt_x_gpu(2)
383
    @parametrize("has_wrapping", [True, False])
384
    @parametrize(
385
        "sharding_strategy",
386
        [
387
            ShardingStrategy.NO_SHARD,
388
            ShardingStrategy.FULL_SHARD,
389
            ShardingStrategy.SHARD_GRAD_OP,
390
        ],
391
    )
392
    def test_fp16_hook(
393
        self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
394
    ):
395
        state = default_hooks.LowPrecisionState(process_group=_get_default_group())
396
        hook = default_hooks.fp16_compress_hook
397

398
        self._check_low_precision_hook(
399
            state, hook, sharding_strategy, torch.float16, has_wrapping
400
        )
401

402
    @requires_nccl()
403
    @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
404
    @skip_but_pass_in_sandcastle_if(
405
        not BFLOAT16_AVAILABLE,
406
        "BFloat16 is only supported by CUDA 11+",
407
    )
408
    @skip_if_lt_x_gpu(2)
409
    @parametrize("has_wrapping", [True, False])
410
    @parametrize(
411
        "sharding_strategy",
412
        [
413
            ShardingStrategy.NO_SHARD,
414
            ShardingStrategy.FULL_SHARD,
415
            ShardingStrategy.SHARD_GRAD_OP,
416
        ],
417
    )
418
    def test_bf16_hook(
419
        self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]
420
    ):
421
        state = default_hooks.LowPrecisionState(process_group=_get_default_group())
422
        hook = default_hooks.bf16_compress_hook
423

424
        self._check_low_precision_hook(
425
            state, hook, sharding_strategy, torch.bfloat16, has_wrapping
426
        )
427

428

429
instantiate_parametrized_tests(TestCommunicationHooks)
430

431
if __name__ == "__main__":
432
    run_tests()
433

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.