pytorch
445 строк · 16.4 Кб
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import sys
5from collections import Counter
6from enum import auto, Enum
7from functools import partial
8from typing import List, Optional, Tuple
9
10import torch
11import torch.distributed as dist
12import torch.distributed.fsdp._traversal_utils as traversal_utils
13import torch.nn as nn
14from torch.distributed.device_mesh import init_device_mesh
15from torch.distributed.distributed_c10d import _rank_not_in_group
16from torch.distributed.fsdp import (
17FullyShardedDataParallel as FSDP,
18ShardingStrategy,
19StateDictType,
20)
21from torch.distributed.fsdp._init_utils import (
22_init_intra_and_inter_node_groups,
23HYBRID_SHARDING_STRATEGIES,
24)
25
26from torch.distributed.fsdp.wrap import ModuleWrapPolicy
27from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
28from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
29from torch.testing._internal.common_fsdp import (
30CUDAInitMode,
31FSDPInitMode,
32FSDPTest,
33TransformerWithSharedParams,
34)
35from torch.testing._internal.common_utils import (
36instantiate_parametrized_tests,
37run_tests,
38TEST_WITH_DEV_DBG_ASAN,
39)
40
41if not dist.is_available():
42print("Distributed not available, skipping tests", file=sys.stderr)
43sys.exit(0)
44
45if TEST_WITH_DEV_DBG_ASAN:
46print(
47"Skip dev-asan as torch + multiprocessing spawn have known issues",
48file=sys.stderr,
49)
50sys.exit(0)
51
52
53@contextlib.contextmanager
54def patch_allreduce(new_allreduce):
55"""
56Patches dist.all_reduce with a new all_reduce and
57restores upon exiting.
58"""
59orig_ar = dist.all_reduce
60dist.all_reduce = new_allreduce
61try:
62yield
63finally:
64dist.all_reduce = orig_ar
65
66
67@contextlib.contextmanager
68def patch_reduce_scatter(new_reduce_scatter):
69"""
70Patches dist.reduce_scatter_tensor with a new reduce_scatter_tensor and
71restores upon exiting.
72"""
73orig_reduce_scatter = dist.reduce_scatter_tensor
74dist.reduce_scatter_tensor = new_reduce_scatter
75try:
76yield
77finally:
78dist.reduce_scatter_tensor = orig_reduce_scatter
79
80
81class MyModel(nn.Module):
82def __init__(self):
83super().__init__()
84self.lin1 = nn.Linear(10, 10)
85self.lin2 = nn.Linear(10, 10)
86self.lin3 = nn.Linear(10, 10)
87
88def forward(self, x):
89return self.lin3(self.lin2(self.lin1(x)))
90
91
92class ShardingStrategyMode(Enum):
93ALL_HYBRID_SHARD = auto()
94MIXED_HYBRID_FULL_SHARD = auto()
95
96
97class TestFSDPHybridShard(FSDPTest):
98@property
99def world_size(self):
100return max(torch.cuda.device_count(), 2)
101
102@property
103def process_group(self):
104return dist.distributed_c10d._get_default_group()
105
106@skip_if_lt_x_gpu(2)
107def test_raises_manual_wrap_hybrid_shard_when_none_policy(self):
108model = MyModel().cuda()
109err_ctx = self.assertRaisesRegex(
110ValueError,
111"requires explicit specification of process group or device_mesh.",
112)
113
114with err_ctx:
115model = FSDP(model, sharding_strategy=ShardingStrategy.HYBRID_SHARD)
116
117with err_ctx:
118model = FSDP(model, sharding_strategy=ShardingStrategy._HYBRID_SHARD_ZERO2)
119
120@skip_if_lt_x_gpu(4)
121def test_hsdp_save_load_state_dict(self):
122model = MyModel().cuda()
123num_node_devices = torch.cuda.device_count()
124shard_rank_lists = list(range(0, num_node_devices // 2)), list(
125range(num_node_devices // 2, num_node_devices)
126)
127shard_groups = (
128dist.new_group(shard_rank_lists[0]),
129dist.new_group(shard_rank_lists[1]),
130)
131my_shard_group = (
132shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1]
133)
134my_replicate_group = None
135my_rank = self.rank
136# Create groups like (0, 4), (1, 5), (2, 6) etc and assign appropriately
137shard_factor = len(shard_rank_lists[0])
138for i in range(num_node_devices // 2):
139replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
140replicate_group = dist.new_group(replicate_group_ranks)
141if my_rank in replicate_group_ranks:
142my_replicate_group = replicate_group
143
144fsdp_ctor = partial(
145FSDP,
146sharding_strategy=ShardingStrategy.HYBRID_SHARD,
147use_orig_params=True,
148process_group=(my_shard_group, my_replicate_group),
149)
150model = fsdp_ctor(model)
151optim = torch.optim.AdamW(model.parameters())
152# Initialize optimizer states
153model(torch.randn(2, 10)).sum().backward()
154optim.step()
155shard_g = model.process_group
156replicate_g = model._inter_node_pg
157assert shard_g == my_shard_group
158assert replicate_g == my_replicate_group
159with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
160msd = model.state_dict()
161osd = FSDP.optim_state_dict(model, optim)
162
163load_model = fsdp_ctor(MyModel().cuda())
164load_optim = torch.optim.AdamW(load_model.parameters())
165with FSDP.state_dict_type(load_model, StateDictType.SHARDED_STATE_DICT):
166load_model.load_state_dict(msd)
167FSDP.optim_state_dict_to_load(load_model, load_optim, osd)
168load_optim.load_state_dict(osd)
169
170@skip_if_lt_x_gpu(4)
171def test_hsdp_sync_module_state(self):
172model = MyModel().cuda()
173num_node_devices = torch.cuda.device_count()
174shard_rank_lists = list(range(0, num_node_devices // 2)), list(
175range(num_node_devices // 2, num_node_devices)
176)
177shard_groups = (
178dist.new_group(shard_rank_lists[0]),
179dist.new_group(shard_rank_lists[1]),
180)
181my_shard_group = (
182shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1]
183)
184my_replicate_group = None
185my_rank = self.rank
186# Create groups like (0, 4), (1, 5), (2, 6) etc and assign appropriately
187shard_factor = len(shard_rank_lists[0])
188for i in range(num_node_devices // 2):
189replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
190replicate_group = dist.new_group(replicate_group_ranks)
191if my_rank in replicate_group_ranks:
192my_replicate_group = replicate_group
193
194nn.init.constant_(model.lin1.weight, self.rank)
195nn.init.constant_(model.lin2.weight, self.rank)
196nn.init.constant_(model.lin3.weight, self.rank)
197
198fsdp_ctor = partial(
199FSDP,
200sharding_strategy=ShardingStrategy.HYBRID_SHARD,
201use_orig_params=True,
202sync_module_states=True,
203process_group=(my_shard_group, my_replicate_group),
204)
205model = fsdp_ctor(model)
206
207with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
208self.assertTrue((model.lin1.weight == 0).all())
209self.assertTrue((model.lin2.weight == 0).all())
210self.assertTrue((model.lin3.weight == 0).all())
211
212@skip_if_lt_x_gpu(2)
213def test_invalid_pg_specification_raises(self):
214pol = ModuleWrapPolicy({nn.Linear})
215model = MyModel().cuda()
216with self.assertRaisesRegex(
217ValueError, "Expected process_group to be passed in"
218):
219model = FSDP(
220model,
221auto_wrap_policy=pol,
222process_group=self.process_group,
223sharding_strategy=ShardingStrategy.HYBRID_SHARD,
224)
225
226# TODO - add test for ZeRO-2 style sharding ensure params are not
227# resharded after forward.
228
229@skip_if_lt_x_gpu(2)
230def test_fsdp_hybrid_shard_basic_setup(self):
231"""
232Tests basic functionality of HYBRID_SHARD and _HYBRID_SHARD_ZERO2:
2331. Inter and intra-node process groups are correctly setup
2342. Process groups are the same across FSDP wrapped instances
2353. reduce_scatter and allreduce called the expected no. of times
236"""
237self.run_subtests(
238{
239"hsdp_sharding_strategy": [
240ShardingStrategy.HYBRID_SHARD,
241ShardingStrategy._HYBRID_SHARD_ZERO2,
242],
243"sharding_strategy_mode": [
244ShardingStrategyMode.ALL_HYBRID_SHARD,
245ShardingStrategyMode.MIXED_HYBRID_FULL_SHARD,
246],
247"use_orig_params": [False, True],
248"use_device_mesh": [False, True],
249},
250self._test_fsdp_hybrid_shard_basic_setup,
251)
252
253def _test_fsdp_hybrid_shard_basic_setup(
254self,
255hsdp_sharding_strategy: ShardingStrategy,
256sharding_strategy_mode: ShardingStrategyMode,
257use_orig_params: bool,
258use_device_mesh: bool,
259):
260if use_device_mesh:
261device_mesh = init_device_mesh("cuda", (1, self.world_size))
262else:
263device_mesh = None
264hsdp_model = self._init_hsdp_model(
265hsdp_sharding_strategy,
266sharding_strategy_mode,
267use_orig_params,
268hsdp_device_mesh=device_mesh,
269)
270# All FSDP modules should have state.process_group as the process group over which to
271# shard (default process group), and state._inter_node_pg (process group containing only
272# this rank)
273intra_node_pgs = set()
274inter_node_pgs = set()
275for fsdp_module in hsdp_model.fsdp_modules(hsdp_model):
276# TODO: This needs to be replaced if we deprecate
277# `FSDP.sharding_strategy` to only use the handle one.
278# https://github.com/pytorch/pytorch/issues/90857
279if fsdp_module.sharding_strategy not in HYBRID_SHARDING_STRATEGIES:
280self.assertEqual(
281sharding_strategy_mode, ShardingStrategyMode.MIXED_HYBRID_FULL_SHARD
282)
283self.assertEqual(
284fsdp_module.sharding_strategy, ShardingStrategy.FULL_SHARD
285)
286continue
287# process_group should be across the node, which is just the
288# whole world here.
289self.assertEqual(
290dist.get_world_size(fsdp_module.process_group),
291dist.get_world_size(self.process_group),
292)
293intra_node_pgs.add(fsdp_module.process_group)
294inter_node_pg = fsdp_module._inter_node_pg
295inter_node_pgs.add(inter_node_pg)
296self.assertEqual(1, dist.get_world_size(inter_node_pg))
297self.assertFalse(_rank_not_in_group(inter_node_pg))
298self.assertEqual(hsdp_sharding_strategy, fsdp_module.sharding_strategy)
299# All fsdp modules should share the same process groups
300self.assertEqual(1, len(intra_node_pgs))
301self.assertEqual(1, len(inter_node_pgs))
302
303orig_ar = dist.all_reduce
304orig_rs = dist.reduce_scatter_tensor
305
306def patched_collective(orig_collective, counter, *args, **kwargs):
307counter[orig_collective] += 1
308return orig_collective(*args, **kwargs)
309
310cntr = Counter()
311patched_allreduce = partial(patched_collective, orig_ar, cntr)
312patched_reduce_scatter = partial(patched_collective, orig_rs, cntr)
313with patch_allreduce(patched_allreduce), patch_reduce_scatter(
314patched_reduce_scatter
315):
316inp = hsdp_model.get_input(device=torch.cuda.current_device())
317out = hsdp_model(inp[0], inp[1])
318loss = hsdp_model.get_loss(inp, out)
319loss.backward()
320
321if sharding_strategy_mode == ShardingStrategyMode.ALL_HYBRID_SHARD:
322num_flat_params = len(list(traversal_utils._get_fsdp_handles(hsdp_model)))
323self.assertEqual(num_flat_params, cntr[orig_ar])
324self.assertEqual(num_flat_params, cntr[orig_rs])
325elif sharding_strategy_mode == ShardingStrategyMode.MIXED_HYBRID_FULL_SHARD:
326num_hsdp_flat_params = len(
327list(traversal_utils._get_fsdp_handles(hsdp_model.transformer))
328)
329num_flat_params = len(list(traversal_utils._get_fsdp_handles(hsdp_model)))
330self.assertEqual(num_hsdp_flat_params, cntr[orig_ar])
331self.assertEqual(num_flat_params, cntr[orig_rs])
332
333@skip_if_lt_x_gpu(4)
334def test_fsdp_hybrid_shard_parity(self):
335self.run_subtests(
336{
337"hsdp_sharding_strategy": [
338ShardingStrategy.HYBRID_SHARD,
339ShardingStrategy._HYBRID_SHARD_ZERO2,
340],
341"use_orig_params": [False, True],
342},
343self._test_fsdp_hybrid_shard_parity,
344)
345
346def _test_fsdp_hybrid_shard_parity(
347self, hsdp_sharding_strategy: ShardingStrategy, use_orig_params: bool
348):
349fsdp_model = self._init_fsdp_model(use_orig_params)
350global_pg = dist.distributed_c10d._get_default_group()
351hsdp_pgs = _init_intra_and_inter_node_groups(global_pg, 2)
352hsdp_model = self._init_hsdp_model(
353hsdp_sharding_strategy,
354ShardingStrategyMode.ALL_HYBRID_SHARD,
355use_orig_params,
356hsdp_process_groups=hsdp_pgs,
357)
358assert (
359hsdp_model._inter_node_pg.size() > 1
360), "HSDP model initialized without replication"
361fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
362hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2)
363torch.manual_seed(global_pg.rank() + 1)
364for _ in range(5):
365inp = fsdp_model.module.get_input(torch.device("cuda"))
366losses: List[torch.Tensor] = []
367for model, optim in ((fsdp_model, fsdp_optim), (hsdp_model, hsdp_optim)):
368optim.zero_grad()
369loss = model(*inp).sum()
370losses.append(loss)
371loss.backward()
372optim.step()
373self.assertEqual(losses[0], losses[1])
374
375def _init_fsdp_model(self, use_orig_params: bool) -> nn.Module:
376auto_wrap_policy = ModuleWrapPolicy(
377{TransformerEncoderLayer, TransformerDecoderLayer},
378)
379hsdp_kwargs = {
380"auto_wrap_policy": auto_wrap_policy,
381"device_id": torch.cuda.current_device(),
382"use_orig_params": use_orig_params,
383}
384fsdp_model = TransformerWithSharedParams.init(
385self.process_group,
386FSDPInitMode.RECURSIVE,
387CUDAInitMode.CUDA_BEFORE,
388hsdp_kwargs,
389deterministic=True,
390)
391return fsdp_model
392
393def _init_hsdp_model(
394self,
395hsdp_sharding_strategy: ShardingStrategy,
396sharding_strategy_mode: str,
397use_orig_params: bool,
398hsdp_process_groups: Optional[
399Tuple[dist.ProcessGroup, dist.ProcessGroup]
400] = None,
401hsdp_device_mesh: Optional = None,
402):
403assert hsdp_process_groups is None or hsdp_device_mesh is None
404auto_wrap_policy = ModuleWrapPolicy(
405{TransformerEncoderLayer, TransformerDecoderLayer},
406)
407hsdp_kwargs = {
408"device_id": torch.cuda.current_device(),
409"auto_wrap_policy": auto_wrap_policy,
410"sharding_strategy": hsdp_sharding_strategy,
411"use_orig_params": use_orig_params,
412"device_mesh": hsdp_device_mesh,
413}
414if sharding_strategy_mode == ShardingStrategyMode.ALL_HYBRID_SHARD:
415hsdp_model = TransformerWithSharedParams.init(
416hsdp_process_groups or self.process_group,
417FSDPInitMode.RECURSIVE,
418CUDAInitMode.CUDA_BEFORE,
419hsdp_kwargs,
420deterministic=True,
421)
422elif sharding_strategy_mode == ShardingStrategyMode.MIXED_HYBRID_FULL_SHARD:
423model = TransformerWithSharedParams.init(
424hsdp_process_groups or self.process_group,
425FSDPInitMode.NO_FSDP,
426CUDAInitMode.CUDA_BEFORE,
427{},
428deterministic=True,
429)
430# Use the HSDP strategy for the transformer module
431model.transformer = FSDP(model.transformer, **hsdp_kwargs)
432# Use `FULL_SHARD` for the embedding and output projection
433hsdp_model = FSDP(
434model,
435device_id=torch.cuda.current_device(),
436sharding_strategy=ShardingStrategy.FULL_SHARD,
437use_orig_params=use_orig_params,
438)
439return hsdp_model
440
441
442instantiate_parametrized_tests(TestFSDPHybridShard)
443
444if __name__ == "__main__":
445run_tests()
446