pytorch

Форк
0
/
_wrap_utils.py 
262 строки · 10.6 Кб
1
import collections
2
import functools
3
import inspect
4
import warnings
5
from functools import partial
6
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
7

8
import torch.nn as nn
9
from torch.distributed.fsdp._common_utils import (
10
    _get_module_fsdp_state,
11
    _override_module_mixed_precision,
12
)
13

14
from torch.distributed.fsdp.wrap import (
15
    _construct_wrap_fn,
16
    _or_policy,
17
    _Policy,
18
    _post_order_apply,
19
    _recursive_wrap,
20
    _run_mixed_precision_override_policy,
21
    _wrap_module_cls_individually,
22
)
23

24

25
def _auto_wrap(
26
    root_module: nn.Module,
27
    policy: Union[Callable, _Policy],
28
    ignored_modules: Set[nn.Module],
29
    ignored_params: Set[nn.Parameter],
30
    root_kwargs: Dict[str, Any],
31
    fsdp_fn: Callable,  # e.g. `FullyShardedDataParallel` or `fully_shard`
32
):
33
    """
34
    Auto wraps modules in ``root_module`` 's tree according to ``policy``
35
    following a post-order traversal.
36

37
    Precondition: ``root_kwargs`` should contain all arguments except
38
    ``module``. This function accepts the kwargs dict directly since it gets
39
    forwarded into the post-order traversal function.
40
    """
41
    mixed_precision = root_kwargs["mixed_precision"]
42
    is_wrapper = inspect.isclass(fsdp_fn)
43
    # TODO: We may relax this no-nested-wrapping constraint to support manual
44
    # wrapping followed by auto wrapping.
45
    _check_nested_wrapping(root_module)
46

47
    if isinstance(policy, _Policy):
48
        root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
49
        target_module_to_kwargs = policy._run_policy(
50
            root_module, ignored_modules, root_kwargs
51
        )
52
        if mixed_precision is not None:
53
            target_module_to_kwargs = _run_mixed_precision_override_policy(
54
                root_module,
55
                mixed_precision._module_classes_to_ignore,
56
                ignored_modules,
57
                root_kwargs,
58
                target_module_to_kwargs,
59
            )
60
            overridden_module_classes = _override_module_mixed_precision(
61
                root_module, mixed_precision._module_classes_to_ignore
62
            )
63
            _warn_on_overridden_mixed_precision(overridden_module_classes)
64
        use_orig_params = root_kwargs.get("use_orig_params", False)
65
        _validate_frozen_params(
66
            root_module,
67
            set(target_module_to_kwargs.keys()),
68
            ignored_params,
69
            use_orig_params,
70
        )
71
        wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn)
72
        _post_order_apply(root_module, wrap_fn)
73
        return
74

75
    recursive_wrap_kwargs = {
76
        "module": root_module,
77
        "auto_wrap_policy": policy,
78
        "wrapper_cls": fsdp_fn,
79
        "ignored_modules": ignored_modules,
80
        "ignored_params": ignored_params,
81
        "only_wrap_children": True,
82
    }
83
    if mixed_precision is not None:
84
        # Wrap modules of the ignored types separately and register forward
85
        # hooks to cast to fp32 and back to the original dtype, respectively
86
        overridden_module_classes = _override_module_mixed_precision(
87
            root_module, mixed_precision._module_classes_to_ignore
88
        )
89
        policy = functools.partial(
90
            _or_policy,
91
            policies=[
92
                policy,
93
                partial(
94
                    _wrap_module_cls_individually,
95
                    module_classes=mixed_precision._module_classes_to_ignore,
96
                ),
97
            ],
98
        )
99
        recursive_wrap_kwargs["auto_wrap_policy"] = policy
100
        _warn_on_overridden_mixed_precision(overridden_module_classes)
101
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
102

103

104
def _check_nested_wrapping(root_module: nn.Module):
105
    for module_name, module in root_module.named_modules():
106
        if _get_module_fsdp_state(module) is not None:
107
            raise ValueError(
108
                "FSDP auto wrapping requires modules to not already have "
109
                f"FSDP applied but found {module_name} in\n{root_module}"
110
            )
111

112

113
def _warn_on_overridden_mixed_precision(
114
    overridden_module_classes: Set[Type[nn.Module]],
115
):
116
    if len(overridden_module_classes) == 0:
117
        return
118
    warnings.warn(
119
        "Both mixed precision and an auto_wrap_policy were specified to FSDP, "
120
        f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n"
121
        "These modules will be wrapped as separate FSDP instacnes with mixed "
122
        "precision disabled."
123
    )
124

125

126
def _validate_frozen_params(
127
    root_module: nn.Module,
128
    modules_to_wrap: Set[nn.Module],
129
    ignored_params: Set[nn.Parameter],
130
    use_orig_params: bool,
131
):
132
    """
133
    This checks that, given ``modules_to_wrap``, each module would manage
134
    parameters that are uniformly frozen or non-frozen. This uniformity
135
    requirement is strict for ``use_orig_params=False`` (hard error) and highly
136
    recommended for ``use_orig_params=True`` (user warning).
137
    """
138
    post_order_named_modules = _get_post_order_named_modules(root_module)
139
    visited_modules: Set[nn.Module] = set()
140
    for module_name, module in post_order_named_modules:
141
        if module in modules_to_wrap:
142
            param_to_fqn = _get_managed_param_to_fqn(
143
                module, ignored_params, visited_modules, module_name
144
            )
145
            frozen_param_fqns: List[str] = []
146
            frozen_param_numel = 0
147
            nonfrozen_param_fqns: List[str] = []
148
            nonfrozen_param_numel = 0
149
            for param, fqn in param_to_fqn.items():
150
                if param.requires_grad:
151
                    nonfrozen_param_fqns.append(fqn)
152
                    nonfrozen_param_numel += param.numel()
153
                else:
154
                    frozen_param_fqns.append(fqn)
155
                    frozen_param_numel += param.numel()
156
            if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0:
157
                msg = f"{module_name} has both parameters with requires_grad=True and False."
158
                if use_orig_params:
159
                    total_param_numel = frozen_param_numel + nonfrozen_param_numel
160
                    msg += (
161
                        " We do not recommend wrapping such modules since "
162
                        "the gradient memory usage will be higher than expected "
163
                        f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel "
164
                        "before sharding via reduce-scatter). "
165
                    )
166
                else:
167
                    msg += " FSDP does not support wrapping such modules when use_orig_params=False. "
168
                msg += "If possible, wrap the frozen parameters with FSDP separately.\n"
169
                msg += (
170
                    f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n"
171
                    f"The following parameters have requires_grad=False:\n{frozen_param_fqns}"
172
                )
173
                if use_orig_params:
174
                    warnings.warn(msg)
175
                else:
176
                    raise ValueError(msg)
177

178

179
def _get_post_order_named_modules(
180
    root_module: nn.Module,
181
) -> List[Tuple[str, nn.Module]]:
182
    """
183
    This returns the named modules following a post-order traversal, which is a
184
    valid reverse topological sort. We achieve this using the reverse of a
185
    stack-based DFS order instead of reversing ``root_module.named_modules()``
186
    since the former gives the modules in registration order at each level in
187
    the module tree (as opposed to the reverse), which allows us to error/warn
188
    on the first registered module that violates the condition.
189

190
    For example, consider the following module structure:
191
        M(
192
          S1(),
193
          S2(
194
            SS1(),
195
            SS2(),
196
          ),
197
          S3(),
198
        )
199
    The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse
200
    ``named_modules()`` order is [S3, SS2, SS1, S2, S1, M].
201
    """
202
    visited_modules = {root_module}
203
    stack = [("", root_module)]
204
    # Append and reverse at the end for linear-time algorithm
205
    reverse_post_order_named_modules: List[Tuple[str, nn.Module]] = []
206
    while stack:
207
        module_name, module = stack.pop()
208
        reverse_post_order_named_modules.append((module_name, module))
209
        for child_module_name, child_module in module.named_children():
210
            if child_module is None:  # only for overrides of `named_children()`
211
                continue
212
            if child_module not in visited_modules:
213
                visited_modules.add(child_module)
214
                if module_name != "":
215
                    child_module_name = module_name + "." + child_module_name
216
                stack.append((child_module_name, child_module))
217
    post_order_named_modules = list(reversed(reverse_post_order_named_modules))
218
    return post_order_named_modules
219

220

221
def _get_managed_param_to_fqn(
222
    module_to_wrap: nn.Module,
223
    ignored_params: Set[nn.Parameter],
224
    visited_modules: Set[nn.Module],
225
    root_prefix: str,
226
) -> Dict[nn.Parameter, str]:
227
    """
228
    This returns a dict that maps managed parameter to its FQN for the given
229
    ``module_to_wrap``. The dict's keys are exactly the parameters that would
230
    be managed by the module, where this is achieved by calling this function
231
    on the modules to wrap in reverse topological order, destructively updating
232
    ``visited_modules``, and not traversing into those modules. The FQNs are
233
    prefixed from the root (via ``root_prefix``) to be more informative.
234

235
    NOTE: This function is meant to be called pre-wrapping and iteratively in
236
    reverse topological order to cover the full module tree. This differs from
237
    the ``_get_param_to_fqn()`` function meant to be called post-wrapping and
238
    on the full module tree in one shot. Given those differences, we do not try
239
    to unify the two.
240
    """
241
    param_to_fqn: Dict[nn.Parameter, str] = {}
242
    # Run BFS (or any tree traversal works)
243
    queue = collections.deque([(module_to_wrap, root_prefix)])
244
    visited_modules.add(module_to_wrap)
245
    while queue:
246
        module, prefix = queue.popleft()
247
        for param_name, param in module.named_parameters(recurse=False):
248
            if param not in ignored_params:
249
                fqn = param_name if prefix == "" else prefix + "." + param_name
250
                param_to_fqn[param] = fqn
251
        for child_module_name, child_module in module.named_children():
252
            if child_module is None:  # only for overrides of `named_children()`
253
                continue
254
            if child_module not in visited_modules:
255
                visited_modules.add(child_module)
256
                child_prefix = (
257
                    child_module_name
258
                    if prefix == ""
259
                    else prefix + "." + child_module_name
260
                )
261
                queue.append((child_module, child_prefix))
262
    return param_to_fqn
263

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.