pytorch

Форк
0
/
_unshard_param_utils.py 
357 строк · 12.1 Кб
1
import contextlib
2
import warnings
3
from typing import cast, Generator
4

5
import torch
6
import torch.distributed.fsdp._traversal_utils as traversal_utils
7
import torch.nn as nn
8
from torch.distributed.fsdp._common_utils import (
9
    _FSDPState,
10
    _has_fsdp_params,
11
    _module_handle,
12
    HandleTrainingState,
13
    TrainingState,
14
)
15
from torch.distributed.fsdp._runtime_utils import (
16
    _get_fsdp_root_states_with_modules,
17
    _lazy_init,
18
    _reset_flat_param_grad_info_if_needed,
19
    _reshard,
20
    _reshard_grads,
21
    _unshard,
22
    _unshard_grads,
23
)
24
from torch.distributed.utils import _p_assert
25

26
from ._flat_param import FlatParamHandle
27

28
FLAT_PARAM = "_flat_param"
29

30

31
@torch.no_grad()
32
def _writeback_to_local_shard(
33
    handle: FlatParamHandle,
34
    writeback_grad: bool,
35
):
36
    """
37
    For the handle, writes back the this rank's shard of the unsharded
38
    flattened parameter to the sharded flattened parameter. If
39
    ``writeback_grad=True``, then writes back to the sharded gradient as
40
    well.
41

42
    Precondition: The handle's ``FlatParameter`` 's data points to the
43
    padded unsharded flattened parameter.
44
    """
45

46
    def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor:
47
        if handle.uses_sharded_strategy:
48
            # For sharded strategies, get the *unpadded* shard instead of
49
            # the *padded* shard to persist user changes to the padding
50
            # (though FSDP does not explicitly support this)
51
            shard, _ = FlatParamHandle._get_unpadded_shard(
52
                flat_param_or_grad,
53
                handle.rank,
54
                handle.world_size,
55
            )
56
            return shard
57
        # For `NO_SHARD`, the `flat_param` or its gradient may be modified,
58
        # so we write it back directly
59
        return flat_param_or_grad
60

61
    param_shard = _get_shard(handle.flat_param)
62
    handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard)  # type: ignore[attr-defined]
63
    if writeback_grad:
64
        existing_grad = handle.sharded_grad
65
        if existing_grad is not None:
66
            assert handle.flat_param.grad is not None
67
            grad_shard = _get_shard(handle.flat_param.grad)
68
            existing_grad[: grad_shard.numel()].copy_(grad_shard)
69

70

71
def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None:
72
    """
73
    De-registers the flattened parameter from the wrapped module, hiding it
74
    from ``nn.Module`` methods.
75

76
    We do not use ``del`` because we want ``FLAT_PARAM`` to always be an
77
    attribute but dynamically change whether it is visible to ``nn.Module``
78
    methods.
79
    """
80
    if _has_fsdp_params(state, module):
81
        # TODO: figure out the case for the composable APIs.
82
        cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None)
83

84

85
def _register_flat_param(state: _FSDPState, module: nn.Module) -> None:
86
    """
87
    Registers the flattened parameter to the wrapped module, making it
88
    visible to ``nn.Module`` methods.
89

90
    We do not use :meth:`nn.Module.register_parameter` because we want
91
    ``FLAT_PARAM`` to always be an attribute but dynamically change whether
92
    it is visible to ``nn.Module`` methods.
93
    """
94
    handle = _module_handle(state, module)
95
    if _has_fsdp_params(state, module):
96
        # TODO: figure out the case for the composable APIs.
97
        cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param
98

99

100
@contextlib.contextmanager
101
def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator:
102
    """
103
    Assumes that the flattened parameter is unsharded. When in the context,
104
    de-registers the flattened parameter and unflattens the original
105
    parameters as ``nn.Parameter`` views into the flattened parameter.
106
    After the context, re-registers the flattened parameter and restores
107
    the original parameters as ``Tensor`` views into the flattened
108
    parameter.
109
    """
110
    handle = _module_handle(state, module)
111
    if not handle:
112
        yield
113
    else:
114
        _deregister_flat_param(state, module)
115
        try:
116
            with handle.unflatten_as_params():
117
                yield
118
        finally:
119
            if not handle._use_orig_params:
120
                _register_flat_param(state, module)
121

122

123
def _validate_unshard_params_args(
124
    state: _FSDPState,
125
    writeback: bool,
126
    rank0_only: bool,
127
    offload_to_cpu: bool,
128
    with_grads: bool,
129
) -> None:
130
    if with_grads and (offload_to_cpu or not state._use_orig_params):
131
        raise NotImplementedError(
132
            f"with_grads={with_grads}, "
133
            f"use_orig_params={state._use_orig_params}, "
134
            f"offload_to_cpu={offload_to_cpu} "
135
            f"is not supported yet"
136
        )
137
    if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy):
138
        raise NotImplementedError(
139
            "offload_to_cpu=True and NO_SHARD is not supported yet"
140
        )
141
    if writeback and rank0_only:
142
        # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to
143
        # persist the changes.
144
        raise NotImplementedError(
145
            "writeback=True and rank0_only=True is not supported yet"
146
        )
147
    if offload_to_cpu and not rank0_only:
148
        warnings.warn(
149
            "offload_to_cpu=True and rank0_only=False may result in the"
150
            "unsharded parameters being redundantly copied to CPU memory for "
151
            "GPUs sharing the same CPU memory, which risks CPU OOM. We "
152
            "recommend using offload_to_cpu=True with rank0_only=True."
153
        )
154

155

156
@contextlib.contextmanager
157
def _unshard_fsdp_state_params(
158
    module: nn.Module,
159
    state: _FSDPState,
160
    writeback: bool,
161
    rank0_only: bool,
162
    offload_to_cpu: bool,
163
    with_grads: bool,
164
):
165
    """
166
    This unshards the parameters for a single FSDP state ``state`` that
167
    corresponds to ``module``.
168
    """
169
    _validate_unshard_params_args(
170
        state, writeback, rank0_only, offload_to_cpu, with_grads
171
    )
172
    state._device_handle.synchronize()
173
    # If handles are shared by other module(s), the handle may be already unsharded.
174
    maybe_handle = _module_handle(state, module)
175
    handle = None
176
    if (
177
        maybe_handle
178
        and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS
179
    ):
180
        handle = maybe_handle
181
    if not handle:
182
        yield
183
        return
184

185
    assert (
186
        handle._training_state == HandleTrainingState.IDLE
187
    ), f"Expects the handle training to be IDLE but got {handle._training_state}"
188

189
    handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
190

191
    _reset_flat_param_grad_info_if_needed(handle)
192
    free_unsharded_flat_param = handle.needs_unshard()
193
    # No need to call `wait_stream()` since we unshard in the computation
194
    # stream directly
195
    computation_stream = state._device_handle.current_stream()
196
    _unshard(state, handle, computation_stream, computation_stream)
197
    if with_grads:
198
        _unshard_grads(handle)
199

200
    if rank0_only and state.rank != 0:
201
        # Free the unsharded flattened parameter early
202
        _reshard(state, handle, free_unsharded_flat_param)
203
        if with_grads:
204
            _reshard_grads(handle)
205
        try:
206
            yield
207
        finally:
208
            handle._training_state = HandleTrainingState.IDLE
209
    else:
210
        # Unflatten the unsharded flattened parameters
211
        with contextlib.ExitStack() as stack:
212
            # Invariant: rank == 0 or !rank0_only
213
            if offload_to_cpu and handle.uses_sharded_strategy:
214
                stack.enter_context(handle.to_cpu())
215
                # NOTE: Since PyTorch enforces that a parameter and its
216
                # gradients need to match metadata (e.g. device), we must
217
                # move gradients to CPU *after* we move parameters.
218
            # NOTE: This assumes 1 `FlatParameter`
219
            if not state._use_orig_params:
220
                stack.enter_context(_unflatten_as_params(state, module))
221
            try:
222
                yield
223
            finally:
224
                stack.close()
225
                if writeback:
226
                    _writeback_to_local_shard(handle, with_grads)
227
                _reshard(state, handle, free_unsharded_flat_param)
228
                if with_grads:
229
                    _reshard_grads(handle)
230
                handle._training_state = HandleTrainingState.IDLE
231

232

233
@contextlib.contextmanager
234
def _unshard_params_recurse(
235
    module: nn.Module,
236
    state: _FSDPState,
237
    recurse: bool,
238
    writeback: bool,
239
    rank0_only: bool,
240
    offload_to_cpu: bool,
241
    with_grads: bool,
242
):
243
    """
244
    This is a helper for :func:`_unshard_params` that recursively calls
245
    :func:`_unshard_fsdp_state_params` on FSDP states if ``recurse=True``.
246
    NOTE: This runs lazy initialization.
247
    """
248
    _validate_unshard_params_args(
249
        state, writeback, rank0_only, offload_to_cpu, with_grads
250
    )
251
    if recurse:
252
        with contextlib.ExitStack() as stack:
253
            # TODO (awgu): The traversal function does not traverse through
254
            # incompatible composable APIs. Verify if this is the desired
255
            # behavior for this function.
256
            for state, fsdp_module in zip(
257
                *traversal_utils._get_fsdp_states_with_modules(module)
258
            ):
259
                stack.enter_context(
260
                    _unshard_params_recurse(
261
                        module=fsdp_module,
262
                        state=state,
263
                        recurse=False,
264
                        writeback=writeback,
265
                        rank0_only=rank0_only,
266
                        offload_to_cpu=offload_to_cpu,
267
                        with_grads=with_grads,
268
                    )
269
                )
270
            yield
271
        return
272
    _lazy_init(state, module)
273
    if state.training_state == TrainingState.FORWARD_BACKWARD:
274
        raise AssertionError(
275
            "Cannot manually unshard parameters during forward/backward"
276
        )
277
    elif state.training_state == TrainingState.SUMMON_FULL_PARAMS:
278
        raise AssertionError(
279
            "Cannot manually unshard parameters when already unsharding parameters"
280
        )
281
    with _unshard_fsdp_state_params(
282
        module=module,
283
        state=state,
284
        writeback=writeback,
285
        rank0_only=rank0_only,
286
        offload_to_cpu=offload_to_cpu,
287
        with_grads=with_grads,
288
    ):
289
        try:
290
            state.training_state = TrainingState.SUMMON_FULL_PARAMS
291
            yield
292
        finally:
293
            state.training_state = TrainingState.IDLE
294

295

296
@contextlib.contextmanager
297
def _unshard_params(
298
    module: nn.Module,
299
    recurse: bool,
300
    writeback: bool,
301
    rank0_only: bool,
302
    offload_to_cpu: bool,
303
    with_grads: bool,
304
):
305
    """
306
    This unshards FSDP-managed parameters for all modules with FSDP applied in
307
    the module tree rooted at ``module``.
308
    """
309
    root_fsdp_states, root_fsdp_modules = _get_fsdp_root_states_with_modules(module)
310
    with contextlib.ExitStack() as stack:
311
        for root_fsdp_state, root_fsdp_module in zip(
312
            root_fsdp_states, root_fsdp_modules
313
        ):
314
            stack.enter_context(
315
                _unshard_params_recurse(
316
                    module=root_fsdp_module,
317
                    state=root_fsdp_state,
318
                    recurse=recurse,
319
                    writeback=writeback,
320
                    rank0_only=rank0_only,
321
                    offload_to_cpu=offload_to_cpu,
322
                    with_grads=with_grads,
323
                )
324
            )
325
        yield
326
    return
327

328

329
def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
330
    """
331
    Deregisters the original parameters; registers the ``FlatParameter``.
332
    """
333
    handle = _module_handle(state, module)
334
    if not handle:
335
        return
336
    _p_assert(
337
        handle._use_orig_params,
338
        f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
339
        f"handle: {handle._use_orig_params}",
340
    )
341
    handle._deregister_orig_params()
342
    _register_flat_param(state, module)
343

344

345
def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
346
    """
347
    Deregisters the ``FlatParameter``; registers the original parameters.
348
    """
349
    handle = _module_handle(state, module)
350
    if not handle:
351
        return
352
    _deregister_flat_param(state, module)
353
    if handle.is_sharded(handle.flat_param):
354
        handle._use_sharded_views()
355
        handle._use_sharded_grad_views()
356
    else:
357
        handle._use_unsharded_views(as_params=True)
358

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.