pytorch
410 строк · 18.4 Кб
1"""
2This file includes public APIs for FSDP such as the classes used for the
3constructor arguments.
4"""
5
6from dataclasses import dataclass7from enum import auto, Enum8
9from typing import Optional, Sequence, Type10
11import torch12from torch.nn.modules.batchnorm import _BatchNorm13
14__all__ = [15"ShardingStrategy",16"BackwardPrefetch",17"MixedPrecision",18"CPUOffload",19"StateDictType",20"StateDictConfig",21"FullStateDictConfig",22"LocalStateDictConfig",23"ShardedStateDictConfig",24"OptimStateDictConfig",25"FullOptimStateDictConfig",26"LocalOptimStateDictConfig",27"ShardedOptimStateDictConfig",28"StateDictSettings",29]
30
31
32class ShardingStrategy(Enum):33"""34This specifies the sharding strategy to be used for distributed training by
35:class:`FullyShardedDataParallel`.
36
37- ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.
38For the parameters, this strategy unshards (via all-gather) before the
39forward, reshards after the forward, unshards before the backward
40computation, and reshards after the backward computation. For gradients,
41it synchronizes and shards them (via reduce-scatter) after the backward
42computation. The sharded optimizer states are updated locally per rank.
43- ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during
44computation, and additionally, parameters are sharded outside
45computation. For the parameters, this strategy unshards before the
46forward, does not reshard them after the forward, and only reshards them
47after the backward computation. The sharded optimizer states are updated
48locally per rank. Inside ``no_sync()``, the parameters are not resharded
49after the backward computation.
50- ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded
51but instead replicated across ranks similar to PyTorch's
52:class:`DistributedDataParallel` API. For gradients, this strategy
53synchronizes them (via all-reduce) after the backward computation. The
54unsharded optimizer states are updated locally per rank.
55- ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
56nodes. This results in reduced communication volume as expensive all-gathers and
57reduce-scatters are only done within a node, which can be more performant for medium
58-sized models.
59- ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
60nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
61since the unsharded parameters are not freed after the forward pass, saving the
62all-gathers in the pre-backward.
63"""
64
65FULL_SHARD = auto()66SHARD_GRAD_OP = auto()67NO_SHARD = auto()68HYBRID_SHARD = auto()69_HYBRID_SHARD_ZERO2 = auto()70
71
72class BackwardPrefetch(Enum):73"""74This configures explicit backward prefetching, which improves throughput by
75enabling communication and computation overlap in the backward pass at the
76cost of slightly increased memory usage.
77
78- ``BACKWARD_PRE``: This enables the most overlap but increases memory
79usage the most. This prefetches the next set of parameters *before* the
80current set of parameters' gradient computation. This overlaps the *next
81all-gather* and the *current gradient computation*, and at the peak, it
82holds the current set of parameters, next set of parameters, and current
83set of gradients in memory.
84- ``BACKWARD_POST``: This enables less overlap but requires less memory
85usage. This prefetches the next set of parameters *after* the current
86set of parameters' gradient computation. This overlaps the *current
87reduce-scatter* and the *next gradient computation*, and it frees the
88current set of parameters before allocating memory for the next set of
89parameters, only holding the next set of parameters and current set of
90gradients in memory at the peak.
91- FSDP's ``backward_prefetch`` argument accepts ``None``, which disables
92the backward prefetching altogether. This has no overlap and does not
93increase memory usage. In general, we do not recommend this setting since
94it may degrade throughput significantly.
95
96For more technical context: For a single process group using NCCL backend,
97any collectives, even if issued from different streams, contend for the
98same per-device NCCL stream, which implies that the relative order in which
99the collectives are issued matters for overlapping. The two backward
100prefetching values correspond to different issue orders.
101"""
102
103# NOTE: For both modes, the ordering that defines "current" and "next" is104# not always exact in the current implementation. A mistargeted prefetch105# simply means that the parameter memory is allocated earlier than needed,106# possibly increasing peak memory usage, but does not affect correctness.107BACKWARD_PRE = auto()108BACKWARD_POST = auto()109
110
111@dataclass
112class MixedPrecision:113"""114This configures FSDP-native mixed precision training.
115
116Attributes:
117param_dtype (Optional[torch.dtype]): This specifies the dtype for model
118parameters during forward and backward and thus the dtype for
119forward and backward computation. Outside forward and backward, the
120*sharded* parameters are kept in full precision (e.g. for the
121optimizer step), and for model checkpointing, the parameters are
122always saved in full precision. (Default: ``None``)
123reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
124gradient reduction (i.e. reduce-scatter or all-reduce). If this is
125``None`` but ``param_dtype`` is not ``None``, then this takes on
126the ``param_dtype`` value, still running gradient reduction in low
127precision. This is permitted to differ from ``param_dtype``, e.g.
128to force gradient reduction to run in full precision. (Default:
129``None``)
130buffer_dtype (Optional[torch.dtype]): This specifies the dtype for
131buffers. FSDP does not shard buffers. Rather, FSDP casts them to
132``buffer_dtype`` in the first forward pass and keeps them in that
133dtype thereafter. For model checkpointing, the buffers are saved
134in full precision except for ``LOCAL_STATE_DICT``. (Default:
135``None``)
136keep_low_precision_grads (bool): If ``False``, then FSDP upcasts
137gradients to full precision after the backward pass in preparation
138for the optimizer step. If ``True``, then FSDP keeps the gradients
139in the dtype used for gradient reduction, which can save memory if
140using a custom optimizer that supports running in low precision.
141(Default: ``False``)
142cast_forward_inputs (bool): If ``True``, then this FSDP module casts
143its forward args and kwargs to ``param_dtype``. This is to ensure
144that parameter and input dtypes match for forward computation, as
145required by many ops. This may need to be set to ``True`` when only
146applying mixed precision to some but not all FSDP modules, in which
147case a mixed-precision FSDP submodule needs to recast its inputs.
148(Default: ``False``)
149cast_root_forward_inputs (bool): If ``True``, then the root FSDP module
150casts its forward args and kwargs to ``param_dtype``, overriding
151the value of ``cast_forward_inputs``. For non-root FSDP modules,
152this does not do anything. (Default: ``True``)
153_module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies
154module classes to ignore for mixed precision when using an
155``auto_wrap_policy``: Modules of these classes will have FSDP
156applied to them separately with mixed precision disabled (meaning
157that the final FSDP construction would deviate from the specified
158policy). If ``auto_wrap_policy`` is not specified, then this does
159not do anything. This API is experimental and subject to change.
160(Default: ``(_BatchNorm,)``)
161
162.. note:: This API is experimental and subject to change.
163
164.. note:: Only floating point tensors are cast to their specified dtypes.
165
166.. note:: In ``summon_full_params``, parameters are forced to full
167precision, but buffers are not.
168
169.. note:: Layer norm and batch norm accumulate in ``float32`` even when
170their inputs are in a low precision like ``float16`` or ``bfloat16``.
171Disabling FSDP's mixed precision for those norm modules only means that
172the affine parameters are kept in ``float32``. However, this incurs
173separate all-gathers and reduce-scatters for those norm modules, which
174may be inefficient, so if the workload permits, the user should prefer
175to still apply mixed precision to those modules.
176
177.. note:: By default, if the user passes a model with any ``_BatchNorm``
178modules and specifies an ``auto_wrap_policy``, then the batch norm
179modules will have FSDP applied to them separately with mixed precision
180disabled. See the ``_module_classes_to_ignore`` argument.
181
182.. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
183``cast_forward_inputs=False`` by default. For the root FSDP instance,
184its ``cast_root_forward_inputs`` takes precedence over its
185``cast_forward_inputs``. For non-root FSDP instances, their
186``cast_root_forward_inputs`` values are ignored. The default setting is
187sufficient for the typical case where each FSDP instance has the same
188``MixedPrecision`` configuration and only needs to cast inputs to the
189``param_dtype`` at the beginning of the model's forward pass.
190
191.. note:: For nested FSDP instances with different ``MixedPrecision``
192configurations, we recommend setting individual ``cast_forward_inputs``
193values to configure casting inputs or not before each instance's
194forward. In such a case, since the casts happen before each FSDP
195instance's forward, a parent FSDP instance should have its non-FSDP
196submodules run before its FSDP submodules to avoid the activation dtype
197being changed due to a different ``MixedPrecision`` configuration.
198
199Example::
200
201>>> # xdoctest: +SKIP("undefined variables")
202>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
203>>> model[1] = FSDP(
204>>> model[1],
205>>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
206>>> )
207>>> model = FSDP(
208>>> model,
209>>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
210>>> )
211
212The above shows a working example. On the other hand, if ``model[1]``
213were replaced with ``model[0]``, meaning that the submodule using
214different ``MixedPrecision`` ran its forward first, then ``model[1]``
215would incorrectly see ``float16`` activations instead of ``bfloat16``
216ones.
217
218"""
219
220param_dtype: Optional[torch.dtype] = None221reduce_dtype: Optional[torch.dtype] = None222buffer_dtype: Optional[torch.dtype] = None223keep_low_precision_grads: bool = False224cast_forward_inputs: bool = False225cast_root_forward_inputs: bool = True226_module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)227
228
229@dataclass
230class CPUOffload:231"""232This configures CPU offloading.
233
234Attributes:
235offload_params (bool): This specifies whether to offload parameters to
236CPU when not involved in computation. If ``True``, then this
237offloads gradients to CPU as well, meaning that the optimizer step
238runs on CPU.
239"""
240
241offload_params: bool = False242
243
244class StateDictType(Enum):245"""246This enum indicates that which type of ``state_dict`` the FSDP module is
247currently processing (returning or loading).
248The default value is FULL_STATE_DICT to comply the PyTorch convention.
249..note::
250FSDP currently supports three types of ``state_dict``:
2511. ``state_dict/load_state_dict`: this pair of APIs return and load
252the non-sharded, unflattened parameters. The semantics is the
253same as using DDP.
2542. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
255and load local sharded, flattened parameters. The values returned
256by ``_local_state_dict`` can be directly used by FSDP and is only
257meaningful to FSDP (because parameters are flattened). Note that
258these APIs are meant for use via the :func:`state_dict_type`
259context manager as follows:
260>>> # xdoctest: +SKIP("undefined variables")
261>>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
262... state = fsdp.state_dict() # loads local state dict
2633. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
264return and load sharded, unflattened parameters. The ``state_dict``
265return by ``sharded_state_dict`` can be used by all other parallel
266schemes (resharding may be required).
267"""
268
269FULL_STATE_DICT = auto()270LOCAL_STATE_DICT = auto()271SHARDED_STATE_DICT = auto()272
273
274@dataclass
275class StateDictConfig:276"""277``StateDictConfig`` is the base class for all ``state_dict`` configuration
278classes. Users should instantiate a child class (e.g.
279``FullStateDictConfig``) in order to configure settings for the
280corresponding ``state_dict`` type supported by FSDP.
281
282Attributes:
283offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict
284values to CPU, and if ``False``, then FSDP keeps them on GPU.
285(Default: ``False``)
286"""
287
288offload_to_cpu: bool = False289
290
291@dataclass
292class FullStateDictConfig(StateDictConfig):293"""294``FullStateDictConfig`` is a config class meant to be used with
295``StateDictType.FULL_STATE_DICT``. We recommend enabling both
296``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state
297dicts to save GPU memory and CPU memory, respectively. This config class
298is meant to be used via the :func:`state_dict_type` context manager as
299follows:
300
301>>> # xdoctest: +SKIP("undefined variables")
302>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
303>>> fsdp = FSDP(model, auto_wrap_policy=...)
304>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
305>>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
306>>> state = fsdp.state_dict()
307>>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
308>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
309>>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
310>>> if dist.get_rank() == 0:
311>>> # Load checkpoint only on rank 0 to avoid memory redundancy
312>>> state_dict = torch.load("my_checkpoint.pt")
313>>> model.load_state_dict(state_dict)
314>>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
315>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
316>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
317>>> # After this point, all ranks have FSDP model with loaded checkpoint.
318
319Attributes:
320rank0_only (bool): If ``True``, then only rank 0 saves the full state
321dict, and nonzero ranks save an empty dict. If ``False``, then all
322ranks save the full state dict. (Default: ``False``)
323"""
324
325rank0_only: bool = False326
327
328@dataclass
329class LocalStateDictConfig(StateDictConfig):330pass331
332
333@dataclass
334class ShardedStateDictConfig(StateDictConfig):335"""336``ShardedStateDictConfig`` is a config class meant to be used with
337``StateDictType.SHARDED_STATE_DICT``.
338
339Attributes:
340_use_dtensor (bool): If ``True``, then FSDP saves the state dict values
341as ``DTensor``, and if ``False``, then FSDP saves them as
342``ShardedTensor``. (Default: ``False``)
343
344.. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig`
345and it is used by FSDP to determine the type of state dict values. Users should not
346manually modify ``_use_dtensor``.
347"""
348
349_use_dtensor: bool = False350
351
352@dataclass
353class OptimStateDictConfig:354"""355``OptimStateDictConfig`` is the base class for all ``optim_state_dict``
356configuration classes. Users should instantiate a child class (e.g.
357``FullOptimStateDictConfig``) in order to configure settings for the
358corresponding ``optim_state_dict`` type supported by FSDP.
359
360Attributes:
361offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's
362tensor values to CPU, and if ``False``, then FSDP keeps them on the
363original device (which is GPU unless parameter CPU offloading is
364enabled). (Default: ``True``)
365"""
366
367offload_to_cpu: bool = True368
369
370@dataclass
371class FullOptimStateDictConfig(OptimStateDictConfig):372"""373Attributes:
374rank0_only (bool): If ``True``, then only rank 0 saves the full state
375dict, and nonzero ranks save an empty dict. If ``False``, then all
376ranks save the full state dict. (Default: ``False``)
377"""
378
379rank0_only: bool = False380
381
382@dataclass
383class LocalOptimStateDictConfig(OptimStateDictConfig):384offload_to_cpu: bool = False385
386
387@dataclass
388class ShardedOptimStateDictConfig(OptimStateDictConfig):389"""390``ShardedOptimStateDictConfig`` is a config class meant to be used with
391``StateDictType.SHARDED_STATE_DICT``.
392
393Attributes:
394_use_dtensor (bool): If ``True``, then FSDP saves the state dict values
395as ``DTensor``, and if ``False``, then FSDP saves them as
396``ShardedTensor``. (Default: ``False``)
397
398.. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig`
399and it is used by FSDP to determine the type of state dict values. Users should not
400manually modify ``_use_dtensor``.
401"""
402
403_use_dtensor: bool = False404
405
406@dataclass
407class StateDictSettings:408state_dict_type: StateDictType409state_dict_config: StateDictConfig410optim_state_dict_config: OptimStateDictConfig411