pytorch

Форк
0
606 строк · 22.1 Кб
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
#
3
# This source code is licensed under the BSD license found in the
4
# LICENSE file in the root directory of this source tree.
5

6
import contextlib
7
import copy
8
from abc import ABC, abstractmethod
9
from typing import (
10
    Any,
11
    Callable,
12
    cast,
13
    Dict,
14
    Generator,
15
    Iterable,
16
    Optional,
17
    Sequence,
18
    Set,
19
    Tuple,
20
    Type,
21
    Union,
22
)
23

24
import torch.nn as nn
25

26
__all__ = [
27
    "always_wrap_policy",
28
    "lambda_auto_wrap_policy",
29
    "transformer_auto_wrap_policy",
30
    "size_based_auto_wrap_policy",
31
    "enable_wrap",
32
    "wrap",
33
    "CustomPolicy",
34
    "ModuleWrapPolicy",
35
]
36

37

38
# NOTE: We intentionally keep this function simple and isolate the complexity
39
# to `fn` to enable using this function generically. We may move this to a
40
# non-FSDP-specific folder and/or make it public in the future.
41
def _post_order_apply(
42
    root_module: nn.Module,
43
    fn: Callable[[nn.Module], Optional[nn.Module]],
44
):
45
    """
46
    This applies ``fn`` to every module in the module tree of ``root_module``
47
    following a post-order traversal. If ``fn`` returns an :class:`nn.Module`,
48
    then this replaces the original module with the newly returned one in the
49
    tree. Otherwise, ``fn`` should return ``None``, in which case the module is
50
    not changed.
51
    """
52
    # Track visited modules to avoid visiting shared modules multiple times
53
    visited_modules: Set[nn.Module] = {root_module}
54

55
    def _post_order_apply_inner(
56
        module: nn.Module,
57
        module_name: str,
58
        parent_module: Optional[nn.Module],
59
    ):
60
        for child_module_name, child_module in module.named_children():
61
            if child_module not in visited_modules:
62
                visited_modules.add(child_module)
63
                _post_order_apply_inner(child_module, child_module_name, module)
64
        optional_module = fn(module)
65
        if optional_module is not None:
66
            assert isinstance(parent_module, nn.Module), (
67
                "Non-root modules should have their parent module set but got "
68
                f"{parent_module} for {module}"
69
            )
70
            assert module_name, (
71
                "Non-root modules should have their module name set but got "
72
                f"an empty module name for {module}"
73
            )
74
            assert isinstance(
75
                optional_module, nn.Module
76
            ), f"fn should return None or an nn.Module but got {optional_module}"
77
            setattr(parent_module, module_name, optional_module)
78

79
    _post_order_apply_inner(root_module, "", None)
80

81

82
def _construct_wrap_fn(
83
    root_module: nn.Module,
84
    target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
85
    fsdp_fn: Callable,
86
) -> Callable[[nn.Module], Optional[nn.Module]]:
87
    """
88
    This constructs the "wrap" function to pass to :func:`_post_order_apply`
89
    based on ``target_module_to_kwargs``, which should be constructed from the
90
    wrapping policy.
91
    """
92

93
    def fn(module: nn.Module) -> Optional[nn.Module]:
94
        # Explicitly avoid wrapping the root module since for FSDP, it is
95
        # handled by the caller
96
        if module in target_module_to_kwargs and module is not root_module:
97
            kwargs = target_module_to_kwargs[module]
98
            return fsdp_fn(module, **kwargs)
99
        return None
100

101
    return fn
102

103

104
def _run_mixed_precision_override_policy(
105
    root_module: nn.Module,
106
    module_classes: Iterable[Type[nn.Module]],
107
    ignored_modules: Set[nn.Module],
108
    root_kwargs: Dict[str, Any],
109
    target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
110
):
111
    module_classes_tuple = tuple(set(module_classes))
112
    for module in root_module.modules():
113
        if module in ignored_modules:
114
            continue
115
        elif isinstance(module, module_classes_tuple):
116
            # This policy overrides any existing policy
117
            if module not in target_module_to_kwargs:
118
                # Only inherit from the root kwargs if not already specified
119
                target_module_to_kwargs[module] = root_kwargs
120
            target_module_to_kwargs[module]["mixed_precision"] = None
121
    return target_module_to_kwargs
122

123

124
def always_wrap_policy(*args, **kwargs) -> bool:
125
    """
126
    A simple recursive wrap policy that always returns ``True``. This means
127
    that every submodule is wrapped by the wrapper class in
128
    :func:`_recursive_wrap`.
129
    """
130
    return True
131

132

133
class _Policy(ABC):
134
    """
135
    This defines an abstract base class that represents a policy for applying
136
    a module-level API.
137
    """
138

139
    @abstractmethod
140
    def _run_policy(
141
        self,
142
        root_module: nn.Module,
143
        ignored_modules: Set[nn.Module],
144
        root_kwargs: Dict[str, Any],
145
    ) -> Dict[nn.Module, Dict[str, Any]]:
146
        """
147
        This should return a dict ``target_module_to_kwargs`` that maps from
148
        each target module to wrap to its kwargs.
149
        """
150
        ...
151

152

153
def _module_wrap_policy(
154
    module: nn.Module,
155
    recurse: bool,
156
    nonwrapped_numel: int,
157
    module_classes: Set[Type[nn.Module]],
158
) -> bool:
159
    """
160
    This auto wrap policy wraps every module that is an instance of any type in
161
    ``module_classes`` as its own FSDP instance. The root module given by
162
    ``module`` is always wrapped as an FSDP instance regardless. Since the
163
    wrapping proceeds bottom up, each FSDP instance manages the parameters in
164
    its subtree excluding any already managed by a child FSDP instance.
165

166
    Args:
167
        module (nn.Module): Current module being considered.
168
        recurse (bool): If ``False``, then this function must decide whether
169
            ``module`` should be wrapped as an FSDP instance or not. If
170
            ``True``, then the function is still recursing down the module
171
            tree as a part of the DFS.
172
        nonwrapped_numel (int): Parameter numel not yet wrapped.
173
        module_classes (Set[Type[nn.Module]]): Set of module classes that are
174
            wrapped as FSDP instances.
175

176
    Returns:
177
        ``True`` if ``recurse=True``, and whether ``module`` should be wrapped
178
        if ``recurse=False``.
179
    """
180
    if recurse:
181
        return True  # always recurse
182
    return isinstance(module, tuple(module_classes))
183

184

185
class ModuleWrapPolicy(_Policy):
186
    """
187
    This policy applies to every module of the specified module classes,
188
    passing in the kwargs given to the root.
189
    """
190

191
    def __init__(self, module_classes: Iterable[Type[nn.Module]]):
192
        module_classes_set = set(module_classes)
193
        self._module_classes = module_classes_set
194
        self._module_classes_str = str(module_classes_set)
195

196
    def _run_policy(
197
        self,
198
        root_module: nn.Module,
199
        ignored_modules: Set[nn.Module],
200
        root_kwargs: Dict[str, Any],
201
    ) -> Dict[nn.Module, Dict[str, Any]]:
202
        module_classes = tuple(self._module_classes)
203
        target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
204
        for module in root_module.modules():
205
            if module in ignored_modules:
206
                continue
207
            elif isinstance(module, module_classes):
208
                # Shallow copy to avoid coupling changes across modules
209
                target_module_to_kwargs[module] = copy.copy(root_kwargs)
210
        return target_module_to_kwargs
211

212
    def __call__(self, module, recurse, *args, **kwargs):
213
        # nonwrapped_numel is not used.
214
        return _module_wrap_policy(
215
            module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes
216
        )
217

218
    def __repr__(self) -> str:
219
        return super().__repr__() + f"({self._module_classes_str})"
220

221

222
class CustomPolicy(_Policy):
223
    """
224
    This policy takes in a lambda function that maps a given ``nn.Module`` to
225
    either ``False``, ``True``, or a kwarg dictionary.
226
    - If the function returns ``False`` or an empty dictionary, then the module
227
      does not have the API applied.
228
    - If the function returns ``True``, then the module has the API applied
229
      with the root's kwargs.
230
    - If the function returns a non-empty dictionary, then the module has the
231
      API applied, and the dictionary overrides the root's kwargs.
232

233
    Example::
234

235
        >>> # xdoctest: +SKIP("undefined variables")
236
        >>> model = init_transformer_model(...)
237
        >>> def lambda_fn(module: nn.Module):
238
        >>>     if module is model.lm_head:
239
        >>>         return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
240
        >>>     elif isinstance(module, TransformerBlock):
241
        >>>         return True
242
        >>>     return False
243
        >>> policy = CustomPolicy(lambda_fn)
244
        >>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
245
    """
246

247
    def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
248
        self._lambda_fn = lambda_fn
249

250
    def _run_policy(
251
        self,
252
        root_module: nn.Module,
253
        ignored_modules: Set[nn.Module],
254
        root_kwargs: Dict[str, Any],
255
    ) -> Dict[nn.Module, Dict[str, Any]]:
256
        target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
257
        for module in root_module.modules():
258
            if module in ignored_modules:
259
                continue
260
            res = self._lambda_fn(module)
261
            if not isinstance(res, (dict, bool)):
262
                raise ValueError(
263
                    "The lambda_fn passed to CustomPolicy should return "
264
                    f"False/True or a kwarg dict, but it returned {res}"
265
                )
266
            if not res:
267
                continue
268
            kwargs = copy.copy(root_kwargs)
269
            if isinstance(res, dict):
270
                # Override the root kwargs with the ones specified by the
271
                # lambda function
272
                kwargs.update(res)
273
            target_module_to_kwargs[module] = kwargs
274
        return target_module_to_kwargs
275

276

277
def lambda_auto_wrap_policy(
278
    module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
279
) -> bool:
280
    """
281
    A convenient auto wrap policy to wrap submodules based on an arbitrary user
282
    function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
283
    a `wrapper_cls` unit.
284

285
    Return if a module should be wrapped during auto wrapping.
286

287
    The first three parameters are required by :func:`_recursive_wrap`.
288

289
    Args:
290
        module (nn.Module): Current module being considered.
291
        recurse (bool): If ``False``, then this function must decide whether
292
            ``module`` should be wrapped as an FSDP instance or not. If
293
            ``True``, then the function is still recursing down the module
294
            tree as a part of the DFS.
295
        nonwrapped_numel (int): Parameter numel not yet wrapped.
296

297
        lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
298
            this module will be wrapped.
299
    """
300
    if recurse:
301
        return True  # always recurse
302
    return lambda_fn(module)
303

304

305
def transformer_auto_wrap_policy(
306
    module: nn.Module,
307
    recurse: bool,
308
    nonwrapped_numel: int,
309
    transformer_layer_cls: Set[Type[nn.Module]],
310
) -> bool:
311
    """
312
    See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
313
    same as ``module_classes``. Note that shared parameters must be wrapped in
314
    the same FSDP instance, so this auto wrap policy can help wrap shared
315
    embeddings into the same FSDP instance for transformer models.
316
    """
317
    return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
318

319

320
def _wrap_module_cls_individually(
321
    module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs
322
):
323
    if recurse:
324
        # always recurse
325
        return True
326
    else:
327
        # if not recursing, decide whether we should wrap based on whether the type of module
328
        # is in `module_classes`.
329
        return isinstance(module, tuple(module_classes))
330

331

332
def _or_policy(
333
    module: nn.Module,
334
    recurse: bool,
335
    nonwrapped_numel: int,
336
    policies,
337
) -> bool:
338
    """
339
    A policy that wraps ``module`` if any policy in the passed in iterable of
340
    ``policies`` returns ``True``.
341
    """
342
    return any(
343
        policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel)
344
        for policy in policies
345
    )
346

347

348
def size_based_auto_wrap_policy(
349
    module: nn.Module,
350
    recurse: bool,
351
    nonwrapped_numel: int,
352
    # Additional custom arguments
353
    min_num_params: int = int(1e8),
354
    force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
355
    exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
356
) -> bool:
357
    """
358
    A size-based auto wrap policy.
359

360
    Args:
361
        module (nn.Module): Current module being considered.
362
        recurse (bool): If ``False``, then this function must decide whether
363
            ``module`` should be wrapped as an FSDP instance or not. If
364
            ``True``, then the function is still recursing down the module
365
            tree as a part of the DFS.
366
        nonwrapped_numel (int): Parameter numel not yet wrapped.
367

368
        min_num_params (int): Customizable policy input that controls the size
369
            threshold over which a module is ready to be wrapped. This is in
370
            units of numel.
371
        force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
372
            as leaves, i.e. their children will never be wrapped.
373
        exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
374
            excluded in wrapping.
375

376
    Returns:
377
        Whether ``module`` should be wrapped.
378
    """
379
    force_leaf_modules = (
380
        size_based_auto_wrap_policy.FORCE_LEAF_MODULES  # type: ignore[attr-defined]
381
        if force_leaf_modules is None
382
        else force_leaf_modules
383
    )
384
    exclude_wrap_modules = (
385
        size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES  # type: ignore[attr-defined]
386
        if exclude_wrap_modules is None
387
        else exclude_wrap_modules
388
    )
389

390
    # Keep the argument `min_num_params` for BC for now, but it represents the
391
    # minimum non-wrapped *numel* before triggering a wrapping
392
    min_nonwrapped_numel = min_num_params
393
    is_large = nonwrapped_numel >= min_nonwrapped_numel
394
    if recurse:
395
        # We should recurse if the module is big enough but not in force_leaf_modules list.
396
        return is_large and not isinstance(module, tuple(force_leaf_modules))
397
    else:
398
        # If we are not recursing, determine if we should wrap.
399
        return is_large and not isinstance(module, tuple(exclude_wrap_modules))
400

401

402
# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
403
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict}  # type: ignore[attr-defined]
404
size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention}  # type: ignore[attr-defined]
405

406

407
@contextlib.contextmanager
408
def enable_wrap(
409
    *, wrapper_cls: Any, **wrapper_kwargs: Any
410
) -> Generator[None, None, None]:
411
    """
412
    Context manager to wrap modules using a wrapper.
413

414
    Useful for when you'd like to apply the same configuration arguments to all
415
    child modules that you wrap. A particularly important use case is wrapping
416
    large layers so that they get sharded (in-place) during initialization, to
417
    avoid running out of system memory. Large layers can indicate that they
418
    should be sharded via the ``wrap`` annotation and this context manager can
419
    provide the exact configuration for these nested instances.
420

421
    Usage::
422

423
        with enable_wrap(wrapper_cls, **params):
424
            # Wraps layer in FSDP by default if within context
425
            self.l1 = wrap(torch.nn.Linear(5, 5))
426

427
    Args:
428
        wrapper_cls:
429
            Class that `wrap` annotation will `wrap` modules with, such as
430
            `FullyShardedDataParallel`.
431
        **wrapper_kwargs:
432
            Configuration settings that will be passed to all ``wrap``
433
            instances inside the context
434
    """
435
    kwargs = {
436
        "wrapper_cls": wrapper_cls,
437
        **wrapper_kwargs,
438
    }
439
    with _ConfigAutoWrap(**kwargs):
440
        yield
441

442

443
def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
444
    """
445
    Annotate that a module should be wrapped. Annotated modules will only be
446
    wrapped if inside of an :func:`enable_wrap` context manager. This allows
447
    a module to be initialized both with and without a wrapper without code
448
    change.
449

450
    The class that this function wraps the passed in ``nn.Module`` with is the
451
    passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
452
    ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
453
    the ``wrapper_cls`` instance. In the case of duplicate kwargs in
454
    ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
455
    respected.
456

457
    Usage::
458

459
        with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
460
            # Wraps layer in FSDP by default if within context
461
            self.l1 = wrap(torch.nn.Linear(5, 5))
462

463
    Args:
464
        module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
465
        **wrap_overrides: configuration overrides that will take priority over
466
            the values provided by the :func:`enable_wrap` context
467
    """
468
    if _ConfigAutoWrap.in_autowrap_context:
469
        assert _ConfigAutoWrap.wrapper_cls is not None
470

471
        wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
472
        return _wrap(
473
            module,
474
            _ConfigAutoWrap.wrapper_cls,
475
            **wrap_overrides,
476
        )
477
    return module
478

479

480
def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
481
    assert wrapper_cls is not None
482
    if hasattr(module, "_wrap_overrides"):
483
        # If module has a _wrap_overrides attribute, we force overriding the
484
        # FSDP config with these attributes for this module. Currently this
485
        # is only used to disable mixed precision for BatchNorm when
486
        # auto_wrapping.
487
        overrides = {**kwargs, **module._wrap_overrides}  # type: ignore[arg-type]
488
        return wrapper_cls(module, **overrides)
489

490
    return wrapper_cls(module, **kwargs)
491

492

493
def _recursive_wrap(
494
    module: nn.Module,
495
    auto_wrap_policy: Callable,
496
    wrapper_cls: Callable,
497
    ignored_modules: Set[nn.Module],
498
    ignored_params: Set[nn.Parameter],
499
    only_wrap_children: bool = False,
500
    **kwargs: Any,
501
) -> Tuple[nn.Module, int]:
502
    """
503
    Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
504
    ``True`` with ``wrapper_cls``.
505

506
    Args:
507
        module (nn.Module): Module to recursively wrap.
508
        auto_wrap_policy (Callable): A callable representing a policy that
509
            determines which modules to recursively wrap with ``wrapper_cls``.
510
        ignored_modules (Set[torch.nn.Module]): Modules to ignore when
511
            wrapping.
512
        ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
513
            wrapping; these should be the parameters contained in the modules
514
            in ``ignored_modules``.
515
    Returns:
516
        (nn.Module, int):
517
            ``module`` after wrapping and the numel recursively wrapped.
518
    """
519
    assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
520
    assert wrapper_cls is not None, "Must specify wrapper_cls"
521
    # Make sure no child is already wrapped.
522
    for _, child in module.named_modules():
523
        if child in ignored_modules:
524
            continue
525
        try:
526
            assert not isinstance(child, cast(type, wrapper_cls))
527
        except TypeError:
528
            # wrapper_cls is a function as opposed to a class type, just bypass above check.
529
            pass
530

531
    # We count all params, assuming none of them are already wrapped.
532
    nonwrapped_numel = sum(
533
        p.numel() for p in module.parameters() if p not in ignored_params
534
    )
535

536
    assert auto_wrap_policy is not None
537
    if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
538
        total_wrapped_numel = 0
539
        # Iterate through the children, recursively wrap if necessary
540
        for name, child in module.named_children():
541
            if child in ignored_modules:
542
                continue
543
            wrapped_child, num_wrapped_params = _recursive_wrap(
544
                module=child,
545
                auto_wrap_policy=auto_wrap_policy,
546
                wrapper_cls=wrapper_cls,
547
                ignored_modules=ignored_modules,
548
                ignored_params=ignored_params,
549
                **kwargs,
550
            )
551
            setattr(module, name, wrapped_child)
552
            # Keep track of how many parameters have been wrapped
553
            total_wrapped_numel += num_wrapped_params
554
        # decide if we need to wrap the current module,
555
        # since the left over parameters exceed the number of params to wrap
556
        remainder = nonwrapped_numel - total_wrapped_numel
557
        if not only_wrap_children and auto_wrap_policy(
558
            module=module, recurse=False, nonwrapped_numel=remainder
559
        ):
560
            # Leaf node or final wrapping of the remainder both happen here.
561
            return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
562
        else:
563
            return module, total_wrapped_numel
564
    return module, 0
565

566

567
class _ConfigAutoWrap:
568
    """
569
    Helper class to wrap modules based on default config args via a context manager.
570
    See :func:`enable_wrap` for more information.
571
    """
572

573
    in_autowrap_context: bool = False  # Context flag
574
    wrapper_cls: Optional[Callable] = None  # The wrapper class
575
    kwargs: Dict[str, Any] = {}  # Wrapper's args
576

577
    def __init__(self, **kwargs: Dict[str, Any]):
578
        self.kwargs = kwargs
579

580
    @staticmethod
581
    def enable_autowrap_context(kwargs: Any) -> None:
582
        if _ConfigAutoWrap.in_autowrap_context:
583
            raise NotImplementedError(
584
                "You are already within an autowrap context and we currently do not supported nested autowrap."
585
            )
586
        _ConfigAutoWrap.in_autowrap_context = True
587
        # Get and save the wrapper cls for the context.
588
        assert (
589
            "wrapper_cls" in kwargs.keys()
590
        ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
591
        _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
592
        del kwargs["wrapper_cls"]
593
        # Save the rest.
594
        _ConfigAutoWrap.kwargs = kwargs
595

596
    @staticmethod
597
    def disable_autowrap_context() -> None:
598
        _ConfigAutoWrap.in_autowrap_context = False
599
        _ConfigAutoWrap.wrapper_cls = None
600
        _ConfigAutoWrap.kwargs = {}
601

602
    def __enter__(self) -> None:
603
        self.enable_autowrap_context(self.kwargs)
604

605
    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
606
        self.disable_autowrap_context()
607

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.