pytorch

Форк
0
/
test_fsdp_backward_prefetch.py 
220 строк · 8.8 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4
from typing import List
5
from unittest.mock import patch
6

7
import torch
8
import torch.nn as nn
9
from torch import distributed as dist
10
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
11
from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root
12
from torch.distributed.fsdp._flat_param import HandleTrainingState
13
from torch.distributed.fsdp._runtime_utils import (
14
    _get_handle_to_prefetch,
15
    _get_training_state,
16
)
17
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
18
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
19
from torch.testing._internal.common_fsdp import FSDPTest
20
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
21

22

23
NUM_ITERS = 2
24
DECODER_PARAM_FQNS = [
25
    "decoder.layers.{index}.self_attn.in_proj_weight",
26
    "decoder.layers.{index}.self_attn.in_proj_bias",
27
    "decoder.layers.{index}.self_attn.out_proj.weight",
28
    "decoder.layers.{index}.self_attn.out_proj.bias",
29
    "decoder.layers.{index}.multihead_attn.in_proj_weight",
30
    "decoder.layers.{index}.multihead_attn.in_proj_bias",
31
    "decoder.layers.{index}.multihead_attn.out_proj.weight",
32
    "decoder.layers.{index}.multihead_attn.out_proj.bias",
33
    "decoder.layers.{index}.linear1.weight",
34
    "decoder.layers.{index}.linear1.bias",
35
    "decoder.layers.{index}.linear2.weight",
36
    "decoder.layers.{index}.linear2.bias",
37
    "decoder.layers.{index}.norm1.weight",
38
    "decoder.layers.{index}.norm1.bias",
39
    "decoder.layers.{index}.norm2.weight",
40
    "decoder.layers.{index}.norm2.bias",
41
    "decoder.layers.{index}.norm3.weight",
42
    "decoder.layers.{index}.norm3.bias",
43
]
44
ENCODER_PARAM_FQNS = [
45
    "encoder.layers.{index}.self_attn.in_proj_weight",
46
    "encoder.layers.{index}.self_attn.in_proj_bias",
47
    "encoder.layers.{index}.self_attn.out_proj.weight",
48
    "encoder.layers.{index}.self_attn.out_proj.bias",
49
    "encoder.layers.{index}.linear1.weight",
50
    "encoder.layers.{index}.linear1.bias",
51
    "encoder.layers.{index}.linear2.weight",
52
    "encoder.layers.{index}.linear2.bias",
53
    "encoder.layers.{index}.norm1.weight",
54
    "encoder.layers.{index}.norm1.bias",
55
    "encoder.layers.{index}.norm2.weight",
56
    "encoder.layers.{index}.norm2.bias",
57
]
58
TOTAL_NUM_PREFETCH_FOR_PRE = 12
59
TOTAL_NUM_PREFETCH_FOR_POST = 11
60
ENCODER_BEGIN_INDEX_FOR_PRE = 6
61
ENCODER_BEGIN_INDEX_FOR_POST = 5
62
ENCODER_PREFETCH_NUM = 5
63

64
if not dist.is_available():
65
    print("Distributed not available, skipping tests", file=sys.stderr)
66
    sys.exit(0)
67

68
if TEST_WITH_DEV_DBG_ASAN:
69
    print(
70
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
71
        file=sys.stderr,
72
    )
73
    sys.exit(0)
74

75

76
class TestBackwardPrefetch(FSDPTest):
77
    @property
78
    def world_size(self):
79
        return 2
80

81
    def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
82
        rank = self.rank
83
        orig_get_handle_to_prefetch = _get_handle_to_prefetch
84

85
        torch.manual_seed(0)
86
        policy = ModuleWrapPolicy(
87
            {nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
88
        )
89
        model = FSDP(
90
            nn.Transformer(d_model=1024, nhead=8, device="cuda"),
91
            device_id=torch.cuda.current_device(),
92
            auto_wrap_policy=policy,
93
            use_orig_params=True,
94
            backward_prefetch=backward_prefetch,
95
        )
96
        optim = torch.optim.SGD(model.parameters(), lr=1e-2)
97

98
        # prepare input
99
        torch.manual_seed(rank + 1)
100
        src = torch.randn((10, 1, 1024), device="cuda")
101
        tgt = torch.randn((20, 1, 1024), device="cuda")
102

103
        # monkey patch
104
        all_handle_fqns: List[List[str]] = []
105

106
        def patched_get_handle_to_prefetch(*args, **kwargs):
107
            handle = orig_get_handle_to_prefetch(*args, **kwargs)
108

109
            self.assertEqual(
110
                len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)"
111
            )
112
            state = args[0]
113
            current_handle = args[1]
114
            training_state = _get_training_state(current_handle)
115
            if (
116
                training_state == HandleTrainingState.BACKWARD_PRE
117
                and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
118
            ) or (
119
                training_state == HandleTrainingState.BACKWARD_POST
120
                and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
121
            ):
122
                nonlocal all_handle_fqns
123
                # FQNs prefixed from the root module
124
                # state._exec_order_data.param_to_fqn
125
                fqns = _get_handle_fqns_from_root(state, handle)
126
                all_handle_fqns.append(fqns)
127
            return handle
128

129
        # flat params from prefetch handle should match
130
        # DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS
131
        with patch(
132
            "torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch",
133
            patched_get_handle_to_prefetch,
134
        ):
135
            for _ in range(NUM_ITERS):
136
                optim.zero_grad()
137
                loss = model(src, tgt).sum()
138
                loss.backward()
139
                optim.step()
140
                if backward_prefetch is None:
141
                    self.assertEqual(len(all_handle_fqns), 0)
142
                    continue
143
                elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
144
                    # state._exec_order_data.handles_post_forward_order
145
                    # equals forward order
146
                    #     encoder 0...5 -> decoder 0...5 -> root
147
                    # pre-backward hook order
148
                    #     root -> decoder 5...0 -> encoder 5...0
149
                    # prefetch order
150
                    #     decoder 5...0 -> encoder 5...0 -> None
151
                    # None: when current_handle=encoder 0,
152
                    #       _get_handle_to_prefetch returns None
153
                    # +1 is for the above None
154
                    encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE
155
                    self.assertEqual(
156
                        len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1
157
                    )
158
                elif backward_prefetch == BackwardPrefetch.BACKWARD_POST:
159
                    # state._exec_order_data.handles_post_forward_order
160
                    # equals forward order (same as BACKWARD_PRE)
161
                    #     encoder 0...5 -> decoder 0...5 -> root
162
                    # post-backward hook (AccumulateGrad) order
163
                    #     decoder 5, 4...0 -> encoder 5...0 -> root
164
                    # prefetch order
165
                    #     decoder 4...0 -> encoder 5...0 -> None -> None
166
                    # 1st None: when current_handle=encoder 0,
167
                    #           _get_handle_to_prefetch returns None
168
                    # 2nd None: when current_handle=root,
169
                    #           get decoder 5 inside _get_handle_to_prefetch
170
                    #           but not needed since decoder 5 is computed already
171
                    # +2 is for the above Nones
172
                    encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST
173
                    self.assertEqual(
174
                        len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2
175
                    )
176

177
                # ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch
178
                for ith_prefetch, fqns in enumerate(all_handle_fqns):
179
                    if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index:
180
                        layer_index = encoder_begin_index - 1 - ith_prefetch
181
                        self.assertEqual(
182
                            fqns,
183
                            [x.format(index=layer_index) for x in DECODER_PARAM_FQNS],
184
                        )
185
                    elif (
186
                        ith_prefetch >= encoder_begin_index
187
                        and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM
188
                    ):
189
                        layer_index = (
190
                            encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch
191
                        )
192
                        self.assertEqual(
193
                            fqns,
194
                            [x.format(index=layer_index) for x in ENCODER_PARAM_FQNS],
195
                        )
196
                    else:
197
                        self.assertTrue(fqns is None)
198

199
                all_handle_fqns = []
200

201
    @skip_if_lt_x_gpu(2)
202
    def test_backward_prefetch(self):
203
        # subtest reuse process group to shorten test time
204
        self.run_subtests(
205
            {
206
                "backward_prefetch": [
207
                    None,
208
                    BackwardPrefetch.BACKWARD_PRE,
209
                    BackwardPrefetch.BACKWARD_POST,
210
                ],
211
            },
212
            self._test_backward_prefetch,
213
        )
214

215
    def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch):
216
        self._dist_train(backward_prefetch)
217

218

219
if __name__ == "__main__":
220
    run_tests()
221

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.