pytorch
357 строк · 12.1 Кб
1import contextlib
2import warnings
3from typing import cast, Generator
4
5import torch
6import torch.distributed.fsdp._traversal_utils as traversal_utils
7import torch.nn as nn
8from torch.distributed.fsdp._common_utils import (
9_FSDPState,
10_has_fsdp_params,
11_module_handle,
12HandleTrainingState,
13TrainingState,
14)
15from torch.distributed.fsdp._runtime_utils import (
16_get_fsdp_root_states_with_modules,
17_lazy_init,
18_reset_flat_param_grad_info_if_needed,
19_reshard,
20_reshard_grads,
21_unshard,
22_unshard_grads,
23)
24from torch.distributed.utils import _p_assert
25
26from ._flat_param import FlatParamHandle
27
28FLAT_PARAM = "_flat_param"
29
30
31@torch.no_grad()
32def _writeback_to_local_shard(
33handle: FlatParamHandle,
34writeback_grad: bool,
35):
36"""
37For the handle, writes back the this rank's shard of the unsharded
38flattened parameter to the sharded flattened parameter. If
39``writeback_grad=True``, then writes back to the sharded gradient as
40well.
41
42Precondition: The handle's ``FlatParameter`` 's data points to the
43padded unsharded flattened parameter.
44"""
45
46def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor:
47if handle.uses_sharded_strategy:
48# For sharded strategies, get the *unpadded* shard instead of
49# the *padded* shard to persist user changes to the padding
50# (though FSDP does not explicitly support this)
51shard, _ = FlatParamHandle._get_unpadded_shard(
52flat_param_or_grad,
53handle.rank,
54handle.world_size,
55)
56return shard
57# For `NO_SHARD`, the `flat_param` or its gradient may be modified,
58# so we write it back directly
59return flat_param_or_grad
60
61param_shard = _get_shard(handle.flat_param)
62handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined]
63if writeback_grad:
64existing_grad = handle.sharded_grad
65if existing_grad is not None:
66assert handle.flat_param.grad is not None
67grad_shard = _get_shard(handle.flat_param.grad)
68existing_grad[: grad_shard.numel()].copy_(grad_shard)
69
70
71def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None:
72"""
73De-registers the flattened parameter from the wrapped module, hiding it
74from ``nn.Module`` methods.
75
76We do not use ``del`` because we want ``FLAT_PARAM`` to always be an
77attribute but dynamically change whether it is visible to ``nn.Module``
78methods.
79"""
80if _has_fsdp_params(state, module):
81# TODO: figure out the case for the composable APIs.
82cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None)
83
84
85def _register_flat_param(state: _FSDPState, module: nn.Module) -> None:
86"""
87Registers the flattened parameter to the wrapped module, making it
88visible to ``nn.Module`` methods.
89
90We do not use :meth:`nn.Module.register_parameter` because we want
91``FLAT_PARAM`` to always be an attribute but dynamically change whether
92it is visible to ``nn.Module`` methods.
93"""
94handle = _module_handle(state, module)
95if _has_fsdp_params(state, module):
96# TODO: figure out the case for the composable APIs.
97cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param
98
99
100@contextlib.contextmanager
101def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator:
102"""
103Assumes that the flattened parameter is unsharded. When in the context,
104de-registers the flattened parameter and unflattens the original
105parameters as ``nn.Parameter`` views into the flattened parameter.
106After the context, re-registers the flattened parameter and restores
107the original parameters as ``Tensor`` views into the flattened
108parameter.
109"""
110handle = _module_handle(state, module)
111if not handle:
112yield
113else:
114_deregister_flat_param(state, module)
115try:
116with handle.unflatten_as_params():
117yield
118finally:
119if not handle._use_orig_params:
120_register_flat_param(state, module)
121
122
123def _validate_unshard_params_args(
124state: _FSDPState,
125writeback: bool,
126rank0_only: bool,
127offload_to_cpu: bool,
128with_grads: bool,
129) -> None:
130if with_grads and (offload_to_cpu or not state._use_orig_params):
131raise NotImplementedError(
132f"with_grads={with_grads}, "
133f"use_orig_params={state._use_orig_params}, "
134f"offload_to_cpu={offload_to_cpu} "
135f"is not supported yet"
136)
137if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy):
138raise NotImplementedError(
139"offload_to_cpu=True and NO_SHARD is not supported yet"
140)
141if writeback and rank0_only:
142# TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to
143# persist the changes.
144raise NotImplementedError(
145"writeback=True and rank0_only=True is not supported yet"
146)
147if offload_to_cpu and not rank0_only:
148warnings.warn(
149"offload_to_cpu=True and rank0_only=False may result in the"
150"unsharded parameters being redundantly copied to CPU memory for "
151"GPUs sharing the same CPU memory, which risks CPU OOM. We "
152"recommend using offload_to_cpu=True with rank0_only=True."
153)
154
155
156@contextlib.contextmanager
157def _unshard_fsdp_state_params(
158module: nn.Module,
159state: _FSDPState,
160writeback: bool,
161rank0_only: bool,
162offload_to_cpu: bool,
163with_grads: bool,
164):
165"""
166This unshards the parameters for a single FSDP state ``state`` that
167corresponds to ``module``.
168"""
169_validate_unshard_params_args(
170state, writeback, rank0_only, offload_to_cpu, with_grads
171)
172state._device_handle.synchronize()
173# If handles are shared by other module(s), the handle may be already unsharded.
174maybe_handle = _module_handle(state, module)
175handle = None
176if (
177maybe_handle
178and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS
179):
180handle = maybe_handle
181if not handle:
182yield
183return
184
185assert (
186handle._training_state == HandleTrainingState.IDLE
187), f"Expects the handle training to be IDLE but got {handle._training_state}"
188
189handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
190
191_reset_flat_param_grad_info_if_needed(handle)
192free_unsharded_flat_param = handle.needs_unshard()
193# No need to call `wait_stream()` since we unshard in the computation
194# stream directly
195computation_stream = state._device_handle.current_stream()
196_unshard(state, handle, computation_stream, computation_stream)
197if with_grads:
198_unshard_grads(handle)
199
200if rank0_only and state.rank != 0:
201# Free the unsharded flattened parameter early
202_reshard(state, handle, free_unsharded_flat_param)
203if with_grads:
204_reshard_grads(handle)
205try:
206yield
207finally:
208handle._training_state = HandleTrainingState.IDLE
209else:
210# Unflatten the unsharded flattened parameters
211with contextlib.ExitStack() as stack:
212# Invariant: rank == 0 or !rank0_only
213if offload_to_cpu and handle.uses_sharded_strategy:
214stack.enter_context(handle.to_cpu())
215# NOTE: Since PyTorch enforces that a parameter and its
216# gradients need to match metadata (e.g. device), we must
217# move gradients to CPU *after* we move parameters.
218# NOTE: This assumes 1 `FlatParameter`
219if not state._use_orig_params:
220stack.enter_context(_unflatten_as_params(state, module))
221try:
222yield
223finally:
224stack.close()
225if writeback:
226_writeback_to_local_shard(handle, with_grads)
227_reshard(state, handle, free_unsharded_flat_param)
228if with_grads:
229_reshard_grads(handle)
230handle._training_state = HandleTrainingState.IDLE
231
232
233@contextlib.contextmanager
234def _unshard_params_recurse(
235module: nn.Module,
236state: _FSDPState,
237recurse: bool,
238writeback: bool,
239rank0_only: bool,
240offload_to_cpu: bool,
241with_grads: bool,
242):
243"""
244This is a helper for :func:`_unshard_params` that recursively calls
245:func:`_unshard_fsdp_state_params` on FSDP states if ``recurse=True``.
246NOTE: This runs lazy initialization.
247"""
248_validate_unshard_params_args(
249state, writeback, rank0_only, offload_to_cpu, with_grads
250)
251if recurse:
252with contextlib.ExitStack() as stack:
253# TODO (awgu): The traversal function does not traverse through
254# incompatible composable APIs. Verify if this is the desired
255# behavior for this function.
256for state, fsdp_module in zip(
257*traversal_utils._get_fsdp_states_with_modules(module)
258):
259stack.enter_context(
260_unshard_params_recurse(
261module=fsdp_module,
262state=state,
263recurse=False,
264writeback=writeback,
265rank0_only=rank0_only,
266offload_to_cpu=offload_to_cpu,
267with_grads=with_grads,
268)
269)
270yield
271return
272_lazy_init(state, module)
273if state.training_state == TrainingState.FORWARD_BACKWARD:
274raise AssertionError(
275"Cannot manually unshard parameters during forward/backward"
276)
277elif state.training_state == TrainingState.SUMMON_FULL_PARAMS:
278raise AssertionError(
279"Cannot manually unshard parameters when already unsharding parameters"
280)
281with _unshard_fsdp_state_params(
282module=module,
283state=state,
284writeback=writeback,
285rank0_only=rank0_only,
286offload_to_cpu=offload_to_cpu,
287with_grads=with_grads,
288):
289try:
290state.training_state = TrainingState.SUMMON_FULL_PARAMS
291yield
292finally:
293state.training_state = TrainingState.IDLE
294
295
296@contextlib.contextmanager
297def _unshard_params(
298module: nn.Module,
299recurse: bool,
300writeback: bool,
301rank0_only: bool,
302offload_to_cpu: bool,
303with_grads: bool,
304):
305"""
306This unshards FSDP-managed parameters for all modules with FSDP applied in
307the module tree rooted at ``module``.
308"""
309root_fsdp_states, root_fsdp_modules = _get_fsdp_root_states_with_modules(module)
310with contextlib.ExitStack() as stack:
311for root_fsdp_state, root_fsdp_module in zip(
312root_fsdp_states, root_fsdp_modules
313):
314stack.enter_context(
315_unshard_params_recurse(
316module=root_fsdp_module,
317state=root_fsdp_state,
318recurse=recurse,
319writeback=writeback,
320rank0_only=rank0_only,
321offload_to_cpu=offload_to_cpu,
322with_grads=with_grads,
323)
324)
325yield
326return
327
328
329def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
330"""
331Deregisters the original parameters; registers the ``FlatParameter``.
332"""
333handle = _module_handle(state, module)
334if not handle:
335return
336_p_assert(
337handle._use_orig_params,
338f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
339f"handle: {handle._use_orig_params}",
340)
341handle._deregister_orig_params()
342_register_flat_param(state, module)
343
344
345def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
346"""
347Deregisters the ``FlatParameter``; registers the original parameters.
348"""
349handle = _module_handle(state, module)
350if not handle:
351return
352_deregister_flat_param(state, module)
353if handle.is_sharded(handle.flat_param):
354handle._use_sharded_views()
355handle._use_sharded_grad_views()
356else:
357handle._use_unsharded_views(as_params=True)
358