pytorch

Форк
0
/
_init_utils.py 
1182 строки · 44.5 Кб
1
import collections
2
import itertools
3
import os
4
import warnings
5
from typing import (
6
    Any,
7
    Callable,
8
    Deque,
9
    Dict,
10
    Generator,
11
    Iterable,
12
    Iterator,
13
    List,
14
    no_type_check,
15
    Optional,
16
    Set,
17
    Tuple,
18
    Union,
19
)
20

21
import torch
22
import torch.distributed as dist
23
import torch.distributed.fsdp._exec_order_utils as exec_order_utils
24
import torch.distributed.fsdp._traversal_utils as traversal_utils
25
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
26
import torch.nn as nn
27
from torch.distributed.algorithms._comm_hooks import default_hooks
28
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
29
from torch.distributed.distributed_c10d import _get_default_group
30
from torch.distributed.fsdp._common_utils import (
31
    _FSDPDeviceHandle,
32
    _FSDPState,
33
    _get_module_fsdp_state,
34
    _is_fsdp_flattened,
35
    _named_parameters_with_duplicates,
36
    clean_tensor_name,
37
    TrainingState,
38
)
39
from torch.distributed.fsdp._flat_param import (
40
    _FSDP_USE_FULL_PREC_IN_EVAL,
41
    FlatParameter,
42
    FlatParamHandle,
43
    HandleShardingStrategy,
44
)
45
from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
46
from torch.distributed.fsdp.api import (
47
    BackwardPrefetch,
48
    CPUOffload,
49
    FullOptimStateDictConfig,
50
    FullStateDictConfig,
51
    MixedPrecision,
52
    ShardingStrategy,
53
    StateDictConfig,
54
    StateDictType,
55
)
56
from torch.distributed.fsdp.wrap import _Policy
57
from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
58
from torch.distributed.utils import _sync_params_and_buffers
59

60
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
61
from torch.utils.hooks import RemovableHandle
62

63
_TORCHDISTX_AVAIL = True
64
try:
65
    from torchdistx import deferred_init, fake  # type: ignore[import]
66
except ImportError:
67
    _TORCHDISTX_AVAIL = False
68

69
PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
70
FSDP_SYNCED = "_fsdp_synced"
71
# Specification of process groups for hybrid sharding strategies.
72
HybridShardProcessGroupType = Tuple[dist.ProcessGroup, dist.ProcessGroup]
73
# Overall specification of process group.
74
ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]]
75

76

77
# TODO (awgu): Refactor this later
78
SHARDING_STRATEGY_MAP = {
79
    ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
80
    ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
81
    ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
82
    ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
83
    ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
84
}
85
HYBRID_SHARDING_STRATEGIES = [
86
    ShardingStrategy.HYBRID_SHARD,
87
    ShardingStrategy._HYBRID_SHARD_ZERO2,
88
]
89
NO_RESHARD_AFTER_FORWARD_STRATEGIES = (
90
    ShardingStrategy.SHARD_GRAD_OP,
91
    ShardingStrategy._HYBRID_SHARD_ZERO2,
92
)
93

94

95
# NOTE: Since non-self attributes cannot be type annotated, several attributes
96
# on `state` are defined first as local variables before being assigned.
97

98

99
@no_type_check
100
def _init_process_group_state(
101
    state: _FSDPState,
102
    process_group: ProcessGroupType,
103
    sharding_strategy: ShardingStrategy,
104
    policy: Optional[_Policy],
105
    device_mesh: Optional[DeviceMesh] = None,
106
) -> _FSDPState:
107
    if process_group is not None and device_mesh is not None:
108
        raise ValueError(
109
            "Cannot pass both process_group and device_mesh at the "
110
            "same time. Please just pass only one of them."
111
        )
112
    is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES
113
    if is_hybrid_strategy:
114
        if process_group is None and policy is None and device_mesh is None:
115
            # Raise an error here, since this is manual wrapping with no process group
116
            # passed in, there is no way to ensure all wrapped FSDP instances use the same
117
            # process groups.
118
            raise ValueError(
119
                f"Manual wrapping with {sharding_strategy}",
120
                "requires explicit specification of process group or device_mesh.",
121
            )
122
        else:
123
            state = _init_process_group_state_for_hybrid_shard(
124
                state, process_group, device_mesh
125
            )
126
    else:
127
        if device_mesh:
128
            state._device_mesh = device_mesh
129
            state.process_group = device_mesh.get_group(mesh_dim=0)
130
        else:
131
            state.process_group = (
132
                process_group if process_group is not None else _get_default_group()
133
            )
134

135
    state.rank = state.process_group.rank()
136
    state.world_size = state.process_group.size()
137
    data_parallel_world_size = state.world_size
138
    if is_hybrid_strategy:
139
        data_parallel_world_size *= state._inter_node_pg.size()
140
    state._gradient_predivide_factor = (
141
        default_hooks.DefaultState._get_gradient_predivide_factor(
142
            data_parallel_world_size
143
        )
144
    )
145
    state._gradient_postdivide_factor = (
146
        data_parallel_world_size / state._gradient_predivide_factor
147
    )
148
    return state
149

150

151
@no_type_check
152
def _init_process_group_state_for_hybrid_shard(
153
    state: _FSDPState,
154
    process_group: ProcessGroupType,
155
    device_mesh: DeviceMesh,
156
) -> _FSDPState:
157
    if device_mesh:
158
        if _is_valid_hybrid_shard_device_mesh(device_mesh):
159
            state._device_mesh = device_mesh
160
            # We currently only allow _inter_node_pg to be the outermost dimension, and the
161
            # process_group(intra_node) to be the innermost dimension.
162
            state._inter_node_pg = device_mesh.get_group(mesh_dim=0)
163
            state.process_group = device_mesh.get_group(mesh_dim=1)
164
        else:
165
            raise ValueError(
166
                "Expected device_mesh to have ndim=2 "
167
                f"but got {len(device_mesh.get_group())}"
168
            )
169
    elif process_group is None:
170
        default_group = _get_default_group()
171
        intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
172
            default_group, state._device_handle.device_count()
173
        )
174
        # we shard across intra-node
175
        state.process_group = intra_node_group
176
        # save _inter_node_pg to allreduce across.
177
        state._inter_node_pg = inter_node_group
178
    else:
179
        # Check type and assign state.process_group and state._inter_node_pg.
180
        if _is_valid_hybrid_shard_pg_type(process_group):
181
            # Assuming that user passed in as intra node group and inter node group
182
            # as documented.
183
            state.process_group, state._inter_node_pg = process_group
184
        else:
185
            raise ValueError(
186
                "Expected process_group to be passed in as either None or "
187
                f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}"
188
            )
189
    # Create state for allreduce
190
    state._inter_node_state = _get_default_comm_hook_state(
191
        process_group=state._inter_node_pg,
192
    )
193
    return state
194

195

196
@no_type_check
197
def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:
198
    return (
199
        isinstance(process_group, tuple)
200
        and len(process_group) == 2
201
        and all(isinstance(pg, dist.ProcessGroup) for pg in process_group)
202
    )
203

204

205
@no_type_check
206
def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool:
207
    return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2
208

209

210
@no_type_check
211
def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup:
212
    """
213
    Return a process group across the current node.
214

215
    For example, given each row is a distinct node:
216
    0 1 2 3 4 5 6 7 8
217
    9 10 11 12 13 14 15
218
    This API would return an intra-node subgroup across
219
    [0, 7] or [8, 15] depending on the process's rank.
220
    For example, rank 3 would get [0, 7].
221
    """
222
    intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node)
223
    return intra_node_subgroup
224

225

226
@no_type_check
227
def _init_inter_node_process_group(
228
    global_process_group: dist.ProcessGroup,
229
    num_devices_per_node: int,
230
) -> dist.ProcessGroup:
231
    """
232
    Return an inter-node process group where each contained rank has the same local rank.
233

234
    For example, given each row is a distinct node:
235
    0 1 2 3 4 5 6 7 8
236
    9 10 11 12 13 14 15
237
    This API would return inter-node process group {0, 8}, {1, 9}, {2, 10}, and so forth
238
    depending on the process's rank. For example, rank 1 would get {1, 9}, rank 5
239
    would get {5, 13}.
240
    """
241
    # the inter-node pg that is returned
242
    inter_node_pg = None
243
    sharding_backend = dist.get_backend(global_process_group)
244
    world_size = dist.get_world_size(global_process_group)
245
    # Assuming fully homogeneous setup
246
    num_nodes = world_size // num_devices_per_node
247
    my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node
248
    for local_rank in range(num_devices_per_node):
249
        ranks_for_inter_group = [
250
            local_rank + (i * num_devices_per_node) for i in range(num_nodes)
251
        ]
252
        # every rank always needs to call dist.new_group
253
        grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
254
        if local_rank == my_local_rank:
255
            inter_node_pg = grp
256

257
    assert (
258
        inter_node_pg is not None
259
    ), f"{my_local_rank} expected to assign inter-node pg, but did not"
260
    return inter_node_pg
261

262

263
def _init_intra_and_inter_node_groups(
264
    global_process_group: dist.ProcessGroup,
265
    num_devices_per_node: int,
266
) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]:
267
    """
268
    Initialize intra and inter-node process groups and return the ones corresponding to this process's rank.
269

270
    This function can be used to initialize process groups for ``HYBRID_SHARD`` or
271
    ``_HYBRID_SHARD_ZERO2`` in FSDP.
272
    This function assumes each node has an equal number of CUDA-enabled devices.
273
    Returns:
274
        Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
275
    """
276
    return (
277
        _init_intra_node_process_group(num_devices_per_node),
278
        _init_inter_node_process_group(global_process_group, num_devices_per_node),
279
    )
280

281

282
@no_type_check
283
def _init_ignored_module_states(
284
    state: _FSDPState,
285
    module: nn.Module,
286
    ignored_modules: Optional[Iterable[torch.nn.Module]],
287
    ignored_states: Union[
288
        Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
289
    ] = None,
290
) -> _FSDPState:
291
    if ignored_modules is not None and ignored_states is not None:
292
        raise ValueError(
293
            "Cannot pass both ignored_modules and ignored_states at the "
294
            "same time. Please just pass ignored_states."
295
        )
296
    ignored_parameters = None
297
    passed_as_ignored_states = ignored_states is not None
298
    if passed_as_ignored_states:
299
        ignored_states_list = list(ignored_states)
300
        _check_ignored_states(ignored_states_list, True)
301
    else:
302
        ignored_states_list = []
303
        _check_ignored_states(
304
            list(ignored_modules) if ignored_modules is not None else [], False
305
        )
306
    if len(ignored_states_list) > 0:
307
        if isinstance(ignored_states_list[0], nn.Parameter):
308
            ignored_parameters = ignored_states_list
309
        else:
310
            ignored_modules = ignored_states_list
311
    state._ignored_modules = _get_ignored_modules(module, ignored_modules)
312
    state._ignored_params = _get_ignored_params(
313
        module,
314
        state._ignored_modules,
315
        ignored_parameters,
316
    )
317
    state._ignored_buffer_names = _get_ignored_buffer_names(
318
        module,
319
        state._ignored_modules,
320
    )
321
    # TODO: FSDP's contract for buffers is not well-defined. They are
322
    # implicitly ignored for most functionality since they are not sharded;
323
    # however, FSDP still imposes some semantics on buffers (e.g. buffer mixed
324
    # precision). We should formalize this contract and decide if we need to
325
    # compute and store `_ignored_buffers`.
326
    return state
327

328

329
def _check_ignored_states(
330
    ignored_states: List[Any], passed_as_ignored_states: bool
331
) -> None:
332
    """
333
    Check that the ignored states are uniformly parameters or uniformly modules.
334

335
    We may remove this check in the future if we permit mixing.
336
    """
337
    if len(ignored_states) == 0:
338
        return
339
    if passed_as_ignored_states:
340
        all_params = all(isinstance(state, nn.Parameter) for state in ignored_states)
341
        all_modules = all(isinstance(state, nn.Module) for state in ignored_states)
342
        if not all_params and not all_modules:
343
            # Sort for consistent ordering for unit test regex matching
344
            sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
345
            raise ValueError(
346
                "ignored_states expects all nn.Parameter or all nn.Module list "
347
                f"elements but got types {sorted_types}"
348
            )
349
    else:
350
        if not all(isinstance(state, nn.Module) for state in ignored_states):
351
            sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
352
            raise ValueError(
353
                "ignored_modules expects nn.Module list elements but got "
354
                f"types {sorted_types}"
355
            )
356

357

358
@no_type_check
359
def _init_device_handle(
360
    state: _FSDPState,
361
    module: nn.Module,
362
    ignored_params: Set[nn.Parameter],
363
    device_id: Optional[Union[int, torch.device]],
364
) -> _FSDPState:
365
    """
366
    Determine device handle used for initializing FSDP.
367

368
    If a device is specified by ``device_id``,
369
    then returns device handle corresponds to that device type. Otherwise, If the
370
    module is already on a non-CPU device, then the device type is that non-CPU device type.
371
    If the module is on CPU or meta, then the device type is the current cuda device.
372

373
    This method will be called once ignored paramters was determined, as the device handle maybe needed
374
    for other initialization.
375
    """
376
    determined_device = None
377
    if device_id is not None:
378
        determined_device = (
379
            device_id
380
            if isinstance(device_id, torch.device)
381
            else torch.device(device_id)
382
        )
383
    if determined_device is None:
384
        for param in _get_orig_params(module, ignored_params):
385
            if param.device.type in {"cpu", "meta"}:
386
                continue
387
            if determined_device is None:
388
                determined_device = param.device
389
            else:
390
                if param.device.type != determined_device.type:
391
                    raise RuntimeError(
392
                        f"FSDP does not support modules with different device types "
393
                        f"but got params on {determined_device.type} and {param.device.type}"
394
                    )
395
        determined_device = determined_device or torch.device(
396
            "cuda", torch.cuda.current_device()
397
        )
398

399
    state._device_handle = _FSDPDeviceHandle.from_device(determined_device)
400
    return state
401

402

403
@no_type_check
404
def _init_buffer_state(
405
    state: _FSDPState,
406
    module: nn.Module,
407
) -> _FSDPState:
408
    state._buffer_names = _get_buffer_names(module)
409
    # Save a mapping from clean fully-qualified buffer name (starting from
410
    # `module`) to its original dtype for restoring that dtype during model
411
    # checkpointing when buffer mixed precision is enabled. The names should
412
    # be clean since the casting happens in a `summon_full_params()` context.
413
    _buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
414
    for buffer_name, buffer in module.named_buffers():
415
        buffer_name = clean_tensor_name(buffer_name)
416
        _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
417
    state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
418
    return state
419

420

421
@no_type_check
422
def _init_core_state(
423
    state: _FSDPState,
424
    sharding_strategy: Optional[ShardingStrategy],
425
    mixed_precision: Optional[MixedPrecision],
426
    cpu_offload: Optional[CPUOffload],
427
    limit_all_gathers: bool,
428
    use_orig_params: bool,
429
    backward_prefetch_limit: int,
430
    forward_prefetch_limit: int,
431
) -> _FSDPState:
432
    # We clamp the strategy to `NO_SHARD` for world size of 1 since they are
433
    # currently functionally equivalent. This may change if/when we integrate
434
    # FSDP with MoE.
435
    if state.world_size == 1:
436
        if sharding_strategy != ShardingStrategy.NO_SHARD:
437
            warnings.warn(
438
                "FSDP is switching to use `NO_SHARD` instead of "
439
                f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since "
440
                "the world size is 1."
441
            )
442
        sharding_strategy = ShardingStrategy.NO_SHARD
443
    elif sharding_strategy == ShardingStrategy.NO_SHARD:
444
        warnings.warn(
445
            "The `NO_SHARD` sharding strategy is deprecated. If having issues, "
446
            "please use DistributedDataParallel instead.",
447
            # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and
448
            # level 3 is from the true caller
449
            stacklevel=3,
450
        )
451
    state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
452
    state.mixed_precision = mixed_precision or MixedPrecision()
453
    if mixed_precision is not None:
454
        torch._C._log_api_usage_once(
455
            f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}"
456
        )
457
    state._use_full_prec_in_eval = (
458
        os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
459
    )
460
    state.cpu_offload = cpu_offload or CPUOffload()
461
    state.limit_all_gathers = limit_all_gathers
462
    state._use_orig_params = use_orig_params
463
    state.training_state = TrainingState.IDLE
464
    state._is_root = None
465
    state._free_event_queue = _FreeEventQueue()
466
    state._debug_level = dist.get_debug_level()
467
    state._exec_order_data = exec_order_utils._ExecOrderData(
468
        state._debug_level,
469
        backward_prefetch_limit,
470
        forward_prefetch_limit,
471
    )
472
    # Mapping from fully sharded module to the handles it is responsible to
473
    # unshard and reshard (see [Note: Fully Sharded Module])
474
    _fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = dict()
475
    state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle
476
    # Invariant: `state.params` contains exactly the `FlatParameter`s of the
477
    # handles in `state._handle`
478
    _handle: FlatParamHandle = None
479
    state._handle = _handle
480
    params: List[FlatParameter] = []
481
    state.params = params
482
    return state
483

484

485
@no_type_check
486
def _init_runtime_state(
487
    state: _FSDPState,
488
) -> _FSDPState:
489
    _root_pre_forward_handles: List[RemovableHandle] = []
490
    state._root_pre_forward_handles = _root_pre_forward_handles
491
    _pre_forward_handles: List[RemovableHandle] = []
492
    state._pre_forward_handles = _pre_forward_handles
493
    _post_forward_handles: List[RemovableHandle] = []
494
    state._post_forward_handles = _post_forward_handles
495
    state._sync_gradients = True
496
    state._comm_hook = None
497
    state._comm_hook_state = None
498
    # Used to prevent running the pre-backward hook multiple times
499
    return state
500

501

502
@no_type_check
503
def _init_prefetching_state(
504
    state: _FSDPState,
505
    backward_prefetch: BackwardPrefetch,
506
    forward_prefetch: bool,
507
) -> _FSDPState:
508
    state.backward_prefetch = backward_prefetch
509
    state.forward_prefetch = forward_prefetch
510
    # The data structures use tuples of handles to generalize over the case
511
    # where a module's forward involves multiple handles.
512
    return state
513

514

515
@no_type_check
516
def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
517
    # TODO: we need to add additional check once we support FSDP + PiPPy.
518
    # This check is currently sufficient, since we only support FSDP + TP.
519
    if device_mesh and _mesh_resources.get_parent_mesh(state._device_mesh) is not None:
520
        state._fsdp_extension = DTensorExtensions(state._device_handle)
521
    else:
522
        # We need to explicilty set _fsdp_extension to None.
523
        # Otherwise, we will run into an infinite recursion when getting the attribute.
524
        state._fsdp_extension = None
525
    return state
526

527

528
@no_type_check
529
def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
530
    state._state_dict_type = StateDictType.FULL_STATE_DICT
531
    state_dict_config: StateDictConfig = FullStateDictConfig()
532
    state._optim_state_dict_config = FullOptimStateDictConfig()
533
    state._state_dict_config = state_dict_config
534
    unshard_params_ctx: Dict[nn.Module, Generator] = {}
535
    state._unshard_params_ctx = unshard_params_ctx
536

537
    return state
538

539

540
@no_type_check
541
def _init_param_handle_from_module(
542
    state: _FSDPState,
543
    fully_sharded_module: nn.Module,
544
    device_id: Optional[Union[int, torch.device]],
545
    param_init_fn: Optional[Callable[[nn.Module], None]],
546
    sync_module_states: bool,
547
) -> _FSDPState:
548
    """Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``."""
549
    _check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
550
    device_from_device_id = _get_device_from_device_id(device_id, state.rank)
551
    is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
552
        fully_sharded_module, state._ignored_params, state._ignored_modules
553
    )
554
    # Materialize the module if needed
555
    if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None:
556
        _materialize_with_param_init_fn(
557
            fully_sharded_module, param_init_fn, state._ignored_modules
558
        )
559
    elif is_meta_module:
560
        _materialize_meta_module(
561
            fully_sharded_module, device_id, state._ignored_modules
562
        )
563
    elif is_torchdistX_deferred_init:
564
        deferred_init.materialize_module(
565
            fully_sharded_module,
566
            check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None
567
            and submodule not in state._ignored_modules,
568
        )
569

570
    ignored_buffers = {
571
        buffer
572
        for ignored_module in state._ignored_modules
573
        for buffer in ignored_module.buffers()
574
    }
575

576
    _move_module_to_device(
577
        fully_sharded_module,
578
        state._ignored_params,
579
        ignored_buffers,
580
        device_from_device_id,
581
    )
582
    state.compute_device = _get_compute_device(
583
        fully_sharded_module,
584
        state._ignored_params,
585
        device_from_device_id,
586
        state.rank,
587
    )
588

589
    managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
590
    if sync_module_states:
591
        _sync_module_params_and_buffers(
592
            fully_sharded_module, managed_params, state.process_group
593
        )
594
        if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
595
            _sync_module_params_and_buffers(
596
                fully_sharded_module, managed_params, state._inter_node_pg
597
            )
598
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
599
    return state
600

601

602
@no_type_check
603
def _init_param_handle_from_params(
604
    state: _FSDPState,
605
    params: List[nn.Parameter],
606
    fully_sharded_module: nn.Module,
607
):
608
    if len(params) == 0:
609
        return
610
    handle = FlatParamHandle(
611
        params,
612
        fully_sharded_module,
613
        state.compute_device,
614
        SHARDING_STRATEGY_MAP[state.sharding_strategy],
615
        state.cpu_offload.offload_params,
616
        state.mixed_precision.param_dtype,
617
        state.mixed_precision.reduce_dtype,
618
        state.mixed_precision.keep_low_precision_grads,
619
        state.process_group,
620
        state._use_orig_params,
621
        fsdp_extension=state._fsdp_extension,
622
    )
623
    handle.shard()
624
    assert not state._handle
625
    state.params.append(handle.flat_param)
626
    state._handle = handle
627
    state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle
628
    cpu_device = torch.device("cpu")
629
    if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device:
630
        handle.flat_param_to(cpu_device)
631

632

633
def _get_ignored_modules(
634
    root_module: nn.Module,
635
    _ignored_modules: Optional[Iterable[torch.nn.Module]],
636
) -> Set[nn.Module]:
637
    """
638
    Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances.
639

640
    Return the modules contained in their module
641
    subtrees as a :class:`set`. Nested FSDP instances are excluded, but their
642
    already-computed ignored modules are included.
643

644
    ``_ignored_modules`` represents the argument passed by the user to FSDP.
645
    """
646
    msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
647
    try:
648
        ignored_root_modules = (
649
            set(_ignored_modules) if _ignored_modules is not None else set()
650
        )
651
    except TypeError as e:
652
        raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
653
    for module in ignored_root_modules:
654
        if not isinstance(module, torch.nn.Module):
655
            raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
656
        if _get_module_fsdp_state(module):
657
            # TODO: We may relax this by taking the FSDP instance's wrapped
658
            # module to provide more flexibility to the user.
659
            raise ValueError("`ignored_modules` should not include FSDP modules")
660
    # Treat modules that cannot compose with `fully_shard` as ignored modules,
661
    # meaning that their subtrees are ignored
662
    for module in root_module.modules():
663
        if not traversal_utils._composable(module):
664
            ignored_root_modules.add(module)
665
    # NOTE: Even if `ignored_root_modules` is empty, do not return early so
666
    # that this FSDP instance can get any ignored modules from its children.
667

668
    # Include child modules and exclude nested FSDP modules themselves
669
    ignored_modules = {
670
        child
671
        for module in ignored_root_modules
672
        for child in module.modules()
673
        if not isinstance(child, fsdp_file.FullyShardedDataParallel)
674
    }
675
    if root_module in ignored_modules:
676
        warnings.warn(
677
            "Trying to ignore the top-level module passed into the FSDP "
678
            "constructor itself will result in all parameters being "
679
            f"ignored and is not well-supported: {module}"
680
        )
681
    # Include nested FSDP modules' ignored modules
682
    for submodule in root_module.modules():
683
        optional_fsdp_state = _get_module_fsdp_state(submodule)
684
        if optional_fsdp_state is not None:
685
            assert hasattr(optional_fsdp_state, "_ignored_modules")
686
            ignored_modules.update(optional_fsdp_state._ignored_modules)
687
    return ignored_modules
688

689

690
def _get_ignored_params(
691
    root_module: torch.nn.Module,
692
    ignored_modules: Set[torch.nn.Module],
693
    ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
694
) -> Set[torch.nn.Parameter]:
695
    """
696
    Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``.
697

698
    :class:`FlatParameter` s are excluded from the result.
699
    """
700
    all_ignored_params: Set[torch.nn.Parameter] = set()
701

702
    params_in_ignored_modules = {
703
        p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
704
    }
705

706
    all_ignored_params.update(params_in_ignored_modules)
707

708
    if ignored_parameters is not None:
709
        params_in_ignored_parameters = {
710
            p for p in ignored_parameters if not _is_fsdp_flattened(p)
711
        }
712
        all_ignored_params.update(params_in_ignored_parameters)
713

714
    # Always include nested FSDP modules' ignored parameters
715
    for submodule in root_module.modules():
716
        optional_fsdp_state = _get_module_fsdp_state(submodule)
717
        if optional_fsdp_state is not None:
718
            assert hasattr(optional_fsdp_state, "_ignored_params")
719
            all_ignored_params.update(optional_fsdp_state._ignored_params)
720

721
    return all_ignored_params
722

723

724
def _get_ignored_buffer_names(
725
    root_module: torch.nn.Module,
726
    ignored_modules: Set[torch.nn.Module],
727
) -> Set[str]:
728
    """Return the cleaned buffer FQNs in ``ignored_modules``."""
729
    all_ignored_buffer_names: Set[str] = set()
730

731
    buffers_in_ignored_modules = {
732
        buffer for m in ignored_modules for buffer in m.buffers()
733
    }
734

735
    all_ignored_buffer_names.update(
736
        {
737
            clean_tensor_name(buffer_name)
738
            for buffer_name, buffer in root_module.named_buffers()
739
            if buffer in buffers_in_ignored_modules
740
        }
741
    )
742

743
    # Always include nested FSDP modules' ignored buffer names
744
    for submodule in root_module.modules():
745
        optional_fsdp_state = _get_module_fsdp_state(submodule)
746
        if optional_fsdp_state is not None:
747
            assert hasattr(optional_fsdp_state, "_ignored_buffer_names")
748
            all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names)
749

750
    return all_ignored_buffer_names
751

752

753
def _get_buffer_names(root_module: nn.Module) -> Set[str]:
754
    """Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`."""
755
    return {
756
        clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
757
    }
758

759

760
def _check_single_device_module(
761
    module: nn.Module,
762
    ignored_params: Set[nn.Parameter],
763
    device_id: Optional[Union[int, torch.device]],
764
) -> None:
765
    """
766
    Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``.
767

768
    Thus, after this method, the
769
    module must be either fully on the CPU or fully on a non-CPU device.
770
    """
771
    devices = {param.device for param in _get_orig_params(module, ignored_params)}
772
    # We allow module to be partially on CPU and partially on GPU if device_id is not
773
    # None, since the device_id arg will result in the CPU portion being moved to
774
    # GPU. This is useful in cases where part of the module may be parallelized
775
    # by another algorithm and may already be on GPU. We'd like to enforce device_id
776
    # to not be None, otherwise we'd flatten parameters in a mixed module which is
777
    # not supported.
778
    if len(devices) == 2 and torch.device("cpu") in devices:
779
        if device_id is None:
780
            raise RuntimeError(
781
                "To support a module with both CPU and GPU params, "
782
                "please pass in device_id argument."
783
            )
784
    elif len(devices) > 1:
785
        raise RuntimeError(
786
            f"FSDP only supports single device modules but got params on {devices}"
787
        )
788

789

790
def _get_device_from_device_id(
791
    device_id: Optional[Union[int, torch.device]],
792
    rank: int,
793
) -> Optional[torch.device]:
794
    """
795
    Return a ``torch.device`` for the specified ``device_id``.
796

797
    Processes ``device_id`` and returns either the corresponding device or
798
    ``None`` if ``device_id`` is ``None``.
799
    """
800
    if device_id is None:
801
        return None
802
    device = (
803
        device_id if isinstance(device_id, torch.device) else torch.device(device_id)
804
    )
805
    if device == torch.device("cuda"):
806
        warnings.warn(
807
            f"FSDP got the argument `device_id` {device_id} on rank "
808
            f"{rank}, which does not have an explicit index. "
809
            f"FSDP will use the current device {torch.cuda.current_device()}. "
810
            "If this is incorrect, please explicitly call `torch.cuda.set_device()` "
811
            "before FSDP initialization or pass in the explicit device "
812
            "index as the `device_id` argument."
813
        )
814
        device = torch.device("cuda", torch.cuda.current_device())
815
    return device
816

817

818
def _need_to_materialize_module(
819
    module: nn.Module,
820
    ignored_params: Set[nn.Parameter],
821
    ignored_modules: Set[nn.Module],
822
) -> Tuple[bool, bool]:
823
    """
824
    Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization.
825

826
    At most of the returned bools can
827
    be ``True``. If either is ``True``, then ``module`` needs to be
828
    materialized.
829
    """
830
    managed_params = list(_get_orig_params(module, ignored_params))
831
    is_meta_module = any(param.is_meta for param in managed_params)
832
    # TODO: We need to establish a contract for FSDP and buffers. For now, we
833
    # skip checking for meta buffers from ignored modules. We should consider
834
    # refactoring the initialization holistically to avoid so many traversals.
835
    for submodule in module.modules():
836
        if submodule in ignored_modules:
837
            continue
838
        for buf in submodule.buffers(recurse=False):
839
            is_meta_module |= buf.is_meta
840
    is_torchdistX_deferred_init = (
841
        not is_meta_module
842
        and _TORCHDISTX_AVAIL
843
        and any(fake.is_fake(param) for param in managed_params)
844
    )
845
    return is_meta_module, is_torchdistX_deferred_init
846

847

848
def _materialize_with_param_init_fn(
849
    root_module: nn.Module,
850
    param_init_fn: Callable[[nn.Module], None],
851
    ignored_modules: Set[nn.Module],
852
) -> None:
853
    if not callable(param_init_fn):
854
        raise ValueError(
855
            f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}"
856
        )
857
    modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
858
    for module in modules_to_materialize:
859
        param_init_fn(module)
860

861

862
def _materialize_meta_module(
863
    root_module: nn.Module,
864
    device_from_device_id: Optional[torch.device],
865
    ignored_modules: Set[nn.Module],
866
):
867
    # Run default meta device initialization
868
    materialization_device = device_from_device_id or torch.device(
869
        torch.cuda.current_device()
870
    )
871
    modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
872
    try:
873
        # Assume that each module's `reset_parameters()` only initializes its
874
        # own parameters and not those of its children
875
        with torch.no_grad():
876
            for module in modules_to_materialize:
877
                # As a contract to the user, only call `reset_parameters()` if
878
                # the module has directly managed parameters/buffers
879
                module_state_iter = itertools.chain(
880
                    module.parameters(recurse=False), module.buffers(recurse=False)
881
                )
882
                has_module_states = len(list(module_state_iter)) > 0
883
                if has_module_states:
884
                    module.to_empty(device=materialization_device, recurse=False)
885
                    module.reset_parameters()  # type: ignore[operator]
886
    except BaseException as e:
887
        warnings.warn(
888
            "Unable to call `reset_parameters()` for module on meta "
889
            f"device with error {str(e)}. Please ensure that your module of"
890
            f"type {type(module)} implements a `reset_parameters()` method."  # type: ignore[possibly-undefined]
891
        )
892
        raise e
893

894

895
def _get_modules_to_materialize(
896
    root_module: nn.Module, ignored_modules: Set[nn.Module]
897
) -> List[nn.Module]:
898
    # Run BFS to collect the modules to materialize via `reset_parameters()`,
899
    # stopping at any module with FSDP already applied or at ignored modules.
900
    modules_to_materialize: List[nn.Module] = []
901
    queue = collections.deque([root_module])
902
    visited_modules: Set[nn.Module] = {root_module}
903
    while queue:
904
        module = queue.popleft()
905
        modules_to_materialize.append(module)
906
        for child_module in module.children():
907
            if (
908
                child_module not in visited_modules
909
                and _get_module_fsdp_state(child_module) is None
910
                and child_module not in ignored_modules
911
            ):
912
                visited_modules.add(child_module)
913
                queue.append(child_module)
914
    return modules_to_materialize
915

916

917
def _move_module_to_device(
918
    module: nn.Module,
919
    ignored_params: Set[nn.Parameter],
920
    ignored_buffers: Set[torch.Tensor],
921
    device_from_device_id: Optional[torch.device],
922
) -> None:
923
    """
924
    Move ``module`` depending on ``device_from_device_id`` and its current device.
925

926
    This includes moving ignored modules' parameters.
927

928
    - If ``device_from_device_id`` is not ``None``, then this moves
929
    ``module`` to the device.
930
    - If ``device_from_device_id`` is ``None``, then this does not move
931
    ``module`` but warns the user if it is on CPU.
932

933
    Precondition: ``_check_single_device_module()``.
934
    """
935
    cpu_device = torch.device("cpu")
936
    if device_from_device_id is not None:
937
        # BFS from `module` without traversing any nested FSDP instances to
938
        # collect the parameters/buffers that have not yet been managed
939
        queue: Deque[nn.Module] = collections.deque()
940
        queue.append(module)
941
        params: List[nn.Parameter] = []
942
        buffers: List[torch.Tensor] = []
943
        while queue:
944
            curr_module = queue.popleft()
945
            # NOTE: We include a check to only move parameters/buffers that are
946
            # on CPU device. If they are on a CUDA device different from the
947
            # one specified by `device_id`, then this does NOT move them. This
948
            # is so that we can raise an error in `_get_compute_device()`.
949
            params.extend(
950
                param
951
                for param in curr_module.parameters(recurse=False)
952
                if param.device == cpu_device
953
            )
954
            buffers.extend(
955
                buffer
956
                for buffer in curr_module.buffers(recurse=False)
957
                if buffer.device == cpu_device
958
            )
959
            for submodule in curr_module.children():
960
                if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
961
                    queue.append(submodule)
962
        params_to_move = [p for p in params if p not in ignored_params]
963
        bufs_to_move = [p for p in buffers if p not in ignored_buffers]
964
        _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id)
965
        return
966
    param = next(_get_orig_params(module, ignored_params), None)
967
    if param is not None and param.device == cpu_device:
968
        _warn_cpu_init()
969

970

971
def _move_states_to_device(
972
    params: List[nn.Parameter],
973
    buffers: List[torch.Tensor],
974
    device_from_device_id: Optional[torch.device],
975
) -> None:
976
    """
977
    Move states to the specified device.
978

979
    Precondition: ``_check_single_device_module()`` and module's parameters and
980
    buffers have been materialized if needed.
981
    """
982
    if len(params) == 0 and len(buffers) == 0:
983
        return
984
    if len(params) > 0:
985
        current_device = params[0].device
986
    elif len(buffers) > 0:
987
        current_device = buffers[0].device
988
    cpu_device = torch.device("cpu")
989
    if device_from_device_id is not None:
990
        # Move the parameters and buffers like the `.data` code path in
991
        # `nn.Module._apply()`, which underlies `nn.Module.to()`
992
        for param in params:
993
            with torch.no_grad():
994
                param.data = param.to(device_from_device_id)
995
                if param.grad is not None:
996
                    param.grad.data = param.grad.to(device_from_device_id)
997
        for buffer in buffers:
998
            buffer.data = buffer.to(device_from_device_id)
999
    elif current_device == cpu_device:  # type: ignore[possibly-undefined]
1000
        _warn_cpu_init()
1001

1002

1003
def _warn_cpu_init():
1004
    warnings.warn(
1005
        "The passed-in `module` is on CPU and will thus have FSDP's sharding "
1006
        "initialization run on CPU, which may be slower than on GPU. We "
1007
        "recommend passing in the `device_id` argument for FSDP to move "
1008
        "`module` to GPU for the sharding initialization. `module` must also "
1009
        "be on GPU device to work with the `sync_module_states=True` flag "
1010
        "since that requires GPU communication."
1011
    )
1012

1013

1014
def _get_compute_device(
1015
    module: nn.Module,
1016
    ignored_params: Set[nn.Parameter],
1017
    device_from_device_id: Optional[torch.device],
1018
    rank: int,
1019
) -> torch.device:
1020
    """
1021
    Determine and return this FSDP instance's compute device.
1022

1023
    If a device is
1024
    specified by ``device_id``, then returns that device. Otherwise, If the
1025
    module is already on a non-CPU device, then the compute device is that non-CPU
1026
    device. If the module is on CPU, then the compute device is the current
1027
    device.
1028

1029
    Since this method should be called after materializing the module, any
1030
    non-CPU device should not be meta device. For now, the compute device is
1031
    always a CUDA GPU device with its explicit index.
1032

1033
    Precondition: ``_check_single_device_module()`` and
1034
    ``_move_module_to_device()``.
1035
    """
1036
    param = next(_get_orig_params(module, ignored_params), None)
1037
    if param is not None and param.device.type != "cpu":
1038
        compute_device = param.device  # Determined by model param placement
1039
    else:
1040
        if device_from_device_id is not None and device_from_device_id.type != "cuda":
1041
            compute_device = device_from_device_id  # Determined by custom backend
1042
        else:
1043
            compute_device = torch.device("cuda", torch.cuda.current_device())
1044
    if device_from_device_id is not None and compute_device != device_from_device_id:
1045
        raise ValueError(
1046
            f"Inconsistent compute device and `device_id` on rank {rank}: "
1047
            f"{compute_device} vs {device_from_device_id}"
1048
        )
1049
    return compute_device
1050

1051

1052
# TODO: See how to deprecate!
1053
def _sync_module_params_and_buffers(
1054
    module: nn.Module,
1055
    params: List[nn.Parameter],
1056
    process_group: dist.ProcessGroup,
1057
) -> None:
1058
    """
1059
    Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks.
1060

1061
    Precondition: ``sync_module_states == True`` and ``self.process_group`` has
1062
    been set.
1063
    """
1064
    module_states: List[torch.Tensor] = []
1065
    for buffer in module.buffers():
1066
        # Avoid re-synchronizing buffers in case of nested wrapping
1067
        if not getattr(buffer, FSDP_SYNCED, False):
1068
            setattr(buffer, FSDP_SYNCED, True)
1069
            detached_buffer = buffer.detach()
1070
            if is_traceable_wrapper_subclass(detached_buffer):
1071
                # NOTE: Here we assume no nested subclasses, at most one level of subclass
1072
                # in both model's buffers and params
1073
                attrs, _ = detached_buffer.__tensor_flatten__()  # type: ignore[attr-defined]
1074
                inner_buffers = [getattr(detached_buffer, attr) for attr in attrs]
1075
                module_states.extend(inner_buffers)
1076
            else:
1077
                module_states.append(detached_buffer)
1078

1079
    for param in params:
1080
        detached_param = param.detach()
1081
        if is_traceable_wrapper_subclass(detached_param):
1082
            attrs, _ = detached_param.__tensor_flatten__()  # type: ignore[attr-defined]
1083
            inner_params = [getattr(detached_param, attr) for attr in attrs]
1084
            module_states.extend(inner_params)
1085
        else:
1086
            module_states.append(detached_param)
1087

1088
    _check_module_states_for_sync_module_states(module_states)
1089
    _sync_params_and_buffers(
1090
        process_group,
1091
        module_states,
1092
        PARAM_BROADCAST_BUCKET_SIZE,
1093
        src=0,
1094
    )
1095

1096

1097
def _sync_module_states(
1098
    params: List[nn.Parameter],
1099
    buffers: List[torch.Tensor],
1100
    process_group: dist.ProcessGroup,
1101
) -> None:
1102
    # Assumes that each call to this method passes in disjoint `params` and
1103
    # and `buffers` across calls, so there is no chance of re-synchronizing
1104
    params_and_buffers = [param.detach() for param in params] + [
1105
        buffer.detach() for buffer in buffers
1106
    ]
1107
    _check_module_states_for_sync_module_states(params_and_buffers)
1108
    _sync_params_and_buffers(
1109
        process_group,
1110
        params_and_buffers,
1111
        PARAM_BROADCAST_BUCKET_SIZE,
1112
        src=0,
1113
    )
1114

1115

1116
def _check_module_states_for_sync_module_states(
1117
    module_states: List[torch.Tensor],
1118
) -> None:
1119
    if module_states and any(
1120
        tensor.device == torch.device("cpu") for tensor in module_states
1121
    ):
1122
        raise ValueError(
1123
            "The module has CPU parameters or buffers when `sync_module_states=True`, "
1124
            "which requires them to be on GPU. Please specify the `device_id` argument "
1125
            "or move the module to GPU before passing it to FSDP."
1126
        )
1127

1128

1129
def _get_orig_params(
1130
    module: nn.Module,
1131
    ignored_params: Set[nn.Parameter],
1132
) -> Iterator[nn.Parameter]:
1133
    """
1134
    Return an iterator over the original parameters in ``module``.
1135

1136
    The iterator does not return
1137
    the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be
1138
    present due to nested FSDP wrapping), or any original parameters already
1139
    flattened (only relevant when ``use_orig_params=True``).
1140
    """
1141
    param_gen = module.parameters()
1142
    try:
1143
        while True:
1144
            param = next(param_gen)
1145
            if param not in ignored_params and not _is_fsdp_flattened(param):
1146
                yield param
1147
    except StopIteration:
1148
        pass
1149

1150

1151
def _check_orig_params_flattened(
1152
    fsdp_module,
1153
    ignored_params: Set[nn.Parameter],
1154
) -> None:
1155
    """
1156
    Check that original parameters in ``fsdp_module`` have been flattened.
1157

1158
    The flattened parameters are made
1159
    invisible to ``named_parameters()`` for the module hierarchy rooted at
1160
    ``fsdp_module``. This should be called as a sanity check after flattening
1161
    the wrapped module's parameters.
1162
    """
1163
    for param_name, param in _named_parameters_with_duplicates(fsdp_module):
1164
        if param not in ignored_params and not _is_fsdp_flattened(param):
1165
            raise RuntimeError(
1166
                f"Found an unflattened parameter: {param_name}; "
1167
                f"{param.size()} {param.__class__}"
1168
            )
1169

1170

1171
def _get_default_comm_hook(sharding_strategy: ShardingStrategy):
1172
    return (
1173
        default_hooks.allreduce_hook
1174
        if sharding_strategy == ShardingStrategy.NO_SHARD
1175
        else default_hooks.reduce_scatter_hook
1176
    )
1177

1178

1179
def _get_default_comm_hook_state(
1180
    process_group: dist.ProcessGroup,
1181
) -> default_hooks.DefaultState:
1182
    return default_hooks.DefaultState(process_group=process_group)
1183

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.