pytorch

Форк
0
/
test_fsdp_hybrid_shard.py 
445 строк · 16.4 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import contextlib
4
import sys
5
from collections import Counter
6
from enum import auto, Enum
7
from functools import partial
8
from typing import List, Optional, Tuple
9

10
import torch
11
import torch.distributed as dist
12
import torch.distributed.fsdp._traversal_utils as traversal_utils
13
import torch.nn as nn
14
from torch.distributed.device_mesh import init_device_mesh
15
from torch.distributed.distributed_c10d import _rank_not_in_group
16
from torch.distributed.fsdp import (
17
    FullyShardedDataParallel as FSDP,
18
    ShardingStrategy,
19
    StateDictType,
20
)
21
from torch.distributed.fsdp._init_utils import (
22
    _init_intra_and_inter_node_groups,
23
    HYBRID_SHARDING_STRATEGIES,
24
)
25

26
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
27
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
28
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
29
from torch.testing._internal.common_fsdp import (
30
    CUDAInitMode,
31
    FSDPInitMode,
32
    FSDPTest,
33
    TransformerWithSharedParams,
34
)
35
from torch.testing._internal.common_utils import (
36
    instantiate_parametrized_tests,
37
    run_tests,
38
    TEST_WITH_DEV_DBG_ASAN,
39
)
40

41
if not dist.is_available():
42
    print("Distributed not available, skipping tests", file=sys.stderr)
43
    sys.exit(0)
44

45
if TEST_WITH_DEV_DBG_ASAN:
46
    print(
47
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
48
        file=sys.stderr,
49
    )
50
    sys.exit(0)
51

52

53
@contextlib.contextmanager
54
def patch_allreduce(new_allreduce):
55
    """
56
    Patches dist.all_reduce with a new all_reduce and
57
    restores upon exiting.
58
    """
59
    orig_ar = dist.all_reduce
60
    dist.all_reduce = new_allreduce
61
    try:
62
        yield
63
    finally:
64
        dist.all_reduce = orig_ar
65

66

67
@contextlib.contextmanager
68
def patch_reduce_scatter(new_reduce_scatter):
69
    """
70
    Patches dist.reduce_scatter_tensor with a new reduce_scatter_tensor and
71
    restores upon exiting.
72
    """
73
    orig_reduce_scatter = dist.reduce_scatter_tensor
74
    dist.reduce_scatter_tensor = new_reduce_scatter
75
    try:
76
        yield
77
    finally:
78
        dist.reduce_scatter_tensor = orig_reduce_scatter
79

80

81
class MyModel(nn.Module):
82
    def __init__(self):
83
        super().__init__()
84
        self.lin1 = nn.Linear(10, 10)
85
        self.lin2 = nn.Linear(10, 10)
86
        self.lin3 = nn.Linear(10, 10)
87

88
    def forward(self, x):
89
        return self.lin3(self.lin2(self.lin1(x)))
90

91

92
class ShardingStrategyMode(Enum):
93
    ALL_HYBRID_SHARD = auto()
94
    MIXED_HYBRID_FULL_SHARD = auto()
95

96

97
class TestFSDPHybridShard(FSDPTest):
98
    @property
99
    def world_size(self):
100
        return max(torch.cuda.device_count(), 2)
101

102
    @property
103
    def process_group(self):
104
        return dist.distributed_c10d._get_default_group()
105

106
    @skip_if_lt_x_gpu(2)
107
    def test_raises_manual_wrap_hybrid_shard_when_none_policy(self):
108
        model = MyModel().cuda()
109
        err_ctx = self.assertRaisesRegex(
110
            ValueError,
111
            "requires explicit specification of process group or device_mesh.",
112
        )
113

114
        with err_ctx:
115
            model = FSDP(model, sharding_strategy=ShardingStrategy.HYBRID_SHARD)
116

117
        with err_ctx:
118
            model = FSDP(model, sharding_strategy=ShardingStrategy._HYBRID_SHARD_ZERO2)
119

120
    @skip_if_lt_x_gpu(4)
121
    def test_hsdp_save_load_state_dict(self):
122
        model = MyModel().cuda()
123
        num_node_devices = torch.cuda.device_count()
124
        shard_rank_lists = list(range(0, num_node_devices // 2)), list(
125
            range(num_node_devices // 2, num_node_devices)
126
        )
127
        shard_groups = (
128
            dist.new_group(shard_rank_lists[0]),
129
            dist.new_group(shard_rank_lists[1]),
130
        )
131
        my_shard_group = (
132
            shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1]
133
        )
134
        my_replicate_group = None
135
        my_rank = self.rank
136
        # Create groups like (0, 4), (1, 5), (2, 6) etc and assign appropriately
137
        shard_factor = len(shard_rank_lists[0])
138
        for i in range(num_node_devices // 2):
139
            replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
140
            replicate_group = dist.new_group(replicate_group_ranks)
141
            if my_rank in replicate_group_ranks:
142
                my_replicate_group = replicate_group
143

144
        fsdp_ctor = partial(
145
            FSDP,
146
            sharding_strategy=ShardingStrategy.HYBRID_SHARD,
147
            use_orig_params=True,
148
            process_group=(my_shard_group, my_replicate_group),
149
        )
150
        model = fsdp_ctor(model)
151
        optim = torch.optim.AdamW(model.parameters())
152
        # Initialize optimizer states
153
        model(torch.randn(2, 10)).sum().backward()
154
        optim.step()
155
        shard_g = model.process_group
156
        replicate_g = model._inter_node_pg
157
        assert shard_g == my_shard_group
158
        assert replicate_g == my_replicate_group
159
        with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
160
            msd = model.state_dict()
161
            osd = FSDP.optim_state_dict(model, optim)
162

163
        load_model = fsdp_ctor(MyModel().cuda())
164
        load_optim = torch.optim.AdamW(load_model.parameters())
165
        with FSDP.state_dict_type(load_model, StateDictType.SHARDED_STATE_DICT):
166
            load_model.load_state_dict(msd)
167
            FSDP.optim_state_dict_to_load(load_model, load_optim, osd)
168
        load_optim.load_state_dict(osd)
169

170
    @skip_if_lt_x_gpu(4)
171
    def test_hsdp_sync_module_state(self):
172
        model = MyModel().cuda()
173
        num_node_devices = torch.cuda.device_count()
174
        shard_rank_lists = list(range(0, num_node_devices // 2)), list(
175
            range(num_node_devices // 2, num_node_devices)
176
        )
177
        shard_groups = (
178
            dist.new_group(shard_rank_lists[0]),
179
            dist.new_group(shard_rank_lists[1]),
180
        )
181
        my_shard_group = (
182
            shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1]
183
        )
184
        my_replicate_group = None
185
        my_rank = self.rank
186
        # Create groups like (0, 4), (1, 5), (2, 6) etc and assign appropriately
187
        shard_factor = len(shard_rank_lists[0])
188
        for i in range(num_node_devices // 2):
189
            replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
190
            replicate_group = dist.new_group(replicate_group_ranks)
191
            if my_rank in replicate_group_ranks:
192
                my_replicate_group = replicate_group
193

194
        nn.init.constant_(model.lin1.weight, self.rank)
195
        nn.init.constant_(model.lin2.weight, self.rank)
196
        nn.init.constant_(model.lin3.weight, self.rank)
197

198
        fsdp_ctor = partial(
199
            FSDP,
200
            sharding_strategy=ShardingStrategy.HYBRID_SHARD,
201
            use_orig_params=True,
202
            sync_module_states=True,
203
            process_group=(my_shard_group, my_replicate_group),
204
        )
205
        model = fsdp_ctor(model)
206

207
        with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
208
            self.assertTrue((model.lin1.weight == 0).all())
209
            self.assertTrue((model.lin2.weight == 0).all())
210
            self.assertTrue((model.lin3.weight == 0).all())
211

212
    @skip_if_lt_x_gpu(2)
213
    def test_invalid_pg_specification_raises(self):
214
        pol = ModuleWrapPolicy({nn.Linear})
215
        model = MyModel().cuda()
216
        with self.assertRaisesRegex(
217
            ValueError, "Expected process_group to be passed in"
218
        ):
219
            model = FSDP(
220
                model,
221
                auto_wrap_policy=pol,
222
                process_group=self.process_group,
223
                sharding_strategy=ShardingStrategy.HYBRID_SHARD,
224
            )
225

226
    # TODO - add test for ZeRO-2 style sharding ensure params are not
227
    # resharded after forward.
228

229
    @skip_if_lt_x_gpu(2)
230
    def test_fsdp_hybrid_shard_basic_setup(self):
231
        """
232
        Tests basic functionality of HYBRID_SHARD and _HYBRID_SHARD_ZERO2:
233
            1. Inter and intra-node process groups are correctly setup
234
            2. Process groups are the same across FSDP wrapped instances
235
            3. reduce_scatter and allreduce called the expected no. of times
236
        """
237
        self.run_subtests(
238
            {
239
                "hsdp_sharding_strategy": [
240
                    ShardingStrategy.HYBRID_SHARD,
241
                    ShardingStrategy._HYBRID_SHARD_ZERO2,
242
                ],
243
                "sharding_strategy_mode": [
244
                    ShardingStrategyMode.ALL_HYBRID_SHARD,
245
                    ShardingStrategyMode.MIXED_HYBRID_FULL_SHARD,
246
                ],
247
                "use_orig_params": [False, True],
248
                "use_device_mesh": [False, True],
249
            },
250
            self._test_fsdp_hybrid_shard_basic_setup,
251
        )
252

253
    def _test_fsdp_hybrid_shard_basic_setup(
254
        self,
255
        hsdp_sharding_strategy: ShardingStrategy,
256
        sharding_strategy_mode: ShardingStrategyMode,
257
        use_orig_params: bool,
258
        use_device_mesh: bool,
259
    ):
260
        if use_device_mesh:
261
            device_mesh = init_device_mesh("cuda", (1, self.world_size))
262
        else:
263
            device_mesh = None
264
        hsdp_model = self._init_hsdp_model(
265
            hsdp_sharding_strategy,
266
            sharding_strategy_mode,
267
            use_orig_params,
268
            hsdp_device_mesh=device_mesh,
269
        )
270
        # All FSDP modules should have state.process_group as the process group over which to
271
        # shard (default process group), and state._inter_node_pg (process group containing only
272
        # this rank)
273
        intra_node_pgs = set()
274
        inter_node_pgs = set()
275
        for fsdp_module in hsdp_model.fsdp_modules(hsdp_model):
276
            # TODO: This needs to be replaced if we deprecate
277
            # `FSDP.sharding_strategy` to only use the handle one.
278
            # https://github.com/pytorch/pytorch/issues/90857
279
            if fsdp_module.sharding_strategy not in HYBRID_SHARDING_STRATEGIES:
280
                self.assertEqual(
281
                    sharding_strategy_mode, ShardingStrategyMode.MIXED_HYBRID_FULL_SHARD
282
                )
283
                self.assertEqual(
284
                    fsdp_module.sharding_strategy, ShardingStrategy.FULL_SHARD
285
                )
286
                continue
287
            # process_group should be across the node, which is just the
288
            # whole world here.
289
            self.assertEqual(
290
                dist.get_world_size(fsdp_module.process_group),
291
                dist.get_world_size(self.process_group),
292
            )
293
            intra_node_pgs.add(fsdp_module.process_group)
294
            inter_node_pg = fsdp_module._inter_node_pg
295
            inter_node_pgs.add(inter_node_pg)
296
            self.assertEqual(1, dist.get_world_size(inter_node_pg))
297
            self.assertFalse(_rank_not_in_group(inter_node_pg))
298
            self.assertEqual(hsdp_sharding_strategy, fsdp_module.sharding_strategy)
299
        # All fsdp modules should share the same process groups
300
        self.assertEqual(1, len(intra_node_pgs))
301
        self.assertEqual(1, len(inter_node_pgs))
302

303
        orig_ar = dist.all_reduce
304
        orig_rs = dist.reduce_scatter_tensor
305

306
        def patched_collective(orig_collective, counter, *args, **kwargs):
307
            counter[orig_collective] += 1
308
            return orig_collective(*args, **kwargs)
309

310
        cntr = Counter()
311
        patched_allreduce = partial(patched_collective, orig_ar, cntr)
312
        patched_reduce_scatter = partial(patched_collective, orig_rs, cntr)
313
        with patch_allreduce(patched_allreduce), patch_reduce_scatter(
314
            patched_reduce_scatter
315
        ):
316
            inp = hsdp_model.get_input(device=torch.cuda.current_device())
317
            out = hsdp_model(inp[0], inp[1])
318
            loss = hsdp_model.get_loss(inp, out)
319
            loss.backward()
320

321
        if sharding_strategy_mode == ShardingStrategyMode.ALL_HYBRID_SHARD:
322
            num_flat_params = len(list(traversal_utils._get_fsdp_handles(hsdp_model)))
323
            self.assertEqual(num_flat_params, cntr[orig_ar])
324
            self.assertEqual(num_flat_params, cntr[orig_rs])
325
        elif sharding_strategy_mode == ShardingStrategyMode.MIXED_HYBRID_FULL_SHARD:
326
            num_hsdp_flat_params = len(
327
                list(traversal_utils._get_fsdp_handles(hsdp_model.transformer))
328
            )
329
            num_flat_params = len(list(traversal_utils._get_fsdp_handles(hsdp_model)))
330
            self.assertEqual(num_hsdp_flat_params, cntr[orig_ar])
331
            self.assertEqual(num_flat_params, cntr[orig_rs])
332

333
    @skip_if_lt_x_gpu(4)
334
    def test_fsdp_hybrid_shard_parity(self):
335
        self.run_subtests(
336
            {
337
                "hsdp_sharding_strategy": [
338
                    ShardingStrategy.HYBRID_SHARD,
339
                    ShardingStrategy._HYBRID_SHARD_ZERO2,
340
                ],
341
                "use_orig_params": [False, True],
342
            },
343
            self._test_fsdp_hybrid_shard_parity,
344
        )
345

346
    def _test_fsdp_hybrid_shard_parity(
347
        self, hsdp_sharding_strategy: ShardingStrategy, use_orig_params: bool
348
    ):
349
        fsdp_model = self._init_fsdp_model(use_orig_params)
350
        global_pg = dist.distributed_c10d._get_default_group()
351
        hsdp_pgs = _init_intra_and_inter_node_groups(global_pg, 2)
352
        hsdp_model = self._init_hsdp_model(
353
            hsdp_sharding_strategy,
354
            ShardingStrategyMode.ALL_HYBRID_SHARD,
355
            use_orig_params,
356
            hsdp_process_groups=hsdp_pgs,
357
        )
358
        assert (
359
            hsdp_model._inter_node_pg.size() > 1
360
        ), "HSDP model initialized without replication"
361
        fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
362
        hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2)
363
        torch.manual_seed(global_pg.rank() + 1)
364
        for _ in range(5):
365
            inp = fsdp_model.module.get_input(torch.device("cuda"))
366
            losses: List[torch.Tensor] = []
367
            for model, optim in ((fsdp_model, fsdp_optim), (hsdp_model, hsdp_optim)):
368
                optim.zero_grad()
369
                loss = model(*inp).sum()
370
                losses.append(loss)
371
                loss.backward()
372
                optim.step()
373
            self.assertEqual(losses[0], losses[1])
374

375
    def _init_fsdp_model(self, use_orig_params: bool) -> nn.Module:
376
        auto_wrap_policy = ModuleWrapPolicy(
377
            {TransformerEncoderLayer, TransformerDecoderLayer},
378
        )
379
        hsdp_kwargs = {
380
            "auto_wrap_policy": auto_wrap_policy,
381
            "device_id": torch.cuda.current_device(),
382
            "use_orig_params": use_orig_params,
383
        }
384
        fsdp_model = TransformerWithSharedParams.init(
385
            self.process_group,
386
            FSDPInitMode.RECURSIVE,
387
            CUDAInitMode.CUDA_BEFORE,
388
            hsdp_kwargs,
389
            deterministic=True,
390
        )
391
        return fsdp_model
392

393
    def _init_hsdp_model(
394
        self,
395
        hsdp_sharding_strategy: ShardingStrategy,
396
        sharding_strategy_mode: str,
397
        use_orig_params: bool,
398
        hsdp_process_groups: Optional[
399
            Tuple[dist.ProcessGroup, dist.ProcessGroup]
400
        ] = None,
401
        hsdp_device_mesh: Optional = None,
402
    ):
403
        assert hsdp_process_groups is None or hsdp_device_mesh is None
404
        auto_wrap_policy = ModuleWrapPolicy(
405
            {TransformerEncoderLayer, TransformerDecoderLayer},
406
        )
407
        hsdp_kwargs = {
408
            "device_id": torch.cuda.current_device(),
409
            "auto_wrap_policy": auto_wrap_policy,
410
            "sharding_strategy": hsdp_sharding_strategy,
411
            "use_orig_params": use_orig_params,
412
            "device_mesh": hsdp_device_mesh,
413
        }
414
        if sharding_strategy_mode == ShardingStrategyMode.ALL_HYBRID_SHARD:
415
            hsdp_model = TransformerWithSharedParams.init(
416
                hsdp_process_groups or self.process_group,
417
                FSDPInitMode.RECURSIVE,
418
                CUDAInitMode.CUDA_BEFORE,
419
                hsdp_kwargs,
420
                deterministic=True,
421
            )
422
        elif sharding_strategy_mode == ShardingStrategyMode.MIXED_HYBRID_FULL_SHARD:
423
            model = TransformerWithSharedParams.init(
424
                hsdp_process_groups or self.process_group,
425
                FSDPInitMode.NO_FSDP,
426
                CUDAInitMode.CUDA_BEFORE,
427
                {},
428
                deterministic=True,
429
            )
430
            # Use the HSDP strategy for the transformer module
431
            model.transformer = FSDP(model.transformer, **hsdp_kwargs)
432
            # Use `FULL_SHARD` for the embedding and output projection
433
            hsdp_model = FSDP(
434
                model,
435
                device_id=torch.cuda.current_device(),
436
                sharding_strategy=ShardingStrategy.FULL_SHARD,
437
                use_orig_params=use_orig_params,
438
            )
439
        return hsdp_model
440

441

442
instantiate_parametrized_tests(TestFSDPHybridShard)
443

444
if __name__ == "__main__":
445
    run_tests()
446

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.