pytorch

Форк
0
410 строк · 18.4 Кб
1
"""
2
This file includes public APIs for FSDP such as the classes used for the
3
constructor arguments.
4
"""
5

6
from dataclasses import dataclass
7
from enum import auto, Enum
8

9
from typing import Optional, Sequence, Type
10

11
import torch
12
from torch.nn.modules.batchnorm import _BatchNorm
13

14
__all__ = [
15
    "ShardingStrategy",
16
    "BackwardPrefetch",
17
    "MixedPrecision",
18
    "CPUOffload",
19
    "StateDictType",
20
    "StateDictConfig",
21
    "FullStateDictConfig",
22
    "LocalStateDictConfig",
23
    "ShardedStateDictConfig",
24
    "OptimStateDictConfig",
25
    "FullOptimStateDictConfig",
26
    "LocalOptimStateDictConfig",
27
    "ShardedOptimStateDictConfig",
28
    "StateDictSettings",
29
]
30

31

32
class ShardingStrategy(Enum):
33
    """
34
    This specifies the sharding strategy to be used for distributed training by
35
    :class:`FullyShardedDataParallel`.
36

37
    - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.
38
      For the parameters, this strategy unshards (via all-gather) before the
39
      forward, reshards after the forward, unshards before the backward
40
      computation, and reshards after the backward computation. For gradients,
41
      it synchronizes and shards them (via reduce-scatter) after the backward
42
      computation. The sharded optimizer states are updated locally per rank.
43
    - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during
44
      computation, and additionally, parameters are sharded outside
45
      computation. For the parameters, this strategy unshards before the
46
      forward, does not reshard them after the forward, and only reshards them
47
      after the backward computation. The sharded optimizer states are updated
48
      locally per rank. Inside ``no_sync()``, the parameters are not resharded
49
      after the backward computation.
50
    - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded
51
      but instead replicated across ranks similar to PyTorch's
52
      :class:`DistributedDataParallel` API. For gradients, this strategy
53
      synchronizes them (via all-reduce) after the backward computation. The
54
      unsharded optimizer states are updated locally per rank.
55
    - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
56
      nodes. This results in reduced communication volume as expensive all-gathers and
57
      reduce-scatters are only done within a node, which can be more performant for medium
58
      -sized models.
59
    - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
60
      nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
61
      since the unsharded parameters are not freed after the forward pass, saving the
62
      all-gathers in the pre-backward.
63
    """
64

65
    FULL_SHARD = auto()
66
    SHARD_GRAD_OP = auto()
67
    NO_SHARD = auto()
68
    HYBRID_SHARD = auto()
69
    _HYBRID_SHARD_ZERO2 = auto()
70

71

72
class BackwardPrefetch(Enum):
73
    """
74
    This configures explicit backward prefetching, which improves throughput by
75
    enabling communication and computation overlap in the backward pass at the
76
    cost of slightly increased memory usage.
77

78
    - ``BACKWARD_PRE``: This enables the most overlap but increases memory
79
      usage the most. This prefetches the next set of parameters *before* the
80
      current set of parameters' gradient computation. This overlaps the *next
81
      all-gather* and the *current gradient computation*, and at the peak, it
82
      holds the current set of parameters, next set of parameters, and current
83
      set of gradients in memory.
84
    - ``BACKWARD_POST``: This enables less overlap but requires less memory
85
      usage. This prefetches the next set of parameters *after* the current
86
      set of parameters' gradient computation. This overlaps the *current
87
      reduce-scatter* and the *next gradient computation*, and it frees the
88
      current set of parameters before allocating memory for the next set of
89
      parameters, only holding the next set of parameters and current set of
90
      gradients in memory at the peak.
91
    - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables
92
      the backward prefetching altogether. This has no overlap and does not
93
      increase memory usage. In general, we do not recommend this setting since
94
      it may degrade throughput significantly.
95

96
    For more technical context: For a single process group using NCCL backend,
97
    any collectives, even if issued from different streams, contend for the
98
    same per-device NCCL stream, which implies that the relative order in which
99
    the collectives are issued matters for overlapping. The two backward
100
    prefetching values correspond to different issue orders.
101
    """
102

103
    # NOTE: For both modes, the ordering that defines "current" and "next" is
104
    # not always exact in the current implementation. A mistargeted prefetch
105
    # simply means that the parameter memory is allocated earlier than needed,
106
    # possibly increasing peak memory usage, but does not affect correctness.
107
    BACKWARD_PRE = auto()
108
    BACKWARD_POST = auto()
109

110

111
@dataclass
112
class MixedPrecision:
113
    """
114
    This configures FSDP-native mixed precision training.
115

116
    Attributes:
117
        param_dtype (Optional[torch.dtype]): This specifies the dtype for model
118
            parameters during forward and backward and thus the dtype for
119
            forward and backward computation. Outside forward and backward, the
120
            *sharded* parameters are kept in full precision (e.g. for the
121
            optimizer step), and for model checkpointing, the parameters are
122
            always saved in full precision. (Default: ``None``)
123
        reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
124
            gradient reduction (i.e. reduce-scatter or all-reduce). If this is
125
            ``None`` but ``param_dtype`` is not ``None``, then this takes on
126
            the ``param_dtype`` value, still running gradient reduction in low
127
            precision. This is permitted to differ from ``param_dtype``, e.g.
128
            to force gradient reduction to run in full precision. (Default:
129
            ``None``)
130
        buffer_dtype (Optional[torch.dtype]): This specifies the dtype for
131
            buffers. FSDP does not shard buffers. Rather, FSDP casts them to
132
            ``buffer_dtype`` in the first forward pass and keeps them in that
133
            dtype thereafter. For model checkpointing, the buffers are saved
134
            in full precision except for ``LOCAL_STATE_DICT``. (Default:
135
            ``None``)
136
        keep_low_precision_grads (bool): If ``False``, then FSDP upcasts
137
            gradients to full precision after the backward pass in preparation
138
            for the optimizer step. If ``True``, then FSDP keeps the gradients
139
            in the dtype used for gradient reduction, which can save memory if
140
            using a custom optimizer that supports running in low precision.
141
            (Default: ``False``)
142
        cast_forward_inputs (bool): If ``True``, then this FSDP module casts
143
            its forward args and kwargs to ``param_dtype``. This is to ensure
144
            that parameter and input dtypes match for forward computation, as
145
            required by many ops. This may need to be set to ``True`` when only
146
            applying mixed precision to some but not all FSDP modules, in which
147
            case a mixed-precision FSDP submodule needs to recast its inputs.
148
            (Default: ``False``)
149
        cast_root_forward_inputs (bool): If ``True``, then the root FSDP module
150
            casts its forward args and kwargs to ``param_dtype``, overriding
151
            the value of ``cast_forward_inputs``. For non-root FSDP modules,
152
            this does not do anything. (Default: ``True``)
153
        _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies
154
            module classes to ignore for mixed precision when using an
155
            ``auto_wrap_policy``: Modules of these classes will have FSDP
156
            applied to them separately with mixed precision disabled (meaning
157
            that the final FSDP construction would deviate from the specified
158
            policy). If ``auto_wrap_policy`` is not specified, then this does
159
            not do anything. This API is experimental and subject to change.
160
            (Default: ``(_BatchNorm,)``)
161

162
    .. note:: This API is experimental and subject to change.
163

164
    .. note:: Only floating point tensors are cast to their specified dtypes.
165

166
    .. note:: In ``summon_full_params``, parameters are forced to full
167
        precision, but buffers are not.
168

169
    .. note:: Layer norm and batch norm accumulate in ``float32`` even when
170
        their inputs are in a low precision like ``float16`` or ``bfloat16``.
171
        Disabling FSDP's mixed precision for those norm modules only means that
172
        the affine parameters are kept in ``float32``. However, this incurs
173
        separate all-gathers and reduce-scatters for those norm modules, which
174
        may be inefficient, so if the workload permits, the user should prefer
175
        to still apply mixed precision to those modules.
176

177
    .. note:: By default, if the user passes a model with any ``_BatchNorm``
178
        modules and specifies an ``auto_wrap_policy``, then the batch norm
179
        modules will have FSDP applied to them separately with mixed precision
180
        disabled. See the ``_module_classes_to_ignore`` argument.
181

182
    .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
183
        ``cast_forward_inputs=False`` by default. For the root FSDP instance,
184
        its ``cast_root_forward_inputs`` takes precedence over its
185
        ``cast_forward_inputs``. For non-root FSDP instances, their
186
        ``cast_root_forward_inputs`` values are ignored. The default setting is
187
        sufficient for the typical case where each FSDP instance has the same
188
        ``MixedPrecision`` configuration and only needs to cast inputs to the
189
        ``param_dtype`` at the beginning of the model's forward pass.
190

191
    .. note:: For nested FSDP instances with different ``MixedPrecision``
192
        configurations, we recommend setting individual ``cast_forward_inputs``
193
        values to configure casting inputs or not before each instance's
194
        forward. In such a case, since the casts happen before each FSDP
195
        instance's forward, a parent FSDP instance should have its non-FSDP
196
        submodules run before its FSDP submodules to avoid the activation dtype
197
        being changed due to a different ``MixedPrecision`` configuration.
198

199
        Example::
200

201
            >>> # xdoctest: +SKIP("undefined variables")
202
            >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
203
            >>> model[1] = FSDP(
204
            >>>     model[1],
205
            >>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
206
            >>> )
207
            >>> model = FSDP(
208
            >>>     model,
209
            >>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
210
            >>> )
211

212
        The above shows a working example. On the other hand, if ``model[1]``
213
        were replaced with ``model[0]``, meaning that the submodule using
214
        different ``MixedPrecision`` ran its forward first, then ``model[1]``
215
        would incorrectly see ``float16`` activations instead of ``bfloat16``
216
        ones.
217

218
    """
219

220
    param_dtype: Optional[torch.dtype] = None
221
    reduce_dtype: Optional[torch.dtype] = None
222
    buffer_dtype: Optional[torch.dtype] = None
223
    keep_low_precision_grads: bool = False
224
    cast_forward_inputs: bool = False
225
    cast_root_forward_inputs: bool = True
226
    _module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)
227

228

229
@dataclass
230
class CPUOffload:
231
    """
232
    This configures CPU offloading.
233

234
    Attributes:
235
        offload_params (bool): This specifies whether to offload parameters to
236
            CPU when not involved in computation. If ``True``, then this
237
            offloads gradients to CPU as well, meaning that the optimizer step
238
            runs on CPU.
239
    """
240

241
    offload_params: bool = False
242

243

244
class StateDictType(Enum):
245
    """
246
    This enum indicates that which type of ``state_dict`` the FSDP module is
247
    currently processing (returning or loading).
248
    The default value is FULL_STATE_DICT to comply the PyTorch convention.
249
    ..note::
250
        FSDP currently supports three types of ``state_dict``:
251
            1. ``state_dict/load_state_dict`: this pair of APIs return and load
252
               the non-sharded, unflattened parameters. The semantics is the
253
               same as using DDP.
254
            2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
255
               and load local sharded, flattened parameters. The values returned
256
               by ``_local_state_dict`` can be directly used by FSDP and is only
257
               meaningful to FSDP (because parameters are flattened). Note that
258
               these APIs are meant for use via the :func:`state_dict_type`
259
               context manager as follows:
260
                   >>> # xdoctest: +SKIP("undefined variables")
261
                   >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
262
                   ...     state = fsdp.state_dict()  # loads local state dict
263
            3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
264
               return and load sharded, unflattened parameters. The ``state_dict``
265
               return by ``sharded_state_dict`` can be used by all other parallel
266
               schemes (resharding may be required).
267
    """
268

269
    FULL_STATE_DICT = auto()
270
    LOCAL_STATE_DICT = auto()
271
    SHARDED_STATE_DICT = auto()
272

273

274
@dataclass
275
class StateDictConfig:
276
    """
277
    ``StateDictConfig`` is the base class for all ``state_dict`` configuration
278
    classes. Users should instantiate a child class (e.g.
279
    ``FullStateDictConfig``) in order to configure settings for the
280
    corresponding ``state_dict`` type supported by FSDP.
281

282
    Attributes:
283
        offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict
284
            values to CPU, and if ``False``, then FSDP keeps them on GPU.
285
            (Default: ``False``)
286
    """
287

288
    offload_to_cpu: bool = False
289

290

291
@dataclass
292
class FullStateDictConfig(StateDictConfig):
293
    """
294
    ``FullStateDictConfig`` is a config class meant to be used with
295
    ``StateDictType.FULL_STATE_DICT``. We recommend enabling both
296
    ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state
297
    dicts to save GPU memory and CPU memory, respectively. This config class
298
    is meant to be used via the :func:`state_dict_type` context manager as
299
    follows:
300

301
        >>> # xdoctest: +SKIP("undefined variables")
302
        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
303
        >>> fsdp = FSDP(model, auto_wrap_policy=...)
304
        >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
305
        >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
306
        >>>     state = fsdp.state_dict()
307
        >>>     # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
308
        >>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
309
        >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
310
        >>> if dist.get_rank() == 0:
311
        >>>     # Load checkpoint only on rank 0 to avoid memory redundancy
312
        >>>     state_dict = torch.load("my_checkpoint.pt")
313
        >>>     model.load_state_dict(state_dict)
314
        >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
315
        >>> # communicates loaded checkpoint states from rank 0 to rest of the world.
316
        >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
317
        >>> # After this point, all ranks have FSDP model with loaded checkpoint.
318

319
    Attributes:
320
        rank0_only (bool): If ``True``, then only rank 0 saves the full state
321
            dict, and nonzero ranks save an empty dict. If ``False``, then all
322
            ranks save the full state dict. (Default: ``False``)
323
    """
324

325
    rank0_only: bool = False
326

327

328
@dataclass
329
class LocalStateDictConfig(StateDictConfig):
330
    pass
331

332

333
@dataclass
334
class ShardedStateDictConfig(StateDictConfig):
335
    """
336
    ``ShardedStateDictConfig`` is a config class meant to be used with
337
    ``StateDictType.SHARDED_STATE_DICT``.
338

339
    Attributes:
340
        _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
341
            as ``DTensor``, and if ``False``, then FSDP saves them as
342
            ``ShardedTensor``. (Default: ``False``)
343

344
    .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig`
345
      and it is used by FSDP to determine the type of state dict values. Users should not
346
      manually modify ``_use_dtensor``.
347
    """
348

349
    _use_dtensor: bool = False
350

351

352
@dataclass
353
class OptimStateDictConfig:
354
    """
355
    ``OptimStateDictConfig`` is the base class for all ``optim_state_dict``
356
    configuration classes.  Users should instantiate a child class (e.g.
357
    ``FullOptimStateDictConfig``) in order to configure settings for the
358
    corresponding ``optim_state_dict`` type supported by FSDP.
359

360
    Attributes:
361
        offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's
362
            tensor values to CPU, and if ``False``, then FSDP keeps them on the
363
            original device (which is GPU unless parameter CPU offloading is
364
            enabled). (Default: ``True``)
365
    """
366

367
    offload_to_cpu: bool = True
368

369

370
@dataclass
371
class FullOptimStateDictConfig(OptimStateDictConfig):
372
    """
373
    Attributes:
374
        rank0_only (bool): If ``True``, then only rank 0 saves the full state
375
            dict, and nonzero ranks save an empty dict. If ``False``, then all
376
            ranks save the full state dict. (Default: ``False``)
377
    """
378

379
    rank0_only: bool = False
380

381

382
@dataclass
383
class LocalOptimStateDictConfig(OptimStateDictConfig):
384
    offload_to_cpu: bool = False
385

386

387
@dataclass
388
class ShardedOptimStateDictConfig(OptimStateDictConfig):
389
    """
390
    ``ShardedOptimStateDictConfig`` is a config class meant to be used with
391
    ``StateDictType.SHARDED_STATE_DICT``.
392

393
    Attributes:
394
        _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
395
            as ``DTensor``, and if ``False``, then FSDP saves them as
396
            ``ShardedTensor``. (Default: ``False``)
397

398
    .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig`
399
      and it is used by FSDP to determine the type of state dict values. Users should not
400
      manually modify ``_use_dtensor``.
401
    """
402

403
    _use_dtensor: bool = False
404

405

406
@dataclass
407
class StateDictSettings:
408
    state_dict_type: StateDictType
409
    state_dict_config: StateDictConfig
410
    optim_state_dict_config: OptimStateDictConfig
411

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.