pytorch
220 строк · 8.8 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys
4from typing import List
5from unittest.mock import patch
6
7import torch
8import torch.nn as nn
9from torch import distributed as dist
10from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
11from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root
12from torch.distributed.fsdp._flat_param import HandleTrainingState
13from torch.distributed.fsdp._runtime_utils import (
14_get_handle_to_prefetch,
15_get_training_state,
16)
17from torch.distributed.fsdp.wrap import ModuleWrapPolicy
18from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
19from torch.testing._internal.common_fsdp import FSDPTest
20from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
21
22
23NUM_ITERS = 2
24DECODER_PARAM_FQNS = [
25"decoder.layers.{index}.self_attn.in_proj_weight",
26"decoder.layers.{index}.self_attn.in_proj_bias",
27"decoder.layers.{index}.self_attn.out_proj.weight",
28"decoder.layers.{index}.self_attn.out_proj.bias",
29"decoder.layers.{index}.multihead_attn.in_proj_weight",
30"decoder.layers.{index}.multihead_attn.in_proj_bias",
31"decoder.layers.{index}.multihead_attn.out_proj.weight",
32"decoder.layers.{index}.multihead_attn.out_proj.bias",
33"decoder.layers.{index}.linear1.weight",
34"decoder.layers.{index}.linear1.bias",
35"decoder.layers.{index}.linear2.weight",
36"decoder.layers.{index}.linear2.bias",
37"decoder.layers.{index}.norm1.weight",
38"decoder.layers.{index}.norm1.bias",
39"decoder.layers.{index}.norm2.weight",
40"decoder.layers.{index}.norm2.bias",
41"decoder.layers.{index}.norm3.weight",
42"decoder.layers.{index}.norm3.bias",
43]
44ENCODER_PARAM_FQNS = [
45"encoder.layers.{index}.self_attn.in_proj_weight",
46"encoder.layers.{index}.self_attn.in_proj_bias",
47"encoder.layers.{index}.self_attn.out_proj.weight",
48"encoder.layers.{index}.self_attn.out_proj.bias",
49"encoder.layers.{index}.linear1.weight",
50"encoder.layers.{index}.linear1.bias",
51"encoder.layers.{index}.linear2.weight",
52"encoder.layers.{index}.linear2.bias",
53"encoder.layers.{index}.norm1.weight",
54"encoder.layers.{index}.norm1.bias",
55"encoder.layers.{index}.norm2.weight",
56"encoder.layers.{index}.norm2.bias",
57]
58TOTAL_NUM_PREFETCH_FOR_PRE = 12
59TOTAL_NUM_PREFETCH_FOR_POST = 11
60ENCODER_BEGIN_INDEX_FOR_PRE = 6
61ENCODER_BEGIN_INDEX_FOR_POST = 5
62ENCODER_PREFETCH_NUM = 5
63
64if not dist.is_available():
65print("Distributed not available, skipping tests", file=sys.stderr)
66sys.exit(0)
67
68if TEST_WITH_DEV_DBG_ASAN:
69print(
70"Skip dev-asan as torch + multiprocessing spawn have known issues",
71file=sys.stderr,
72)
73sys.exit(0)
74
75
76class TestBackwardPrefetch(FSDPTest):
77@property
78def world_size(self):
79return 2
80
81def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
82rank = self.rank
83orig_get_handle_to_prefetch = _get_handle_to_prefetch
84
85torch.manual_seed(0)
86policy = ModuleWrapPolicy(
87{nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
88)
89model = FSDP(
90nn.Transformer(d_model=1024, nhead=8, device="cuda"),
91device_id=torch.cuda.current_device(),
92auto_wrap_policy=policy,
93use_orig_params=True,
94backward_prefetch=backward_prefetch,
95)
96optim = torch.optim.SGD(model.parameters(), lr=1e-2)
97
98# prepare input
99torch.manual_seed(rank + 1)
100src = torch.randn((10, 1, 1024), device="cuda")
101tgt = torch.randn((20, 1, 1024), device="cuda")
102
103# monkey patch
104all_handle_fqns: List[List[str]] = []
105
106def patched_get_handle_to_prefetch(*args, **kwargs):
107handle = orig_get_handle_to_prefetch(*args, **kwargs)
108
109self.assertEqual(
110len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)"
111)
112state = args[0]
113current_handle = args[1]
114training_state = _get_training_state(current_handle)
115if (
116training_state == HandleTrainingState.BACKWARD_PRE
117and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
118) or (
119training_state == HandleTrainingState.BACKWARD_POST
120and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
121):
122nonlocal all_handle_fqns
123# FQNs prefixed from the root module
124# state._exec_order_data.param_to_fqn
125fqns = _get_handle_fqns_from_root(state, handle)
126all_handle_fqns.append(fqns)
127return handle
128
129# flat params from prefetch handle should match
130# DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS
131with patch(
132"torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch",
133patched_get_handle_to_prefetch,
134):
135for _ in range(NUM_ITERS):
136optim.zero_grad()
137loss = model(src, tgt).sum()
138loss.backward()
139optim.step()
140if backward_prefetch is None:
141self.assertEqual(len(all_handle_fqns), 0)
142continue
143elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
144# state._exec_order_data.handles_post_forward_order
145# equals forward order
146# encoder 0...5 -> decoder 0...5 -> root
147# pre-backward hook order
148# root -> decoder 5...0 -> encoder 5...0
149# prefetch order
150# decoder 5...0 -> encoder 5...0 -> None
151# None: when current_handle=encoder 0,
152# _get_handle_to_prefetch returns None
153# +1 is for the above None
154encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE
155self.assertEqual(
156len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1
157)
158elif backward_prefetch == BackwardPrefetch.BACKWARD_POST:
159# state._exec_order_data.handles_post_forward_order
160# equals forward order (same as BACKWARD_PRE)
161# encoder 0...5 -> decoder 0...5 -> root
162# post-backward hook (AccumulateGrad) order
163# decoder 5, 4...0 -> encoder 5...0 -> root
164# prefetch order
165# decoder 4...0 -> encoder 5...0 -> None -> None
166# 1st None: when current_handle=encoder 0,
167# _get_handle_to_prefetch returns None
168# 2nd None: when current_handle=root,
169# get decoder 5 inside _get_handle_to_prefetch
170# but not needed since decoder 5 is computed already
171# +2 is for the above Nones
172encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST
173self.assertEqual(
174len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2
175)
176
177# ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch
178for ith_prefetch, fqns in enumerate(all_handle_fqns):
179if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index:
180layer_index = encoder_begin_index - 1 - ith_prefetch
181self.assertEqual(
182fqns,
183[x.format(index=layer_index) for x in DECODER_PARAM_FQNS],
184)
185elif (
186ith_prefetch >= encoder_begin_index
187and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM
188):
189layer_index = (
190encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch
191)
192self.assertEqual(
193fqns,
194[x.format(index=layer_index) for x in ENCODER_PARAM_FQNS],
195)
196else:
197self.assertTrue(fqns is None)
198
199all_handle_fqns = []
200
201@skip_if_lt_x_gpu(2)
202def test_backward_prefetch(self):
203# subtest reuse process group to shorten test time
204self.run_subtests(
205{
206"backward_prefetch": [
207None,
208BackwardPrefetch.BACKWARD_PRE,
209BackwardPrefetch.BACKWARD_POST,
210],
211},
212self._test_backward_prefetch,
213)
214
215def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch):
216self._dist_train(backward_prefetch)
217
218
219if __name__ == "__main__":
220run_tests()
221