pytorch
262 строки · 10.6 Кб
1import collections
2import functools
3import inspect
4import warnings
5from functools import partial
6from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
7
8import torch.nn as nn
9from torch.distributed.fsdp._common_utils import (
10_get_module_fsdp_state,
11_override_module_mixed_precision,
12)
13
14from torch.distributed.fsdp.wrap import (
15_construct_wrap_fn,
16_or_policy,
17_Policy,
18_post_order_apply,
19_recursive_wrap,
20_run_mixed_precision_override_policy,
21_wrap_module_cls_individually,
22)
23
24
25def _auto_wrap(
26root_module: nn.Module,
27policy: Union[Callable, _Policy],
28ignored_modules: Set[nn.Module],
29ignored_params: Set[nn.Parameter],
30root_kwargs: Dict[str, Any],
31fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard`
32):
33"""
34Auto wraps modules in ``root_module`` 's tree according to ``policy``
35following a post-order traversal.
36
37Precondition: ``root_kwargs`` should contain all arguments except
38``module``. This function accepts the kwargs dict directly since it gets
39forwarded into the post-order traversal function.
40"""
41mixed_precision = root_kwargs["mixed_precision"]
42is_wrapper = inspect.isclass(fsdp_fn)
43# TODO: We may relax this no-nested-wrapping constraint to support manual
44# wrapping followed by auto wrapping.
45_check_nested_wrapping(root_module)
46
47if isinstance(policy, _Policy):
48root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
49target_module_to_kwargs = policy._run_policy(
50root_module, ignored_modules, root_kwargs
51)
52if mixed_precision is not None:
53target_module_to_kwargs = _run_mixed_precision_override_policy(
54root_module,
55mixed_precision._module_classes_to_ignore,
56ignored_modules,
57root_kwargs,
58target_module_to_kwargs,
59)
60overridden_module_classes = _override_module_mixed_precision(
61root_module, mixed_precision._module_classes_to_ignore
62)
63_warn_on_overridden_mixed_precision(overridden_module_classes)
64use_orig_params = root_kwargs.get("use_orig_params", False)
65_validate_frozen_params(
66root_module,
67set(target_module_to_kwargs.keys()),
68ignored_params,
69use_orig_params,
70)
71wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn)
72_post_order_apply(root_module, wrap_fn)
73return
74
75recursive_wrap_kwargs = {
76"module": root_module,
77"auto_wrap_policy": policy,
78"wrapper_cls": fsdp_fn,
79"ignored_modules": ignored_modules,
80"ignored_params": ignored_params,
81"only_wrap_children": True,
82}
83if mixed_precision is not None:
84# Wrap modules of the ignored types separately and register forward
85# hooks to cast to fp32 and back to the original dtype, respectively
86overridden_module_classes = _override_module_mixed_precision(
87root_module, mixed_precision._module_classes_to_ignore
88)
89policy = functools.partial(
90_or_policy,
91policies=[
92policy,
93partial(
94_wrap_module_cls_individually,
95module_classes=mixed_precision._module_classes_to_ignore,
96),
97],
98)
99recursive_wrap_kwargs["auto_wrap_policy"] = policy
100_warn_on_overridden_mixed_precision(overridden_module_classes)
101_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
102
103
104def _check_nested_wrapping(root_module: nn.Module):
105for module_name, module in root_module.named_modules():
106if _get_module_fsdp_state(module) is not None:
107raise ValueError(
108"FSDP auto wrapping requires modules to not already have "
109f"FSDP applied but found {module_name} in\n{root_module}"
110)
111
112
113def _warn_on_overridden_mixed_precision(
114overridden_module_classes: Set[Type[nn.Module]],
115):
116if len(overridden_module_classes) == 0:
117return
118warnings.warn(
119"Both mixed precision and an auto_wrap_policy were specified to FSDP, "
120f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n"
121"These modules will be wrapped as separate FSDP instacnes with mixed "
122"precision disabled."
123)
124
125
126def _validate_frozen_params(
127root_module: nn.Module,
128modules_to_wrap: Set[nn.Module],
129ignored_params: Set[nn.Parameter],
130use_orig_params: bool,
131):
132"""
133This checks that, given ``modules_to_wrap``, each module would manage
134parameters that are uniformly frozen or non-frozen. This uniformity
135requirement is strict for ``use_orig_params=False`` (hard error) and highly
136recommended for ``use_orig_params=True`` (user warning).
137"""
138post_order_named_modules = _get_post_order_named_modules(root_module)
139visited_modules: Set[nn.Module] = set()
140for module_name, module in post_order_named_modules:
141if module in modules_to_wrap:
142param_to_fqn = _get_managed_param_to_fqn(
143module, ignored_params, visited_modules, module_name
144)
145frozen_param_fqns: List[str] = []
146frozen_param_numel = 0
147nonfrozen_param_fqns: List[str] = []
148nonfrozen_param_numel = 0
149for param, fqn in param_to_fqn.items():
150if param.requires_grad:
151nonfrozen_param_fqns.append(fqn)
152nonfrozen_param_numel += param.numel()
153else:
154frozen_param_fqns.append(fqn)
155frozen_param_numel += param.numel()
156if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0:
157msg = f"{module_name} has both parameters with requires_grad=True and False."
158if use_orig_params:
159total_param_numel = frozen_param_numel + nonfrozen_param_numel
160msg += (
161" We do not recommend wrapping such modules since "
162"the gradient memory usage will be higher than expected "
163f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel "
164"before sharding via reduce-scatter). "
165)
166else:
167msg += " FSDP does not support wrapping such modules when use_orig_params=False. "
168msg += "If possible, wrap the frozen parameters with FSDP separately.\n"
169msg += (
170f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n"
171f"The following parameters have requires_grad=False:\n{frozen_param_fqns}"
172)
173if use_orig_params:
174warnings.warn(msg)
175else:
176raise ValueError(msg)
177
178
179def _get_post_order_named_modules(
180root_module: nn.Module,
181) -> List[Tuple[str, nn.Module]]:
182"""
183This returns the named modules following a post-order traversal, which is a
184valid reverse topological sort. We achieve this using the reverse of a
185stack-based DFS order instead of reversing ``root_module.named_modules()``
186since the former gives the modules in registration order at each level in
187the module tree (as opposed to the reverse), which allows us to error/warn
188on the first registered module that violates the condition.
189
190For example, consider the following module structure:
191M(
192S1(),
193S2(
194SS1(),
195SS2(),
196),
197S3(),
198)
199The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse
200``named_modules()`` order is [S3, SS2, SS1, S2, S1, M].
201"""
202visited_modules = {root_module}
203stack = [("", root_module)]
204# Append and reverse at the end for linear-time algorithm
205reverse_post_order_named_modules: List[Tuple[str, nn.Module]] = []
206while stack:
207module_name, module = stack.pop()
208reverse_post_order_named_modules.append((module_name, module))
209for child_module_name, child_module in module.named_children():
210if child_module is None: # only for overrides of `named_children()`
211continue
212if child_module not in visited_modules:
213visited_modules.add(child_module)
214if module_name != "":
215child_module_name = module_name + "." + child_module_name
216stack.append((child_module_name, child_module))
217post_order_named_modules = list(reversed(reverse_post_order_named_modules))
218return post_order_named_modules
219
220
221def _get_managed_param_to_fqn(
222module_to_wrap: nn.Module,
223ignored_params: Set[nn.Parameter],
224visited_modules: Set[nn.Module],
225root_prefix: str,
226) -> Dict[nn.Parameter, str]:
227"""
228This returns a dict that maps managed parameter to its FQN for the given
229``module_to_wrap``. The dict's keys are exactly the parameters that would
230be managed by the module, where this is achieved by calling this function
231on the modules to wrap in reverse topological order, destructively updating
232``visited_modules``, and not traversing into those modules. The FQNs are
233prefixed from the root (via ``root_prefix``) to be more informative.
234
235NOTE: This function is meant to be called pre-wrapping and iteratively in
236reverse topological order to cover the full module tree. This differs from
237the ``_get_param_to_fqn()`` function meant to be called post-wrapping and
238on the full module tree in one shot. Given those differences, we do not try
239to unify the two.
240"""
241param_to_fqn: Dict[nn.Parameter, str] = {}
242# Run BFS (or any tree traversal works)
243queue = collections.deque([(module_to_wrap, root_prefix)])
244visited_modules.add(module_to_wrap)
245while queue:
246module, prefix = queue.popleft()
247for param_name, param in module.named_parameters(recurse=False):
248if param not in ignored_params:
249fqn = param_name if prefix == "" else prefix + "." + param_name
250param_to_fqn[param] = fqn
251for child_module_name, child_module in module.named_children():
252if child_module is None: # only for overrides of `named_children()`
253continue
254if child_module not in visited_modules:
255visited_modules.add(child_module)
256child_prefix = (
257child_module_name
258if prefix == ""
259else prefix + "." + child_module_name
260)
261queue.append((child_module, child_prefix))
262return param_to_fqn
263