pytorch

Форк
0
/
test_wrap.py 
957 строк · 37.4 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import functools
4
import itertools
5
import os
6
import tempfile
7
import unittest
8
from enum import auto, Enum
9
from typing import Callable, Union
10

11
import torch
12
import torch.nn as nn
13
import torch.nn.functional as F
14
from torch.distributed.fsdp._wrap_utils import _validate_frozen_params
15
from torch.distributed.fsdp.fully_sharded_data_parallel import (
16
    BackwardPrefetch,
17
    CPUOffload,
18
    FullyShardedDataParallel as FSDP,
19
    MixedPrecision,
20
    ShardingStrategy,
21
)
22
from torch.distributed.fsdp.wrap import (
23
    _or_policy,
24
    _Policy,
25
    _wrap_module_cls_individually,
26
    always_wrap_policy,
27
    CustomPolicy,
28
    enable_wrap,
29
    ModuleWrapPolicy,
30
    size_based_auto_wrap_policy,
31
    transformer_auto_wrap_policy,
32
    wrap,
33
)
34
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
35
from torch.nn.modules.batchnorm import _BatchNorm
36
from torch.testing._internal.common_cuda import TEST_MULTIGPU
37
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
38
from torch.testing._internal.common_fsdp import (
39
    _maybe_cuda,
40
    CUDAInitMode,
41
    DummyProcessGroup,
42
    FSDPInitMode,
43
    FSDPTest,
44
    TransformerWithSharedParams,
45
)
46
from torch.testing._internal.common_utils import (
47
    FILE_SCHEMA,
48
    find_free_port,
49
    instantiate_parametrized_tests,
50
    parametrize,
51
    run_tests,
52
    TEST_CUDA,
53
    TestCase,
54
)
55

56

57
class BatchNormNet(nn.Module):
58
    def __init__(self):
59
        super().__init__()
60
        self.lin = nn.Linear(10, 10, bias=False)
61
        self.bn1 = nn.BatchNorm1d(10)
62
        self.bn2 = nn.BatchNorm2d(10)
63
        self.bn3 = nn.BatchNorm3d(10)
64
        self.sync_bn = nn.SyncBatchNorm(10)
65

66

67
class LoraModel(nn.Module):
68
    """This is a toy LoRA decoder model."""
69

70
    def __init__(self):
71
        super().__init__()
72
        self.embed_tokens = nn.Embedding(100, 32)
73
        self.layers = nn.ModuleList([LoraDecoder() for _ in range(4)])
74
        self.norm = nn.LayerNorm(32)
75
        self.embed_tokens.weight.requires_grad_(False)
76
        self.norm.weight.requires_grad_(False)
77
        self.norm.bias.requires_grad_(False)
78

79

80
class LoraDecoder(nn.Module):
81
    def __init__(self):
82
        super().__init__()
83
        self.attn = LoraAttention()
84
        self.mlp = LoraMLP()
85
        self.inp_layernorm = nn.LayerNorm(32)
86
        self.post_attn_layernorm = nn.LayerNorm(32)
87
        self.inp_layernorm.weight.requires_grad_(False)
88
        self.inp_layernorm.bias.requires_grad_(False)
89
        self.post_attn_layernorm.weight.requires_grad_(False)
90
        self.post_attn_layernorm.bias.requires_grad_(False)
91

92

93
class LoraAttention(nn.Module):
94
    def __init__(self):
95
        super().__init__()
96
        self.q_proj = nn.Linear(32, 32, bias=False)
97
        self.lora_A = nn.Linear(32, 8, bias=False)
98
        self.lora_B = nn.Linear(8, 32, bias=False)
99
        self.k_proj = nn.Linear(32, 32, bias=False)
100
        self.v_proj = nn.Linear(32, 32, bias=False)
101
        self.o_proj = nn.Linear(32, 32, bias=False)
102
        self.q_proj.weight.requires_grad_(False)
103
        self.k_proj.weight.requires_grad_(False)
104
        self.v_proj.weight.requires_grad_(False)
105
        self.o_proj.weight.requires_grad_(False)
106

107

108
class LoraMLP(nn.Module):
109
    def __init__(self):
110
        super().__init__()
111
        self.proj1 = nn.Linear(32, 128, bias=False)
112
        self.proj2 = nn.Linear(128, 32, bias=False)
113
        self.proj1.weight.requires_grad_(False)
114
        self.proj2.weight.requires_grad_(False)
115

116

117
class WrapMethod(Enum):
118
    FSDP_CTOR = auto()
119
    # FSDP_CTOR is the supported way forward, but keep WRAP_API in case we miss
120
    # any use cases and fix them to work with FSDP_CTOR over time.
121
    WRAP_API = auto()
122

123

124
class TestFSDPWrap(FSDPTest):
125
    """
126
    Tests main API for wrapping FSDP, which is to pass auto_wrap_policy into
127
    FSDP constructor.
128
    """
129

130
    def setUp(self) -> None:
131
        super().setUp()
132

133
    class NestedSequentialModel:
134
        @staticmethod
135
        def get_model(cuda=True):
136
            sequential = nn.Sequential(
137
                nn.Linear(5, 5),
138
                nn.Linear(5, 5),
139
                nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)),
140
            )
141
            if cuda:
142
                sequential = sequential.cuda()
143
            return sequential
144

145
        @staticmethod
146
        def verify_model_all_wrapped(cls, model):
147
            cls.assertTrue(isinstance(model, FSDP))
148
            cls.assertTrue(isinstance(model.module[0], FSDP))
149
            cls.assertTrue(isinstance(model.module[1], FSDP))
150
            cls.assertTrue(isinstance(model.module[2], FSDP))
151
            cls.assertTrue(isinstance(model.module[2].module[0], FSDP))
152
            cls.assertTrue(isinstance(model.module[2].module[1], FSDP))
153

154
        @staticmethod
155
        def verify_model(cls, model):
156
            cls.assertTrue(isinstance(model, FSDP))
157
            cls.assertTrue(isinstance(model.module[0], nn.Linear))
158
            cls.assertTrue(isinstance(model.module[1], nn.Linear))
159
            cls.assertTrue(isinstance(model.module[2], FSDP))
160
            # following modules were not wrapped by the policy.
161
            cls.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
162
            cls.assertTrue(isinstance(model.module[2].module[1], nn.Linear))
163

164
    def _get_linear(self, fin, fout):
165
        return nn.Linear(fin, fout, bias=False)
166

167
    def _get_already_wrapped_fsdp(
168
        self, cuda_init_mode=CUDAInitMode.CUDA_BEFORE, nested=False
169
    ) -> FSDP:
170
        fn_self = self
171

172
        class MyModel(nn.Module):
173
            def __init__(self, nested):
174
                super().__init__()
175
                # TODO: test the various init modes.
176
                move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
177
                # if nested=True, the FSDP module will be nested one layer deep
178
                # and we should pick that up.
179
                if nested:
180
                    self.lin1 = nn.Sequential(
181
                        _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda),
182
                        FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)),
183
                    )
184
                else:
185
                    self.lin1 = FSDP(
186
                        _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)
187
                    )
188
                self.lin2 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
189
                self.lin3 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
190

191
            def forward(self, input: torch.Tensor) -> torch.Tensor:
192
                return self.lin3(self.lin2(self.lin1(input)))
193

194
        model = MyModel(nested=nested)
195
        return model
196

197
    @skip_if_lt_x_gpu(2)
198
    @parametrize("nested", [True, False])
199
    @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE])
200
    def test_error_already_wrapped(self, nested, cuda_init_mode):
201
        """
202
        Test that an error is raised if we attempt to wrap when submodules are
203
        already FSDP.
204
        """
205
        wrapped_fsdp = self._get_already_wrapped_fsdp(
206
            nested=nested, cuda_init_mode=cuda_init_mode
207
        )
208
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
209
            wrapped_fsdp = wrapped_fsdp.cuda()
210

211
        wrapped_module_name = "lin1.1" if nested else "lin1"
212
        with self.assertRaisesRegex(
213
            ValueError,
214
            "FSDP auto wrapping requires modules to not already have FSDP "
215
            f"applied but found {wrapped_module_name} in",
216
        ):
217
            FSDP(wrapped_fsdp, auto_wrap_policy=size_based_auto_wrap_policy)
218

219
    @skip_if_lt_x_gpu(2)
220
    @parametrize("use_or_policy", [True, False])
221
    def test_wrap_batchnorm_individually(self, use_or_policy):
222
        def never_wrap_policy(*args, **kwargs):
223
            return False
224

225
        wrap_batchnorm_individually = functools.partial(
226
            _wrap_module_cls_individually,
227
            module_classes=[
228
                _BatchNorm,
229
            ],
230
        )
231
        policy = (
232
            functools.partial(
233
                _or_policy, policies=[never_wrap_policy, wrap_batchnorm_individually]
234
            )
235
            if use_or_policy
236
            else wrap_batchnorm_individually
237
        )
238
        model = BatchNormNet()
239
        fsdp = FSDP(model, auto_wrap_policy=policy)
240
        # Batchnorms should be wrapped
241
        for layer in [fsdp.bn1, fsdp.bn2, fsdp.bn3, fsdp.sync_bn]:
242
            self.assertTrue(isinstance(layer, FSDP))
243

244
        self.assertFalse(isinstance(fsdp.lin, FSDP))
245

246
    @skip_if_lt_x_gpu(2)
247
    def test_bn_always_wrapped_individually(self):
248
        """
249
        Ensures that by using _or_policy with _wrap_module_cls_individually, even
250
        if the other policy results in a module containing a BN unit being
251
        wrapped, the contained BN unit will still be individually wrapped.
252
        """
253

254
        class MyModule(nn.Module):
255
            def __init__(self):
256
                super().__init__()
257
                self.bn_container = BatchNormNet()
258

259
        def wrap_bn_container(module, recurse, *args, **kwargs):
260
            if recurse:
261
                return True
262
            return isinstance(module, BatchNormNet)
263

264
        wrap_batchnorm_individually = functools.partial(
265
            _wrap_module_cls_individually,
266
            module_classes=[
267
                _BatchNorm,
268
            ],
269
        )
270

271
        my_policy = functools.partial(
272
            _or_policy, policies=[wrap_bn_container, wrap_batchnorm_individually]
273
        )
274
        mod = MyModule()
275
        fsdp = FSDP(mod, auto_wrap_policy=my_policy)
276

277
        # Wrapping should be FSDP(FSDP(BatchNormNet(FSDP(BN))))
278
        # and not FSDP(FSDP(BatchNormNet(BN))) (in the latter the inner
279
        # BN is not individually wrapped.)
280

281
        for bn in [
282
            fsdp.bn_container.bn1,
283
            fsdp.bn_container.bn2,
284
            fsdp.bn_container.bn3,
285
            fsdp.bn_container.sync_bn,
286
        ]:
287
            self.assertTrue(isinstance(bn, FSDP))
288

289
        # if we just wrapped BN container, individual batchnorms are not
290
        # wrapped.
291
        mod = MyModule()
292
        fsdp = FSDP(mod, auto_wrap_policy=wrap_bn_container)
293
        self.assertTrue(isinstance(mod.bn_container, FSDP))
294
        for bn in [
295
            fsdp.bn_container.bn1,
296
            fsdp.bn_container.bn2,
297
            fsdp.bn_container.bn3,
298
            fsdp.bn_container.sync_bn,
299
        ]:
300
            self.assertFalse(isinstance(bn, FSDP))
301

302
    @skip_if_lt_x_gpu(2)
303
    @parametrize(
304
        "cpu_offload",
305
        [CPUOffload(offload_params=False), CPUOffload(offload_params=True)],
306
    )
307
    @parametrize(
308
        "backward_prefetch",
309
        [BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE],
310
    )
311
    @parametrize("forward_prefetch", [False, True])
312
    @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE])
313
    def test_main_wrap_api(
314
        self,
315
        cpu_offload: CPUOffload,
316
        backward_prefetch: BackwardPrefetch,
317
        forward_prefetch: bool,
318
        cuda_init_mode: CUDAInitMode,
319
    ):
320
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params:
321
            # they don't work together, expected
322
            return
323

324
        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
325

326
        class Nested(nn.Module):
327
            def __init__(self):
328
                super().__init__()
329
                self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
330

331
            def forward(self, input):
332
                return self.nested_lin(input)
333

334
        class MyModel(nn.Module):
335
            def __init__(self):
336
                super().__init__()
337
                self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
338
                self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
339
                self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
340
                self.lin4 = Nested()
341

342
            def forward(self, input):
343
                return self.lin4(self.lin3(self.lin2(self.lin1(input))))
344

345
        model = MyModel()
346
        wrapped_model = FSDP(
347
            model,
348
            auto_wrap_policy=functools.partial(
349
                size_based_auto_wrap_policy,
350
                min_num_params=0,  # wrap all modules
351
            ),
352
            cpu_offload=cpu_offload,
353
            backward_prefetch=backward_prefetch,
354
            forward_prefetch=forward_prefetch,
355
        )
356
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
357
            wrapped_model = wrapped_model.cuda()
358

359
        modules_in_fsdp_graph_order = [
360
            wrapped_model.module.lin1,
361
            wrapped_model.module.lin2,
362
            wrapped_model.module.lin3,
363
            wrapped_model.module.lin4.module.nested_lin,
364
            wrapped_model.module.lin4,
365
            wrapped_model,
366
        ]
367

368
        for module in modules_in_fsdp_graph_order:
369
            self.assertTrue(isinstance(module, FSDP))
370
            self._check_cpu_offload(module, cpu_offload)
371
            self._check_backward_prefetch(module, backward_prefetch)
372
            self._check_forward_prefetch(module, forward_prefetch)
373

374
        # Run model a few times for sanity check.
375
        optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9)
376
        inp = torch.ones(1).cuda()
377
        for _ in range(6):
378
            optim.zero_grad()
379
            loss = wrapped_model(inp).sum()
380
            loss.backward()
381
            optim.step()
382

383

384
class TestAutoWrap(TestCase):
385
    def setUp(self) -> None:
386
        super().setUp()
387
        # For all the tests here, we use a fake group
388
        self.process_group = DummyProcessGroup(rank=0, size=1)
389

390
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
391
    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
392
    def test_wrap(self, wrap_method):
393
        if wrap_method == WrapMethod.WRAP_API:
394
            with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
395
                layer = wrap(nn.Linear(5, 5))
396
        else:
397
            assert wrap_method == WrapMethod.FSDP_CTOR
398
            layer = FSDP(
399
                nn.Linear(5, 5),
400
                process_group=self.process_group,
401
                auto_wrap_policy=functools.partial(
402
                    size_based_auto_wrap_policy, min_num_params=1
403
                ),
404
            )
405
        self.assertTrue(isinstance(layer, FSDP))
406
        self.assertEqual(layer.rank, self.process_group.rank())
407
        self.assertEqual(layer.world_size, self.process_group.size())
408

409
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
410
    def test_wrap_disabled_outside_context(self):
411
        pg = self.process_group
412

413
        class MyModel(nn.Module):
414
            def __init__(self):
415
                super().__init__()
416
                self.lin = wrap(nn.Linear(5, 5), process_group=pg)
417

418
        model = MyModel()
419
        with enable_wrap(wrapper_cls=FSDP, process_group=pg):
420
            model = wrap(model)
421

422
        self.assertTrue(isinstance(model, FSDP))
423
        self.assertFalse(isinstance(model.lin, FSDP))
424
        self.assertTrue(isinstance(model.lin, nn.Linear))
425

426
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
427
    def test_wrap_override_defaults(self):
428
        new_process_group = DummyProcessGroup(rank=0, size=2)
429
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
430
            layer = wrap(nn.Linear(5, 5), process_group=new_process_group)
431
        self.assertTrue(isinstance(layer, FSDP))
432
        self.assertTrue(layer.process_group is new_process_group)
433
        self.assertEqual(layer.rank, 0)
434
        self.assertEqual(layer.world_size, 2)
435

436
    @unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")
437
    def test_always_wrap(self):
438
        """
439
        Test to ensure that if `always_wrap_policy` is
440
        passed into FSDP, all submodules are wrapped.
441
        """
442
        seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True)
443
        model = FSDP(
444
            seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy
445
        )
446
        TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped(self, model)
447

448
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
449
    def test_transformer_auto_wrap_policy(self):
450
        """Tests the ``transformer_auto_wrap_policy``."""
451
        auto_wrap_policy = functools.partial(
452
            transformer_auto_wrap_policy,
453
            transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
454
        )
455
        self._test_transformer_wrapping(auto_wrap_policy)
456

457
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
458
    def test_module_wrap_policy(self):
459
        """Tests the ``ModuleWrapPolicy``."""
460
        auto_wrap_policy = ModuleWrapPolicy(
461
            {TransformerEncoderLayer, TransformerDecoderLayer}
462
        )
463
        self._test_transformer_wrapping(auto_wrap_policy)
464

465
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
466
    def test_module_wrap_policy_callable(self):
467
        """Tests the ``ModuleWrapPolicy`` as a ``Callable``."""
468
        auto_wrap_policy = ModuleWrapPolicy(
469
            {TransformerEncoderLayer, TransformerDecoderLayer}
470
        )
471
        callable_policy = functools.partial(_or_policy, policies=[auto_wrap_policy])
472
        self._test_transformer_wrapping(callable_policy)
473

474
    def _test_transformer_wrapping(self, auto_wrap_policy: Union[Callable, _Policy]):
475
        fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
476
        fsdp_model = TransformerWithSharedParams.init(
477
            self.process_group,
478
            FSDPInitMode.RECURSIVE,
479
            CUDAInitMode.CUDA_BEFORE,
480
            fsdp_kwargs,
481
        )
482
        modules = list(fsdp_model.modules())
483
        encoder_layers = set(fsdp_model.module.transformer.encoder.layers)
484
        decoder_layers = set(fsdp_model.module.transformer.decoder.layers)
485
        for module in modules:
486
            if (
487
                module is fsdp_model
488
                or module in encoder_layers
489
                or module in decoder_layers
490
            ):
491
                self.assertTrue(isinstance(module, FSDP))
492
            else:
493
                self.assertFalse(isinstance(module, FSDP))
494

495
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
496
    def test_custom_policy(self):
497
        """
498
        Tests ``CustomPolicy`` with both a lambda function that uses uniform
499
        kwargs (so only returns ``False`` or ``True``) and a lambda function
500
        that uses non-uniform kwargs (so returns a dict to override the root
501
        kwargs).
502
        """
503
        for use_uniform_kwargs in [False, True]:
504
            self._test_custom_policy(use_uniform_kwargs)
505

506
    def _test_custom_policy(self, use_uniform_kwargs: bool):
507
        print(f"use_uniform_kwargs={use_uniform_kwargs}")
508
        model = TransformerWithSharedParams.init(
509
            self.process_group,
510
            FSDPInitMode.NO_FSDP,
511
            CUDAInitMode.CUDA_BEFORE,
512
            {},
513
        )
514

515
        if use_uniform_kwargs:
516

517
            def lambda_fn(module: nn.Module):
518
                if module is model.bn:
519
                    return True
520
                elif isinstance(
521
                    module, (TransformerEncoderLayer, TransformerDecoderLayer)
522
                ):
523
                    return True
524
                return False
525

526
        else:
527

528
            def lambda_fn(module: nn.Module):
529
                if module is model.bn:
530
                    return {"sharding_strategy": ShardingStrategy.NO_SHARD}
531
                elif isinstance(module, TransformerEncoderLayer):
532
                    return True
533
                elif isinstance(module, TransformerDecoderLayer):
534
                    return {
535
                        "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
536
                        "backward_prefetch": BackwardPrefetch.BACKWARD_POST,
537
                    }
538
                return False
539

540
        policy = CustomPolicy(lambda_fn)
541
        # Use a size-2 dummy PG to avoid clamping the sharding strategy to
542
        # `NO_SHARD` as for a size-1 PG
543
        process_group = DummyProcessGroup(rank=0, size=2)
544
        fp16_mp = MixedPrecision(param_dtype=torch.float16)
545
        fp32_mp = MixedPrecision()
546
        model = FSDP(
547
            model,
548
            process_group=process_group,
549
            auto_wrap_policy=policy,
550
            mixed_precision=fp16_mp,
551
        )
552
        encoder_layers = set(model.module.transformer.encoder.layers)
553
        decoder_layers = set(model.module.transformer.decoder.layers)
554
        bn = model.module.bn
555
        bn_strategy = (
556
            ShardingStrategy.FULL_SHARD
557
            if use_uniform_kwargs
558
            else ShardingStrategy.NO_SHARD
559
        )
560
        bn_prefetch = BackwardPrefetch.BACKWARD_PRE
561
        encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD
562
        encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE
563
        decoder_strategy = (
564
            ShardingStrategy.FULL_SHARD
565
            if use_uniform_kwargs
566
            else ShardingStrategy.SHARD_GRAD_OP
567
        )
568
        decoder_prefetch = (
569
            BackwardPrefetch.BACKWARD_PRE
570
            if use_uniform_kwargs
571
            else BackwardPrefetch.BACKWARD_POST
572
        )
573
        for module in model.modules():
574
            if module is bn:
575
                self.assertTrue(isinstance(module, FSDP))
576
                self.assertEqual(module.sharding_strategy, bn_strategy)
577
                self.assertEqual(module.backward_prefetch, bn_prefetch)
578
                # We currently override batch norm modules to use fp32
579
                self.assertEqual(module.mixed_precision, fp32_mp)
580
            elif module in encoder_layers:
581
                self.assertTrue(isinstance(module, FSDP))
582
                self.assertEqual(module.sharding_strategy, encoder_strategy)
583
                self.assertEqual(module.backward_prefetch, encoder_prefetch)
584
                self.assertEqual(module.mixed_precision, fp16_mp)
585
            elif module in decoder_layers:
586
                self.assertTrue(isinstance(module, FSDP))
587
                self.assertEqual(module.sharding_strategy, decoder_strategy)
588
                self.assertEqual(module.backward_prefetch, decoder_prefetch)
589
                self.assertEqual(module.mixed_precision, fp16_mp)
590
            elif module is model:
591
                self.assertTrue(isinstance(module, FSDP))
592
                self.assertEqual(module.sharding_strategy, root_strategy)
593
                self.assertEqual(module.backward_prefetch, root_prefetch)
594
                self.assertEqual(module.mixed_precision, fp16_mp)
595
            else:
596
                self.assertFalse(isinstance(module, FSDP))
597

598
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
599
    def test_auto_wrap_api(self):
600
        """
601
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
602
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
603
        """
604
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
605
        my_auto_wrap_policy = functools.partial(
606
            size_based_auto_wrap_policy, min_num_params=40
607
        )
608
        model = FSDP(
609
            sequential,
610
            process_group=self.process_group,
611
            auto_wrap_policy=my_auto_wrap_policy,
612
        )
613

614
        TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
615

616
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
617
    def test_auto_wrap_preset_exclude_wrap(self):
618
        """
619
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
620
        min_num_params. the size_based_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
621
        """
622
        sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
623
        my_auto_wrap_policy = functools.partial(
624
            size_based_auto_wrap_policy, min_num_params=40
625
        )
626

627
        model = FSDP(
628
            sequential,
629
            process_group=self.process_group,
630
            auto_wrap_policy=my_auto_wrap_policy,
631
        )
632

633
        self.assertTrue(isinstance(model, FSDP))
634
        self.assertTrue(isinstance(model[0], nn.Linear))
635
        self.assertTrue(isinstance(model[1], nn.Linear))
636

637
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
638
    def test_auto_wrap_preset_exclude_wrap_include_children(self):
639
        """
640
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
641
        min_num_params
642
        """
643
        sequential = nn.ModuleList([nn.Linear(10, 10)])
644
        my_auto_wrap_policy = functools.partial(
645
            size_based_auto_wrap_policy, min_num_params=40
646
        )
647
        model = FSDP(
648
            sequential,
649
            process_group=self.process_group,
650
            auto_wrap_policy=my_auto_wrap_policy,
651
        )
652

653
        self.assertTrue(isinstance(model, FSDP))
654
        self.assertTrue(isinstance(model[0], FSDP))
655

656
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
657
    def test_auto_wrap_preset_force_leaf(self):
658
        """
659
        Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
660
        size_based_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
661
        """
662
        sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
663
        my_auto_wrap_policy = functools.partial(
664
            size_based_auto_wrap_policy, min_num_params=40
665
        )
666
        model = FSDP(
667
            sequential,
668
            process_group=self.process_group,
669
            auto_wrap_policy=my_auto_wrap_policy,
670
        )
671
        self.assertTrue(isinstance(model.module[0], FSDP))
672
        # Assert children of multihead attention are not wrapped
673
        self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
674
        self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))
675

676
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
677
    def test_auto_wrap_preset_force_leaf_custom(self):
678
        """
679
        Test to ensure force-leaf modules are not wrapped.
680
        """
681
        my_auto_wrap_policy = functools.partial(
682
            size_based_auto_wrap_policy,
683
            min_num_params=40,
684
            force_leaf_modules=size_based_auto_wrap_policy.FORCE_LEAF_MODULES.union(
685
                {nn.Linear}
686
            ),
687
        )
688
        sequential = nn.Sequential(
689
            nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])
690
        )
691
        model = FSDP(
692
            sequential,
693
            process_group=self.process_group,
694
            auto_wrap_policy=my_auto_wrap_policy,
695
        )
696
        # Model was wrapped in FSDP as no inner modules were wrapped.
697
        self.assertTrue(isinstance(model, FSDP))
698
        self.assertTrue(isinstance(model.module[0], nn.Linear))
699
        self.assertTrue(isinstance(model.module[1], nn.ModuleList))
700

701
    @unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")
702
    @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_BEFORE, CUDAInitMode.CUDA_AFTER])
703
    @parametrize(
704
        "cpu_offload",
705
        [CPUOffload(offload_params=False), CPUOffload(offload_params=True)],
706
    )
707
    @parametrize("use_device_id", [True, False])
708
    def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id):
709
        # CPU offload and CUDA after don't work together as expected.
710
        if cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER:
711
            return
712

713
        device = torch.device("cuda")
714
        torch.cuda.set_device(0)
715
        device_id = (
716
            torch.device("cuda", torch.cuda.current_device()) if use_device_id else None
717
        )
718

719
        # Random port in case the next test run quickly, same port would cause conflict.
720
        os.environ["MASTER_ADDR"] = "localhost"
721
        os.environ["MASTER_PORT"] = str(find_free_port())
722

723
        file_name = tempfile.NamedTemporaryFile(delete=False).name
724
        torch.distributed.init_process_group(
725
            backend="nccl",
726
            init_method=f"{FILE_SCHEMA}_{file_name}",
727
            rank=0,
728
            world_size=1,
729
        )
730

731
        # NOTE: We move model to CUDA after init with FSDP to simulate real use
732
        # cases where full model cannot be loaded onto GPU, but their shards can.
733
        cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER
734
        try:
735
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(
736
                cuda=(not cuda_after_init)
737
            )
738
            my_auto_wrap_policy = functools.partial(
739
                size_based_auto_wrap_policy, min_num_params=40
740
            )
741
            model = FSDP(
742
                sequential,
743
                cpu_offload=cpu_offload,
744
                auto_wrap_policy=my_auto_wrap_policy,
745
                device_id=device_id,
746
            )
747
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
748
            if cuda_after_init:
749
                model = model.cuda()
750
            input = torch.rand((1, 5), dtype=torch.float).to(device)
751
            output = model(input)
752
            loss = F.mse_loss(input, output)
753
            loss.backward()
754
        finally:
755
            torch.distributed.destroy_process_group()
756

757
        try:
758
            os.remove(file_name)
759
        except FileNotFoundError:
760
            pass
761

762
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
763
    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
764
    def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
765
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
766
        ignored_modules = [sequential[1], sequential[2][0]]
767
        fsdp_kwargs = {
768
            "process_group": self.process_group,
769
            "auto_wrap_policy": always_wrap_policy,
770
            "ignored_modules": ignored_modules,
771
        }
772
        if wrap_method == WrapMethod.FSDP_CTOR:
773
            model = FSDP(sequential, **fsdp_kwargs)
774
        elif wrap_method == WrapMethod.WRAP_API:
775
            with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
776
                model = wrap(sequential)
777
        else:
778
            assert 0, f"Unsupported wrap method: {wrap_method}"
779
        # All non-ignored modules should be wrapped with FSDP
780
        self.assertTrue(isinstance(model, FSDP))
781
        self.assertTrue(isinstance(model.module[0], FSDP))
782
        self.assertTrue(isinstance(model.module[1], nn.Linear))
783
        self.assertTrue(isinstance(model.module[2], FSDP))
784
        self.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
785
        self.assertTrue(isinstance(model.module[2].module[1], FSDP))
786

787
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
788
    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
789
    def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
790
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
791
        ignored_modules = [sequential[1], sequential[2][0]]
792
        my_auto_wrap_policy = functools.partial(
793
            size_based_auto_wrap_policy,
794
            min_num_params=40,
795
        )
796
        fsdp_kwargs = {
797
            "process_group": self.process_group,
798
            "auto_wrap_policy": my_auto_wrap_policy,
799
            "ignored_modules": ignored_modules,
800
        }
801
        if wrap_method == WrapMethod.FSDP_CTOR:
802
            model = FSDP(sequential, **fsdp_kwargs)
803
        elif wrap_method == WrapMethod.WRAP_API:
804
            with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
805
                model = wrap(sequential)
806
        else:
807
            assert 0, f"Unsupported wrap method: {wrap_method}"
808
        # Since the 2nd linear (`sequential[1]`) is ignored, the wrapping
809
        # policy does not exceed the parameter threshold before the inner
810
        # sequential (`sequential[2]`) anymore; hence, it flattens
811
        # `sequential[0]` and `sequential[2][0]` into `model` and leaves
812
        # `sequential[1]` and `sequential[2][1]` as-is since they are ignored
813
        self.assertTrue(isinstance(model, FSDP))
814
        self.assertTrue(isinstance(model.module[0], nn.Linear))
815
        self.assertTrue(isinstance(model.module[1], nn.Linear))
816
        self.assertTrue(isinstance(model.module[2], nn.Sequential))
817
        self.assertTrue(isinstance(model.module[2][0], nn.Linear))
818
        self.assertTrue(isinstance(model.module[2][1], nn.Linear))
819

820
    @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
821
    def test_frozen_params(self):
822
        """
823
        Tests that mixing frozen/non-frozen parameters in an FSDP instance
824
        raises for ``use_orig_params=False`` and warns for ``True``.
825
        """
826
        module_classes = (LoraAttention, LoraMLP, LoraDecoder)
827
        module_wrap_policy = ModuleWrapPolicy(module_classes)
828

829
        def lambda_fn_uniform(module: nn.Module):
830
            return isinstance(module, module_classes)
831

832
        def lambda_fn_nonuniform(module: nn.Module):
833
            if isinstance(module, LoraAttention):
834
                return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
835
            elif isinstance(module, module_classes):
836
                return True
837
            return False
838

839
        lambda_wrap_policy_uniform = CustomPolicy(lambda_fn_uniform)
840
        lambda_wrap_policy_nonuniform = CustomPolicy(lambda_fn_nonuniform)
841

842
        for use_orig_params, policy in itertools.product(
843
            [True, False],
844
            [
845
                module_wrap_policy,
846
                lambda_wrap_policy_uniform,
847
                lambda_wrap_policy_nonuniform,
848
            ],
849
        ):
850
            self._test_frozen_params(use_orig_params, policy)
851

852
    def _test_frozen_params(self, use_orig_params: bool, policy: _Policy):
853
        model = LoraModel().cuda()
854
        msg = "layers.0.attn has both parameters with requires_grad=True and False. "
855
        if use_orig_params:
856
            msg += "We do not recommend wrapping such modules"
857
            ctx = self.assertWarnsRegex(UserWarning, msg)
858
        else:
859
            msg += "FSDP does not support wrapping such modules when use_orig_params=False."
860
            ctx = self.assertRaisesRegex(ValueError, msg)
861
        with ctx:
862
            FSDP(
863
                model,
864
                process_group=self.process_group,
865
                auto_wrap_policy=policy,
866
                use_orig_params=use_orig_params,
867
            )
868

869

870
class TestWrapUtils(TestCase):
871
    def test_validate_frozen_params(self):
872
        """Tests the method ``_validate_frozen_params()``."""
873
        for use_orig_params in [True, False]:
874
            self._test_validate_frozen_params(use_orig_params)
875

876
    def _test_validate_frozen_params(self, use_orig_params: bool):
877
        model = LoraModel()
878
        # Wrap only LoRA modules
879
        modules_to_wrap = {
880
            module
881
            for module_name, module in model.named_modules()
882
            if "lora_A" in module_name or "lora_B" in module_name
883
        }
884
        _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
885
        # Additionally wrap attention
886
        for module in model.modules():
887
            if isinstance(module, LoraAttention):
888
                modules_to_wrap.add(module)
889
        _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
890
        # Additionally wrap decoders
891
        for module in model.modules():
892
            if isinstance(module, LoraDecoder):
893
                modules_to_wrap.add(module)
894
        _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
895
        # Do not wrap the LoRA-A modules (meaning mixed frozen/non-frozen)
896
        for module_name, module in model.named_modules():
897
            if "lora_A" in module_name:
898
                modules_to_wrap.remove(module)
899
        regex = "layers.0.attn has both parameters with requires_grad=True and False."
900
        if use_orig_params:
901
            # Wrapping the attention manages all parameters except those from
902
            # the LoRA-B module, which is separately wrapped and all nonfrozen
903
            lorab_numel = sum(
904
                p.numel() for p in model.layers[0].attn.lora_B.parameters()
905
            )
906
            attn_frozen_param_numel = sum(
907
                p.numel()
908
                for p in model.layers[0].attn.parameters()
909
                if not p.requires_grad
910
            )
911
            attn_nonfrozen_param_numel = (
912
                sum(
913
                    p.numel()
914
                    for p in model.layers[0].attn.parameters()
915
                    if p.requires_grad
916
                )
917
                - lorab_numel
918
            )
919
            attn_total_param_numel = (
920
                attn_frozen_param_numel + attn_nonfrozen_param_numel
921
            )
922
            regex += (
923
                " We do not recommend wrapping such modules since the "
924
                r"gradient memory usage will be higher than expected \("
925
                f"{attn_total_param_numel} numel instead of {attn_nonfrozen_param_numel} numel "
926
                r"before sharding via reduce-scatter\). "
927
            )
928
        else:
929
            regex += " FSDP does not support wrapping such modules when use_orig_params=False. "
930
        regex += "If possible, wrap the frozen parameters with FSDP separately.\n"
931
        regex += (
932
            "The following parameters have requires_grad=True:\n"
933
            r"\['layers.0.attn.lora_A.weight'\]\n"
934
            "The following parameters have requires_grad=False:\n"
935
            r"\['layers.0.attn.q_proj.weight', 'layers.0.attn.k_proj.weight', "
936
            r"'layers.0.attn.v_proj.weight', 'layers.0.attn.o_proj.weight'\]"
937
        )
938
        if use_orig_params:
939
            ctx = self.assertWarnsRegex(UserWarning, regex)
940
        else:
941
            ctx = self.assertRaisesRegex(ValueError, regex)
942
        with ctx:
943
            _validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
944
        # Now ignore those LoRA-A modules' parameters
945
        ignored_params = set()
946
        for module_name, module in model.named_modules():
947
            if "lora_A" in module_name:
948
                for param in module.parameters():
949
                    ignored_params.add(param)
950
        _validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params)
951

952

953
instantiate_parametrized_tests(TestFSDPWrap)
954
instantiate_parametrized_tests(TestAutoWrap)
955

956
if __name__ == "__main__":
957
    run_tests()
958

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.