pytorch
1182 строки · 44.5 Кб
1import collections
2import itertools
3import os
4import warnings
5from typing import (
6Any,
7Callable,
8Deque,
9Dict,
10Generator,
11Iterable,
12Iterator,
13List,
14no_type_check,
15Optional,
16Set,
17Tuple,
18Union,
19)
20
21import torch
22import torch.distributed as dist
23import torch.distributed.fsdp._exec_order_utils as exec_order_utils
24import torch.distributed.fsdp._traversal_utils as traversal_utils
25import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
26import torch.nn as nn
27from torch.distributed.algorithms._comm_hooks import default_hooks
28from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
29from torch.distributed.distributed_c10d import _get_default_group
30from torch.distributed.fsdp._common_utils import (
31_FSDPDeviceHandle,
32_FSDPState,
33_get_module_fsdp_state,
34_is_fsdp_flattened,
35_named_parameters_with_duplicates,
36clean_tensor_name,
37TrainingState,
38)
39from torch.distributed.fsdp._flat_param import (
40_FSDP_USE_FULL_PREC_IN_EVAL,
41FlatParameter,
42FlatParamHandle,
43HandleShardingStrategy,
44)
45from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
46from torch.distributed.fsdp.api import (
47BackwardPrefetch,
48CPUOffload,
49FullOptimStateDictConfig,
50FullStateDictConfig,
51MixedPrecision,
52ShardingStrategy,
53StateDictConfig,
54StateDictType,
55)
56from torch.distributed.fsdp.wrap import _Policy
57from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
58from torch.distributed.utils import _sync_params_and_buffers
59
60from torch.utils._python_dispatch import is_traceable_wrapper_subclass
61from torch.utils.hooks import RemovableHandle
62
63_TORCHDISTX_AVAIL = True
64try:
65from torchdistx import deferred_init, fake # type: ignore[import]
66except ImportError:
67_TORCHDISTX_AVAIL = False
68
69PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
70FSDP_SYNCED = "_fsdp_synced"
71# Specification of process groups for hybrid sharding strategies.
72HybridShardProcessGroupType = Tuple[dist.ProcessGroup, dist.ProcessGroup]
73# Overall specification of process group.
74ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]]
75
76
77# TODO (awgu): Refactor this later
78SHARDING_STRATEGY_MAP = {
79ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
80ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
81ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
82ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
83ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
84}
85HYBRID_SHARDING_STRATEGIES = [
86ShardingStrategy.HYBRID_SHARD,
87ShardingStrategy._HYBRID_SHARD_ZERO2,
88]
89NO_RESHARD_AFTER_FORWARD_STRATEGIES = (
90ShardingStrategy.SHARD_GRAD_OP,
91ShardingStrategy._HYBRID_SHARD_ZERO2,
92)
93
94
95# NOTE: Since non-self attributes cannot be type annotated, several attributes
96# on `state` are defined first as local variables before being assigned.
97
98
99@no_type_check
100def _init_process_group_state(
101state: _FSDPState,
102process_group: ProcessGroupType,
103sharding_strategy: ShardingStrategy,
104policy: Optional[_Policy],
105device_mesh: Optional[DeviceMesh] = None,
106) -> _FSDPState:
107if process_group is not None and device_mesh is not None:
108raise ValueError(
109"Cannot pass both process_group and device_mesh at the "
110"same time. Please just pass only one of them."
111)
112is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES
113if is_hybrid_strategy:
114if process_group is None and policy is None and device_mesh is None:
115# Raise an error here, since this is manual wrapping with no process group
116# passed in, there is no way to ensure all wrapped FSDP instances use the same
117# process groups.
118raise ValueError(
119f"Manual wrapping with {sharding_strategy}",
120"requires explicit specification of process group or device_mesh.",
121)
122else:
123state = _init_process_group_state_for_hybrid_shard(
124state, process_group, device_mesh
125)
126else:
127if device_mesh:
128state._device_mesh = device_mesh
129state.process_group = device_mesh.get_group(mesh_dim=0)
130else:
131state.process_group = (
132process_group if process_group is not None else _get_default_group()
133)
134
135state.rank = state.process_group.rank()
136state.world_size = state.process_group.size()
137data_parallel_world_size = state.world_size
138if is_hybrid_strategy:
139data_parallel_world_size *= state._inter_node_pg.size()
140state._gradient_predivide_factor = (
141default_hooks.DefaultState._get_gradient_predivide_factor(
142data_parallel_world_size
143)
144)
145state._gradient_postdivide_factor = (
146data_parallel_world_size / state._gradient_predivide_factor
147)
148return state
149
150
151@no_type_check
152def _init_process_group_state_for_hybrid_shard(
153state: _FSDPState,
154process_group: ProcessGroupType,
155device_mesh: DeviceMesh,
156) -> _FSDPState:
157if device_mesh:
158if _is_valid_hybrid_shard_device_mesh(device_mesh):
159state._device_mesh = device_mesh
160# We currently only allow _inter_node_pg to be the outermost dimension, and the
161# process_group(intra_node) to be the innermost dimension.
162state._inter_node_pg = device_mesh.get_group(mesh_dim=0)
163state.process_group = device_mesh.get_group(mesh_dim=1)
164else:
165raise ValueError(
166"Expected device_mesh to have ndim=2 "
167f"but got {len(device_mesh.get_group())}"
168)
169elif process_group is None:
170default_group = _get_default_group()
171intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
172default_group, state._device_handle.device_count()
173)
174# we shard across intra-node
175state.process_group = intra_node_group
176# save _inter_node_pg to allreduce across.
177state._inter_node_pg = inter_node_group
178else:
179# Check type and assign state.process_group and state._inter_node_pg.
180if _is_valid_hybrid_shard_pg_type(process_group):
181# Assuming that user passed in as intra node group and inter node group
182# as documented.
183state.process_group, state._inter_node_pg = process_group
184else:
185raise ValueError(
186"Expected process_group to be passed in as either None or "
187f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}"
188)
189# Create state for allreduce
190state._inter_node_state = _get_default_comm_hook_state(
191process_group=state._inter_node_pg,
192)
193return state
194
195
196@no_type_check
197def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:
198return (
199isinstance(process_group, tuple)
200and len(process_group) == 2
201and all(isinstance(pg, dist.ProcessGroup) for pg in process_group)
202)
203
204
205@no_type_check
206def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool:
207return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2
208
209
210@no_type_check
211def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup:
212"""
213Return a process group across the current node.
214
215For example, given each row is a distinct node:
2160 1 2 3 4 5 6 7 8
2179 10 11 12 13 14 15
218This API would return an intra-node subgroup across
219[0, 7] or [8, 15] depending on the process's rank.
220For example, rank 3 would get [0, 7].
221"""
222intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node)
223return intra_node_subgroup
224
225
226@no_type_check
227def _init_inter_node_process_group(
228global_process_group: dist.ProcessGroup,
229num_devices_per_node: int,
230) -> dist.ProcessGroup:
231"""
232Return an inter-node process group where each contained rank has the same local rank.
233
234For example, given each row is a distinct node:
2350 1 2 3 4 5 6 7 8
2369 10 11 12 13 14 15
237This API would return inter-node process group {0, 8}, {1, 9}, {2, 10}, and so forth
238depending on the process's rank. For example, rank 1 would get {1, 9}, rank 5
239would get {5, 13}.
240"""
241# the inter-node pg that is returned
242inter_node_pg = None
243sharding_backend = dist.get_backend(global_process_group)
244world_size = dist.get_world_size(global_process_group)
245# Assuming fully homogeneous setup
246num_nodes = world_size // num_devices_per_node
247my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node
248for local_rank in range(num_devices_per_node):
249ranks_for_inter_group = [
250local_rank + (i * num_devices_per_node) for i in range(num_nodes)
251]
252# every rank always needs to call dist.new_group
253grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
254if local_rank == my_local_rank:
255inter_node_pg = grp
256
257assert (
258inter_node_pg is not None
259), f"{my_local_rank} expected to assign inter-node pg, but did not"
260return inter_node_pg
261
262
263def _init_intra_and_inter_node_groups(
264global_process_group: dist.ProcessGroup,
265num_devices_per_node: int,
266) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]:
267"""
268Initialize intra and inter-node process groups and return the ones corresponding to this process's rank.
269
270This function can be used to initialize process groups for ``HYBRID_SHARD`` or
271``_HYBRID_SHARD_ZERO2`` in FSDP.
272This function assumes each node has an equal number of CUDA-enabled devices.
273Returns:
274Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
275"""
276return (
277_init_intra_node_process_group(num_devices_per_node),
278_init_inter_node_process_group(global_process_group, num_devices_per_node),
279)
280
281
282@no_type_check
283def _init_ignored_module_states(
284state: _FSDPState,
285module: nn.Module,
286ignored_modules: Optional[Iterable[torch.nn.Module]],
287ignored_states: Union[
288Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
289] = None,
290) -> _FSDPState:
291if ignored_modules is not None and ignored_states is not None:
292raise ValueError(
293"Cannot pass both ignored_modules and ignored_states at the "
294"same time. Please just pass ignored_states."
295)
296ignored_parameters = None
297passed_as_ignored_states = ignored_states is not None
298if passed_as_ignored_states:
299ignored_states_list = list(ignored_states)
300_check_ignored_states(ignored_states_list, True)
301else:
302ignored_states_list = []
303_check_ignored_states(
304list(ignored_modules) if ignored_modules is not None else [], False
305)
306if len(ignored_states_list) > 0:
307if isinstance(ignored_states_list[0], nn.Parameter):
308ignored_parameters = ignored_states_list
309else:
310ignored_modules = ignored_states_list
311state._ignored_modules = _get_ignored_modules(module, ignored_modules)
312state._ignored_params = _get_ignored_params(
313module,
314state._ignored_modules,
315ignored_parameters,
316)
317state._ignored_buffer_names = _get_ignored_buffer_names(
318module,
319state._ignored_modules,
320)
321# TODO: FSDP's contract for buffers is not well-defined. They are
322# implicitly ignored for most functionality since they are not sharded;
323# however, FSDP still imposes some semantics on buffers (e.g. buffer mixed
324# precision). We should formalize this contract and decide if we need to
325# compute and store `_ignored_buffers`.
326return state
327
328
329def _check_ignored_states(
330ignored_states: List[Any], passed_as_ignored_states: bool
331) -> None:
332"""
333Check that the ignored states are uniformly parameters or uniformly modules.
334
335We may remove this check in the future if we permit mixing.
336"""
337if len(ignored_states) == 0:
338return
339if passed_as_ignored_states:
340all_params = all(isinstance(state, nn.Parameter) for state in ignored_states)
341all_modules = all(isinstance(state, nn.Module) for state in ignored_states)
342if not all_params and not all_modules:
343# Sort for consistent ordering for unit test regex matching
344sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
345raise ValueError(
346"ignored_states expects all nn.Parameter or all nn.Module list "
347f"elements but got types {sorted_types}"
348)
349else:
350if not all(isinstance(state, nn.Module) for state in ignored_states):
351sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
352raise ValueError(
353"ignored_modules expects nn.Module list elements but got "
354f"types {sorted_types}"
355)
356
357
358@no_type_check
359def _init_device_handle(
360state: _FSDPState,
361module: nn.Module,
362ignored_params: Set[nn.Parameter],
363device_id: Optional[Union[int, torch.device]],
364) -> _FSDPState:
365"""
366Determine device handle used for initializing FSDP.
367
368If a device is specified by ``device_id``,
369then returns device handle corresponds to that device type. Otherwise, If the
370module is already on a non-CPU device, then the device type is that non-CPU device type.
371If the module is on CPU or meta, then the device type is the current cuda device.
372
373This method will be called once ignored paramters was determined, as the device handle maybe needed
374for other initialization.
375"""
376determined_device = None
377if device_id is not None:
378determined_device = (
379device_id
380if isinstance(device_id, torch.device)
381else torch.device(device_id)
382)
383if determined_device is None:
384for param in _get_orig_params(module, ignored_params):
385if param.device.type in {"cpu", "meta"}:
386continue
387if determined_device is None:
388determined_device = param.device
389else:
390if param.device.type != determined_device.type:
391raise RuntimeError(
392f"FSDP does not support modules with different device types "
393f"but got params on {determined_device.type} and {param.device.type}"
394)
395determined_device = determined_device or torch.device(
396"cuda", torch.cuda.current_device()
397)
398
399state._device_handle = _FSDPDeviceHandle.from_device(determined_device)
400return state
401
402
403@no_type_check
404def _init_buffer_state(
405state: _FSDPState,
406module: nn.Module,
407) -> _FSDPState:
408state._buffer_names = _get_buffer_names(module)
409# Save a mapping from clean fully-qualified buffer name (starting from
410# `module`) to its original dtype for restoring that dtype during model
411# checkpointing when buffer mixed precision is enabled. The names should
412# be clean since the casting happens in a `summon_full_params()` context.
413_buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
414for buffer_name, buffer in module.named_buffers():
415buffer_name = clean_tensor_name(buffer_name)
416_buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
417state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
418return state
419
420
421@no_type_check
422def _init_core_state(
423state: _FSDPState,
424sharding_strategy: Optional[ShardingStrategy],
425mixed_precision: Optional[MixedPrecision],
426cpu_offload: Optional[CPUOffload],
427limit_all_gathers: bool,
428use_orig_params: bool,
429backward_prefetch_limit: int,
430forward_prefetch_limit: int,
431) -> _FSDPState:
432# We clamp the strategy to `NO_SHARD` for world size of 1 since they are
433# currently functionally equivalent. This may change if/when we integrate
434# FSDP with MoE.
435if state.world_size == 1:
436if sharding_strategy != ShardingStrategy.NO_SHARD:
437warnings.warn(
438"FSDP is switching to use `NO_SHARD` instead of "
439f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since "
440"the world size is 1."
441)
442sharding_strategy = ShardingStrategy.NO_SHARD
443elif sharding_strategy == ShardingStrategy.NO_SHARD:
444warnings.warn(
445"The `NO_SHARD` sharding strategy is deprecated. If having issues, "
446"please use DistributedDataParallel instead.",
447# Level 1 is here, level 2 is from `FullyShardedDataParallel`, and
448# level 3 is from the true caller
449stacklevel=3,
450)
451state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
452state.mixed_precision = mixed_precision or MixedPrecision()
453if mixed_precision is not None:
454torch._C._log_api_usage_once(
455f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}"
456)
457state._use_full_prec_in_eval = (
458os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
459)
460state.cpu_offload = cpu_offload or CPUOffload()
461state.limit_all_gathers = limit_all_gathers
462state._use_orig_params = use_orig_params
463state.training_state = TrainingState.IDLE
464state._is_root = None
465state._free_event_queue = _FreeEventQueue()
466state._debug_level = dist.get_debug_level()
467state._exec_order_data = exec_order_utils._ExecOrderData(
468state._debug_level,
469backward_prefetch_limit,
470forward_prefetch_limit,
471)
472# Mapping from fully sharded module to the handles it is responsible to
473# unshard and reshard (see [Note: Fully Sharded Module])
474_fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = dict()
475state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle
476# Invariant: `state.params` contains exactly the `FlatParameter`s of the
477# handles in `state._handle`
478_handle: FlatParamHandle = None
479state._handle = _handle
480params: List[FlatParameter] = []
481state.params = params
482return state
483
484
485@no_type_check
486def _init_runtime_state(
487state: _FSDPState,
488) -> _FSDPState:
489_root_pre_forward_handles: List[RemovableHandle] = []
490state._root_pre_forward_handles = _root_pre_forward_handles
491_pre_forward_handles: List[RemovableHandle] = []
492state._pre_forward_handles = _pre_forward_handles
493_post_forward_handles: List[RemovableHandle] = []
494state._post_forward_handles = _post_forward_handles
495state._sync_gradients = True
496state._comm_hook = None
497state._comm_hook_state = None
498# Used to prevent running the pre-backward hook multiple times
499return state
500
501
502@no_type_check
503def _init_prefetching_state(
504state: _FSDPState,
505backward_prefetch: BackwardPrefetch,
506forward_prefetch: bool,
507) -> _FSDPState:
508state.backward_prefetch = backward_prefetch
509state.forward_prefetch = forward_prefetch
510# The data structures use tuples of handles to generalize over the case
511# where a module's forward involves multiple handles.
512return state
513
514
515@no_type_check
516def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
517# TODO: we need to add additional check once we support FSDP + PiPPy.
518# This check is currently sufficient, since we only support FSDP + TP.
519if device_mesh and _mesh_resources.get_parent_mesh(state._device_mesh) is not None:
520state._fsdp_extension = DTensorExtensions(state._device_handle)
521else:
522# We need to explicilty set _fsdp_extension to None.
523# Otherwise, we will run into an infinite recursion when getting the attribute.
524state._fsdp_extension = None
525return state
526
527
528@no_type_check
529def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
530state._state_dict_type = StateDictType.FULL_STATE_DICT
531state_dict_config: StateDictConfig = FullStateDictConfig()
532state._optim_state_dict_config = FullOptimStateDictConfig()
533state._state_dict_config = state_dict_config
534unshard_params_ctx: Dict[nn.Module, Generator] = {}
535state._unshard_params_ctx = unshard_params_ctx
536
537return state
538
539
540@no_type_check
541def _init_param_handle_from_module(
542state: _FSDPState,
543fully_sharded_module: nn.Module,
544device_id: Optional[Union[int, torch.device]],
545param_init_fn: Optional[Callable[[nn.Module], None]],
546sync_module_states: bool,
547) -> _FSDPState:
548"""Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``."""
549_check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
550device_from_device_id = _get_device_from_device_id(device_id, state.rank)
551is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
552fully_sharded_module, state._ignored_params, state._ignored_modules
553)
554# Materialize the module if needed
555if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None:
556_materialize_with_param_init_fn(
557fully_sharded_module, param_init_fn, state._ignored_modules
558)
559elif is_meta_module:
560_materialize_meta_module(
561fully_sharded_module, device_id, state._ignored_modules
562)
563elif is_torchdistX_deferred_init:
564deferred_init.materialize_module(
565fully_sharded_module,
566check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None
567and submodule not in state._ignored_modules,
568)
569
570ignored_buffers = {
571buffer
572for ignored_module in state._ignored_modules
573for buffer in ignored_module.buffers()
574}
575
576_move_module_to_device(
577fully_sharded_module,
578state._ignored_params,
579ignored_buffers,
580device_from_device_id,
581)
582state.compute_device = _get_compute_device(
583fully_sharded_module,
584state._ignored_params,
585device_from_device_id,
586state.rank,
587)
588
589managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
590if sync_module_states:
591_sync_module_params_and_buffers(
592fully_sharded_module, managed_params, state.process_group
593)
594if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
595_sync_module_params_and_buffers(
596fully_sharded_module, managed_params, state._inter_node_pg
597)
598_init_param_handle_from_params(state, managed_params, fully_sharded_module)
599return state
600
601
602@no_type_check
603def _init_param_handle_from_params(
604state: _FSDPState,
605params: List[nn.Parameter],
606fully_sharded_module: nn.Module,
607):
608if len(params) == 0:
609return
610handle = FlatParamHandle(
611params,
612fully_sharded_module,
613state.compute_device,
614SHARDING_STRATEGY_MAP[state.sharding_strategy],
615state.cpu_offload.offload_params,
616state.mixed_precision.param_dtype,
617state.mixed_precision.reduce_dtype,
618state.mixed_precision.keep_low_precision_grads,
619state.process_group,
620state._use_orig_params,
621fsdp_extension=state._fsdp_extension,
622)
623handle.shard()
624assert not state._handle
625state.params.append(handle.flat_param)
626state._handle = handle
627state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle
628cpu_device = torch.device("cpu")
629if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device:
630handle.flat_param_to(cpu_device)
631
632
633def _get_ignored_modules(
634root_module: nn.Module,
635_ignored_modules: Optional[Iterable[torch.nn.Module]],
636) -> Set[nn.Module]:
637"""
638Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances.
639
640Return the modules contained in their module
641subtrees as a :class:`set`. Nested FSDP instances are excluded, but their
642already-computed ignored modules are included.
643
644``_ignored_modules`` represents the argument passed by the user to FSDP.
645"""
646msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
647try:
648ignored_root_modules = (
649set(_ignored_modules) if _ignored_modules is not None else set()
650)
651except TypeError as e:
652raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
653for module in ignored_root_modules:
654if not isinstance(module, torch.nn.Module):
655raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
656if _get_module_fsdp_state(module):
657# TODO: We may relax this by taking the FSDP instance's wrapped
658# module to provide more flexibility to the user.
659raise ValueError("`ignored_modules` should not include FSDP modules")
660# Treat modules that cannot compose with `fully_shard` as ignored modules,
661# meaning that their subtrees are ignored
662for module in root_module.modules():
663if not traversal_utils._composable(module):
664ignored_root_modules.add(module)
665# NOTE: Even if `ignored_root_modules` is empty, do not return early so
666# that this FSDP instance can get any ignored modules from its children.
667
668# Include child modules and exclude nested FSDP modules themselves
669ignored_modules = {
670child
671for module in ignored_root_modules
672for child in module.modules()
673if not isinstance(child, fsdp_file.FullyShardedDataParallel)
674}
675if root_module in ignored_modules:
676warnings.warn(
677"Trying to ignore the top-level module passed into the FSDP "
678"constructor itself will result in all parameters being "
679f"ignored and is not well-supported: {module}"
680)
681# Include nested FSDP modules' ignored modules
682for submodule in root_module.modules():
683optional_fsdp_state = _get_module_fsdp_state(submodule)
684if optional_fsdp_state is not None:
685assert hasattr(optional_fsdp_state, "_ignored_modules")
686ignored_modules.update(optional_fsdp_state._ignored_modules)
687return ignored_modules
688
689
690def _get_ignored_params(
691root_module: torch.nn.Module,
692ignored_modules: Set[torch.nn.Module],
693ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
694) -> Set[torch.nn.Parameter]:
695"""
696Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``.
697
698:class:`FlatParameter` s are excluded from the result.
699"""
700all_ignored_params: Set[torch.nn.Parameter] = set()
701
702params_in_ignored_modules = {
703p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
704}
705
706all_ignored_params.update(params_in_ignored_modules)
707
708if ignored_parameters is not None:
709params_in_ignored_parameters = {
710p for p in ignored_parameters if not _is_fsdp_flattened(p)
711}
712all_ignored_params.update(params_in_ignored_parameters)
713
714# Always include nested FSDP modules' ignored parameters
715for submodule in root_module.modules():
716optional_fsdp_state = _get_module_fsdp_state(submodule)
717if optional_fsdp_state is not None:
718assert hasattr(optional_fsdp_state, "_ignored_params")
719all_ignored_params.update(optional_fsdp_state._ignored_params)
720
721return all_ignored_params
722
723
724def _get_ignored_buffer_names(
725root_module: torch.nn.Module,
726ignored_modules: Set[torch.nn.Module],
727) -> Set[str]:
728"""Return the cleaned buffer FQNs in ``ignored_modules``."""
729all_ignored_buffer_names: Set[str] = set()
730
731buffers_in_ignored_modules = {
732buffer for m in ignored_modules for buffer in m.buffers()
733}
734
735all_ignored_buffer_names.update(
736{
737clean_tensor_name(buffer_name)
738for buffer_name, buffer in root_module.named_buffers()
739if buffer in buffers_in_ignored_modules
740}
741)
742
743# Always include nested FSDP modules' ignored buffer names
744for submodule in root_module.modules():
745optional_fsdp_state = _get_module_fsdp_state(submodule)
746if optional_fsdp_state is not None:
747assert hasattr(optional_fsdp_state, "_ignored_buffer_names")
748all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names)
749
750return all_ignored_buffer_names
751
752
753def _get_buffer_names(root_module: nn.Module) -> Set[str]:
754"""Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`."""
755return {
756clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
757}
758
759
760def _check_single_device_module(
761module: nn.Module,
762ignored_params: Set[nn.Parameter],
763device_id: Optional[Union[int, torch.device]],
764) -> None:
765"""
766Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``.
767
768Thus, after this method, the
769module must be either fully on the CPU or fully on a non-CPU device.
770"""
771devices = {param.device for param in _get_orig_params(module, ignored_params)}
772# We allow module to be partially on CPU and partially on GPU if device_id is not
773# None, since the device_id arg will result in the CPU portion being moved to
774# GPU. This is useful in cases where part of the module may be parallelized
775# by another algorithm and may already be on GPU. We'd like to enforce device_id
776# to not be None, otherwise we'd flatten parameters in a mixed module which is
777# not supported.
778if len(devices) == 2 and torch.device("cpu") in devices:
779if device_id is None:
780raise RuntimeError(
781"To support a module with both CPU and GPU params, "
782"please pass in device_id argument."
783)
784elif len(devices) > 1:
785raise RuntimeError(
786f"FSDP only supports single device modules but got params on {devices}"
787)
788
789
790def _get_device_from_device_id(
791device_id: Optional[Union[int, torch.device]],
792rank: int,
793) -> Optional[torch.device]:
794"""
795Return a ``torch.device`` for the specified ``device_id``.
796
797Processes ``device_id`` and returns either the corresponding device or
798``None`` if ``device_id`` is ``None``.
799"""
800if device_id is None:
801return None
802device = (
803device_id if isinstance(device_id, torch.device) else torch.device(device_id)
804)
805if device == torch.device("cuda"):
806warnings.warn(
807f"FSDP got the argument `device_id` {device_id} on rank "
808f"{rank}, which does not have an explicit index. "
809f"FSDP will use the current device {torch.cuda.current_device()}. "
810"If this is incorrect, please explicitly call `torch.cuda.set_device()` "
811"before FSDP initialization or pass in the explicit device "
812"index as the `device_id` argument."
813)
814device = torch.device("cuda", torch.cuda.current_device())
815return device
816
817
818def _need_to_materialize_module(
819module: nn.Module,
820ignored_params: Set[nn.Parameter],
821ignored_modules: Set[nn.Module],
822) -> Tuple[bool, bool]:
823"""
824Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization.
825
826At most of the returned bools can
827be ``True``. If either is ``True``, then ``module`` needs to be
828materialized.
829"""
830managed_params = list(_get_orig_params(module, ignored_params))
831is_meta_module = any(param.is_meta for param in managed_params)
832# TODO: We need to establish a contract for FSDP and buffers. For now, we
833# skip checking for meta buffers from ignored modules. We should consider
834# refactoring the initialization holistically to avoid so many traversals.
835for submodule in module.modules():
836if submodule in ignored_modules:
837continue
838for buf in submodule.buffers(recurse=False):
839is_meta_module |= buf.is_meta
840is_torchdistX_deferred_init = (
841not is_meta_module
842and _TORCHDISTX_AVAIL
843and any(fake.is_fake(param) for param in managed_params)
844)
845return is_meta_module, is_torchdistX_deferred_init
846
847
848def _materialize_with_param_init_fn(
849root_module: nn.Module,
850param_init_fn: Callable[[nn.Module], None],
851ignored_modules: Set[nn.Module],
852) -> None:
853if not callable(param_init_fn):
854raise ValueError(
855f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}"
856)
857modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
858for module in modules_to_materialize:
859param_init_fn(module)
860
861
862def _materialize_meta_module(
863root_module: nn.Module,
864device_from_device_id: Optional[torch.device],
865ignored_modules: Set[nn.Module],
866):
867# Run default meta device initialization
868materialization_device = device_from_device_id or torch.device(
869torch.cuda.current_device()
870)
871modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
872try:
873# Assume that each module's `reset_parameters()` only initializes its
874# own parameters and not those of its children
875with torch.no_grad():
876for module in modules_to_materialize:
877# As a contract to the user, only call `reset_parameters()` if
878# the module has directly managed parameters/buffers
879module_state_iter = itertools.chain(
880module.parameters(recurse=False), module.buffers(recurse=False)
881)
882has_module_states = len(list(module_state_iter)) > 0
883if has_module_states:
884module.to_empty(device=materialization_device, recurse=False)
885module.reset_parameters() # type: ignore[operator]
886except BaseException as e:
887warnings.warn(
888"Unable to call `reset_parameters()` for module on meta "
889f"device with error {str(e)}. Please ensure that your module of"
890f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined]
891)
892raise e
893
894
895def _get_modules_to_materialize(
896root_module: nn.Module, ignored_modules: Set[nn.Module]
897) -> List[nn.Module]:
898# Run BFS to collect the modules to materialize via `reset_parameters()`,
899# stopping at any module with FSDP already applied or at ignored modules.
900modules_to_materialize: List[nn.Module] = []
901queue = collections.deque([root_module])
902visited_modules: Set[nn.Module] = {root_module}
903while queue:
904module = queue.popleft()
905modules_to_materialize.append(module)
906for child_module in module.children():
907if (
908child_module not in visited_modules
909and _get_module_fsdp_state(child_module) is None
910and child_module not in ignored_modules
911):
912visited_modules.add(child_module)
913queue.append(child_module)
914return modules_to_materialize
915
916
917def _move_module_to_device(
918module: nn.Module,
919ignored_params: Set[nn.Parameter],
920ignored_buffers: Set[torch.Tensor],
921device_from_device_id: Optional[torch.device],
922) -> None:
923"""
924Move ``module`` depending on ``device_from_device_id`` and its current device.
925
926This includes moving ignored modules' parameters.
927
928- If ``device_from_device_id`` is not ``None``, then this moves
929``module`` to the device.
930- If ``device_from_device_id`` is ``None``, then this does not move
931``module`` but warns the user if it is on CPU.
932
933Precondition: ``_check_single_device_module()``.
934"""
935cpu_device = torch.device("cpu")
936if device_from_device_id is not None:
937# BFS from `module` without traversing any nested FSDP instances to
938# collect the parameters/buffers that have not yet been managed
939queue: Deque[nn.Module] = collections.deque()
940queue.append(module)
941params: List[nn.Parameter] = []
942buffers: List[torch.Tensor] = []
943while queue:
944curr_module = queue.popleft()
945# NOTE: We include a check to only move parameters/buffers that are
946# on CPU device. If they are on a CUDA device different from the
947# one specified by `device_id`, then this does NOT move them. This
948# is so that we can raise an error in `_get_compute_device()`.
949params.extend(
950param
951for param in curr_module.parameters(recurse=False)
952if param.device == cpu_device
953)
954buffers.extend(
955buffer
956for buffer in curr_module.buffers(recurse=False)
957if buffer.device == cpu_device
958)
959for submodule in curr_module.children():
960if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
961queue.append(submodule)
962params_to_move = [p for p in params if p not in ignored_params]
963bufs_to_move = [p for p in buffers if p not in ignored_buffers]
964_move_states_to_device(params_to_move, bufs_to_move, device_from_device_id)
965return
966param = next(_get_orig_params(module, ignored_params), None)
967if param is not None and param.device == cpu_device:
968_warn_cpu_init()
969
970
971def _move_states_to_device(
972params: List[nn.Parameter],
973buffers: List[torch.Tensor],
974device_from_device_id: Optional[torch.device],
975) -> None:
976"""
977Move states to the specified device.
978
979Precondition: ``_check_single_device_module()`` and module's parameters and
980buffers have been materialized if needed.
981"""
982if len(params) == 0 and len(buffers) == 0:
983return
984if len(params) > 0:
985current_device = params[0].device
986elif len(buffers) > 0:
987current_device = buffers[0].device
988cpu_device = torch.device("cpu")
989if device_from_device_id is not None:
990# Move the parameters and buffers like the `.data` code path in
991# `nn.Module._apply()`, which underlies `nn.Module.to()`
992for param in params:
993with torch.no_grad():
994param.data = param.to(device_from_device_id)
995if param.grad is not None:
996param.grad.data = param.grad.to(device_from_device_id)
997for buffer in buffers:
998buffer.data = buffer.to(device_from_device_id)
999elif current_device == cpu_device: # type: ignore[possibly-undefined]
1000_warn_cpu_init()
1001
1002
1003def _warn_cpu_init():
1004warnings.warn(
1005"The passed-in `module` is on CPU and will thus have FSDP's sharding "
1006"initialization run on CPU, which may be slower than on GPU. We "
1007"recommend passing in the `device_id` argument for FSDP to move "
1008"`module` to GPU for the sharding initialization. `module` must also "
1009"be on GPU device to work with the `sync_module_states=True` flag "
1010"since that requires GPU communication."
1011)
1012
1013
1014def _get_compute_device(
1015module: nn.Module,
1016ignored_params: Set[nn.Parameter],
1017device_from_device_id: Optional[torch.device],
1018rank: int,
1019) -> torch.device:
1020"""
1021Determine and return this FSDP instance's compute device.
1022
1023If a device is
1024specified by ``device_id``, then returns that device. Otherwise, If the
1025module is already on a non-CPU device, then the compute device is that non-CPU
1026device. If the module is on CPU, then the compute device is the current
1027device.
1028
1029Since this method should be called after materializing the module, any
1030non-CPU device should not be meta device. For now, the compute device is
1031always a CUDA GPU device with its explicit index.
1032
1033Precondition: ``_check_single_device_module()`` and
1034``_move_module_to_device()``.
1035"""
1036param = next(_get_orig_params(module, ignored_params), None)
1037if param is not None and param.device.type != "cpu":
1038compute_device = param.device # Determined by model param placement
1039else:
1040if device_from_device_id is not None and device_from_device_id.type != "cuda":
1041compute_device = device_from_device_id # Determined by custom backend
1042else:
1043compute_device = torch.device("cuda", torch.cuda.current_device())
1044if device_from_device_id is not None and compute_device != device_from_device_id:
1045raise ValueError(
1046f"Inconsistent compute device and `device_id` on rank {rank}: "
1047f"{compute_device} vs {device_from_device_id}"
1048)
1049return compute_device
1050
1051
1052# TODO: See how to deprecate!
1053def _sync_module_params_and_buffers(
1054module: nn.Module,
1055params: List[nn.Parameter],
1056process_group: dist.ProcessGroup,
1057) -> None:
1058"""
1059Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks.
1060
1061Precondition: ``sync_module_states == True`` and ``self.process_group`` has
1062been set.
1063"""
1064module_states: List[torch.Tensor] = []
1065for buffer in module.buffers():
1066# Avoid re-synchronizing buffers in case of nested wrapping
1067if not getattr(buffer, FSDP_SYNCED, False):
1068setattr(buffer, FSDP_SYNCED, True)
1069detached_buffer = buffer.detach()
1070if is_traceable_wrapper_subclass(detached_buffer):
1071# NOTE: Here we assume no nested subclasses, at most one level of subclass
1072# in both model's buffers and params
1073attrs, _ = detached_buffer.__tensor_flatten__() # type: ignore[attr-defined]
1074inner_buffers = [getattr(detached_buffer, attr) for attr in attrs]
1075module_states.extend(inner_buffers)
1076else:
1077module_states.append(detached_buffer)
1078
1079for param in params:
1080detached_param = param.detach()
1081if is_traceable_wrapper_subclass(detached_param):
1082attrs, _ = detached_param.__tensor_flatten__() # type: ignore[attr-defined]
1083inner_params = [getattr(detached_param, attr) for attr in attrs]
1084module_states.extend(inner_params)
1085else:
1086module_states.append(detached_param)
1087
1088_check_module_states_for_sync_module_states(module_states)
1089_sync_params_and_buffers(
1090process_group,
1091module_states,
1092PARAM_BROADCAST_BUCKET_SIZE,
1093src=0,
1094)
1095
1096
1097def _sync_module_states(
1098params: List[nn.Parameter],
1099buffers: List[torch.Tensor],
1100process_group: dist.ProcessGroup,
1101) -> None:
1102# Assumes that each call to this method passes in disjoint `params` and
1103# and `buffers` across calls, so there is no chance of re-synchronizing
1104params_and_buffers = [param.detach() for param in params] + [
1105buffer.detach() for buffer in buffers
1106]
1107_check_module_states_for_sync_module_states(params_and_buffers)
1108_sync_params_and_buffers(
1109process_group,
1110params_and_buffers,
1111PARAM_BROADCAST_BUCKET_SIZE,
1112src=0,
1113)
1114
1115
1116def _check_module_states_for_sync_module_states(
1117module_states: List[torch.Tensor],
1118) -> None:
1119if module_states and any(
1120tensor.device == torch.device("cpu") for tensor in module_states
1121):
1122raise ValueError(
1123"The module has CPU parameters or buffers when `sync_module_states=True`, "
1124"which requires them to be on GPU. Please specify the `device_id` argument "
1125"or move the module to GPU before passing it to FSDP."
1126)
1127
1128
1129def _get_orig_params(
1130module: nn.Module,
1131ignored_params: Set[nn.Parameter],
1132) -> Iterator[nn.Parameter]:
1133"""
1134Return an iterator over the original parameters in ``module``.
1135
1136The iterator does not return
1137the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be
1138present due to nested FSDP wrapping), or any original parameters already
1139flattened (only relevant when ``use_orig_params=True``).
1140"""
1141param_gen = module.parameters()
1142try:
1143while True:
1144param = next(param_gen)
1145if param not in ignored_params and not _is_fsdp_flattened(param):
1146yield param
1147except StopIteration:
1148pass
1149
1150
1151def _check_orig_params_flattened(
1152fsdp_module,
1153ignored_params: Set[nn.Parameter],
1154) -> None:
1155"""
1156Check that original parameters in ``fsdp_module`` have been flattened.
1157
1158The flattened parameters are made
1159invisible to ``named_parameters()`` for the module hierarchy rooted at
1160``fsdp_module``. This should be called as a sanity check after flattening
1161the wrapped module's parameters.
1162"""
1163for param_name, param in _named_parameters_with_duplicates(fsdp_module):
1164if param not in ignored_params and not _is_fsdp_flattened(param):
1165raise RuntimeError(
1166f"Found an unflattened parameter: {param_name}; "
1167f"{param.size()} {param.__class__}"
1168)
1169
1170
1171def _get_default_comm_hook(sharding_strategy: ShardingStrategy):
1172return (
1173default_hooks.allreduce_hook
1174if sharding_strategy == ShardingStrategy.NO_SHARD
1175else default_hooks.reduce_scatter_hook
1176)
1177
1178
1179def _get_default_comm_hook_state(
1180process_group: dist.ProcessGroup,
1181) -> default_hooks.DefaultState:
1182return default_hooks.DefaultState(process_group=process_group)
1183