pytorch
606 строк · 22.1 Кб
1# Copyright (c) Facebook, Inc. and its affiliates.
2#
3# This source code is licensed under the BSD license found in the
4# LICENSE file in the root directory of this source tree.
5
6import contextlib
7import copy
8from abc import ABC, abstractmethod
9from typing import (
10Any,
11Callable,
12cast,
13Dict,
14Generator,
15Iterable,
16Optional,
17Sequence,
18Set,
19Tuple,
20Type,
21Union,
22)
23
24import torch.nn as nn
25
26__all__ = [
27"always_wrap_policy",
28"lambda_auto_wrap_policy",
29"transformer_auto_wrap_policy",
30"size_based_auto_wrap_policy",
31"enable_wrap",
32"wrap",
33"CustomPolicy",
34"ModuleWrapPolicy",
35]
36
37
38# NOTE: We intentionally keep this function simple and isolate the complexity
39# to `fn` to enable using this function generically. We may move this to a
40# non-FSDP-specific folder and/or make it public in the future.
41def _post_order_apply(
42root_module: nn.Module,
43fn: Callable[[nn.Module], Optional[nn.Module]],
44):
45"""
46This applies ``fn`` to every module in the module tree of ``root_module``
47following a post-order traversal. If ``fn`` returns an :class:`nn.Module`,
48then this replaces the original module with the newly returned one in the
49tree. Otherwise, ``fn`` should return ``None``, in which case the module is
50not changed.
51"""
52# Track visited modules to avoid visiting shared modules multiple times
53visited_modules: Set[nn.Module] = {root_module}
54
55def _post_order_apply_inner(
56module: nn.Module,
57module_name: str,
58parent_module: Optional[nn.Module],
59):
60for child_module_name, child_module in module.named_children():
61if child_module not in visited_modules:
62visited_modules.add(child_module)
63_post_order_apply_inner(child_module, child_module_name, module)
64optional_module = fn(module)
65if optional_module is not None:
66assert isinstance(parent_module, nn.Module), (
67"Non-root modules should have their parent module set but got "
68f"{parent_module} for {module}"
69)
70assert module_name, (
71"Non-root modules should have their module name set but got "
72f"an empty module name for {module}"
73)
74assert isinstance(
75optional_module, nn.Module
76), f"fn should return None or an nn.Module but got {optional_module}"
77setattr(parent_module, module_name, optional_module)
78
79_post_order_apply_inner(root_module, "", None)
80
81
82def _construct_wrap_fn(
83root_module: nn.Module,
84target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
85fsdp_fn: Callable,
86) -> Callable[[nn.Module], Optional[nn.Module]]:
87"""
88This constructs the "wrap" function to pass to :func:`_post_order_apply`
89based on ``target_module_to_kwargs``, which should be constructed from the
90wrapping policy.
91"""
92
93def fn(module: nn.Module) -> Optional[nn.Module]:
94# Explicitly avoid wrapping the root module since for FSDP, it is
95# handled by the caller
96if module in target_module_to_kwargs and module is not root_module:
97kwargs = target_module_to_kwargs[module]
98return fsdp_fn(module, **kwargs)
99return None
100
101return fn
102
103
104def _run_mixed_precision_override_policy(
105root_module: nn.Module,
106module_classes: Iterable[Type[nn.Module]],
107ignored_modules: Set[nn.Module],
108root_kwargs: Dict[str, Any],
109target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
110):
111module_classes_tuple = tuple(set(module_classes))
112for module in root_module.modules():
113if module in ignored_modules:
114continue
115elif isinstance(module, module_classes_tuple):
116# This policy overrides any existing policy
117if module not in target_module_to_kwargs:
118# Only inherit from the root kwargs if not already specified
119target_module_to_kwargs[module] = root_kwargs
120target_module_to_kwargs[module]["mixed_precision"] = None
121return target_module_to_kwargs
122
123
124def always_wrap_policy(*args, **kwargs) -> bool:
125"""
126A simple recursive wrap policy that always returns ``True``. This means
127that every submodule is wrapped by the wrapper class in
128:func:`_recursive_wrap`.
129"""
130return True
131
132
133class _Policy(ABC):
134"""
135This defines an abstract base class that represents a policy for applying
136a module-level API.
137"""
138
139@abstractmethod
140def _run_policy(
141self,
142root_module: nn.Module,
143ignored_modules: Set[nn.Module],
144root_kwargs: Dict[str, Any],
145) -> Dict[nn.Module, Dict[str, Any]]:
146"""
147This should return a dict ``target_module_to_kwargs`` that maps from
148each target module to wrap to its kwargs.
149"""
150...
151
152
153def _module_wrap_policy(
154module: nn.Module,
155recurse: bool,
156nonwrapped_numel: int,
157module_classes: Set[Type[nn.Module]],
158) -> bool:
159"""
160This auto wrap policy wraps every module that is an instance of any type in
161``module_classes`` as its own FSDP instance. The root module given by
162``module`` is always wrapped as an FSDP instance regardless. Since the
163wrapping proceeds bottom up, each FSDP instance manages the parameters in
164its subtree excluding any already managed by a child FSDP instance.
165
166Args:
167module (nn.Module): Current module being considered.
168recurse (bool): If ``False``, then this function must decide whether
169``module`` should be wrapped as an FSDP instance or not. If
170``True``, then the function is still recursing down the module
171tree as a part of the DFS.
172nonwrapped_numel (int): Parameter numel not yet wrapped.
173module_classes (Set[Type[nn.Module]]): Set of module classes that are
174wrapped as FSDP instances.
175
176Returns:
177``True`` if ``recurse=True``, and whether ``module`` should be wrapped
178if ``recurse=False``.
179"""
180if recurse:
181return True # always recurse
182return isinstance(module, tuple(module_classes))
183
184
185class ModuleWrapPolicy(_Policy):
186"""
187This policy applies to every module of the specified module classes,
188passing in the kwargs given to the root.
189"""
190
191def __init__(self, module_classes: Iterable[Type[nn.Module]]):
192module_classes_set = set(module_classes)
193self._module_classes = module_classes_set
194self._module_classes_str = str(module_classes_set)
195
196def _run_policy(
197self,
198root_module: nn.Module,
199ignored_modules: Set[nn.Module],
200root_kwargs: Dict[str, Any],
201) -> Dict[nn.Module, Dict[str, Any]]:
202module_classes = tuple(self._module_classes)
203target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
204for module in root_module.modules():
205if module in ignored_modules:
206continue
207elif isinstance(module, module_classes):
208# Shallow copy to avoid coupling changes across modules
209target_module_to_kwargs[module] = copy.copy(root_kwargs)
210return target_module_to_kwargs
211
212def __call__(self, module, recurse, *args, **kwargs):
213# nonwrapped_numel is not used.
214return _module_wrap_policy(
215module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes
216)
217
218def __repr__(self) -> str:
219return super().__repr__() + f"({self._module_classes_str})"
220
221
222class CustomPolicy(_Policy):
223"""
224This policy takes in a lambda function that maps a given ``nn.Module`` to
225either ``False``, ``True``, or a kwarg dictionary.
226- If the function returns ``False`` or an empty dictionary, then the module
227does not have the API applied.
228- If the function returns ``True``, then the module has the API applied
229with the root's kwargs.
230- If the function returns a non-empty dictionary, then the module has the
231API applied, and the dictionary overrides the root's kwargs.
232
233Example::
234
235>>> # xdoctest: +SKIP("undefined variables")
236>>> model = init_transformer_model(...)
237>>> def lambda_fn(module: nn.Module):
238>>> if module is model.lm_head:
239>>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}
240>>> elif isinstance(module, TransformerBlock):
241>>> return True
242>>> return False
243>>> policy = CustomPolicy(lambda_fn)
244>>> fsdp_model = FSDP(model, auto_wrap_policy=policy)
245"""
246
247def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]):
248self._lambda_fn = lambda_fn
249
250def _run_policy(
251self,
252root_module: nn.Module,
253ignored_modules: Set[nn.Module],
254root_kwargs: Dict[str, Any],
255) -> Dict[nn.Module, Dict[str, Any]]:
256target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
257for module in root_module.modules():
258if module in ignored_modules:
259continue
260res = self._lambda_fn(module)
261if not isinstance(res, (dict, bool)):
262raise ValueError(
263"The lambda_fn passed to CustomPolicy should return "
264f"False/True or a kwarg dict, but it returned {res}"
265)
266if not res:
267continue
268kwargs = copy.copy(root_kwargs)
269if isinstance(res, dict):
270# Override the root kwargs with the ones specified by the
271# lambda function
272kwargs.update(res)
273target_module_to_kwargs[module] = kwargs
274return target_module_to_kwargs
275
276
277def lambda_auto_wrap_policy(
278module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
279) -> bool:
280"""
281A convenient auto wrap policy to wrap submodules based on an arbitrary user
282function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
283a `wrapper_cls` unit.
284
285Return if a module should be wrapped during auto wrapping.
286
287The first three parameters are required by :func:`_recursive_wrap`.
288
289Args:
290module (nn.Module): Current module being considered.
291recurse (bool): If ``False``, then this function must decide whether
292``module`` should be wrapped as an FSDP instance or not. If
293``True``, then the function is still recursing down the module
294tree as a part of the DFS.
295nonwrapped_numel (int): Parameter numel not yet wrapped.
296
297lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
298this module will be wrapped.
299"""
300if recurse:
301return True # always recurse
302return lambda_fn(module)
303
304
305def transformer_auto_wrap_policy(
306module: nn.Module,
307recurse: bool,
308nonwrapped_numel: int,
309transformer_layer_cls: Set[Type[nn.Module]],
310) -> bool:
311"""
312See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
313same as ``module_classes``. Note that shared parameters must be wrapped in
314the same FSDP instance, so this auto wrap policy can help wrap shared
315embeddings into the same FSDP instance for transformer models.
316"""
317return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
318
319
320def _wrap_module_cls_individually(
321module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs
322):
323if recurse:
324# always recurse
325return True
326else:
327# if not recursing, decide whether we should wrap based on whether the type of module
328# is in `module_classes`.
329return isinstance(module, tuple(module_classes))
330
331
332def _or_policy(
333module: nn.Module,
334recurse: bool,
335nonwrapped_numel: int,
336policies,
337) -> bool:
338"""
339A policy that wraps ``module`` if any policy in the passed in iterable of
340``policies`` returns ``True``.
341"""
342return any(
343policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel)
344for policy in policies
345)
346
347
348def size_based_auto_wrap_policy(
349module: nn.Module,
350recurse: bool,
351nonwrapped_numel: int,
352# Additional custom arguments
353min_num_params: int = int(1e8),
354force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
355exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
356) -> bool:
357"""
358A size-based auto wrap policy.
359
360Args:
361module (nn.Module): Current module being considered.
362recurse (bool): If ``False``, then this function must decide whether
363``module`` should be wrapped as an FSDP instance or not. If
364``True``, then the function is still recursing down the module
365tree as a part of the DFS.
366nonwrapped_numel (int): Parameter numel not yet wrapped.
367
368min_num_params (int): Customizable policy input that controls the size
369threshold over which a module is ready to be wrapped. This is in
370units of numel.
371force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
372as leaves, i.e. their children will never be wrapped.
373exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
374excluded in wrapping.
375
376Returns:
377Whether ``module`` should be wrapped.
378"""
379force_leaf_modules = (
380size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
381if force_leaf_modules is None
382else force_leaf_modules
383)
384exclude_wrap_modules = (
385size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
386if exclude_wrap_modules is None
387else exclude_wrap_modules
388)
389
390# Keep the argument `min_num_params` for BC for now, but it represents the
391# minimum non-wrapped *numel* before triggering a wrapping
392min_nonwrapped_numel = min_num_params
393is_large = nonwrapped_numel >= min_nonwrapped_numel
394if recurse:
395# We should recurse if the module is big enough but not in force_leaf_modules list.
396return is_large and not isinstance(module, tuple(force_leaf_modules))
397else:
398# If we are not recursing, determine if we should wrap.
399return is_large and not isinstance(module, tuple(exclude_wrap_modules))
400
401
402# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
403size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined]
404size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined]
405
406
407@contextlib.contextmanager
408def enable_wrap(
409*, wrapper_cls: Any, **wrapper_kwargs: Any
410) -> Generator[None, None, None]:
411"""
412Context manager to wrap modules using a wrapper.
413
414Useful for when you'd like to apply the same configuration arguments to all
415child modules that you wrap. A particularly important use case is wrapping
416large layers so that they get sharded (in-place) during initialization, to
417avoid running out of system memory. Large layers can indicate that they
418should be sharded via the ``wrap`` annotation and this context manager can
419provide the exact configuration for these nested instances.
420
421Usage::
422
423with enable_wrap(wrapper_cls, **params):
424# Wraps layer in FSDP by default if within context
425self.l1 = wrap(torch.nn.Linear(5, 5))
426
427Args:
428wrapper_cls:
429Class that `wrap` annotation will `wrap` modules with, such as
430`FullyShardedDataParallel`.
431**wrapper_kwargs:
432Configuration settings that will be passed to all ``wrap``
433instances inside the context
434"""
435kwargs = {
436"wrapper_cls": wrapper_cls,
437**wrapper_kwargs,
438}
439with _ConfigAutoWrap(**kwargs):
440yield
441
442
443def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
444"""
445Annotate that a module should be wrapped. Annotated modules will only be
446wrapped if inside of an :func:`enable_wrap` context manager. This allows
447a module to be initialized both with and without a wrapper without code
448change.
449
450The class that this function wraps the passed in ``nn.Module`` with is the
451passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
452``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
453the ``wrapper_cls`` instance. In the case of duplicate kwargs in
454``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
455respected.
456
457Usage::
458
459with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
460# Wraps layer in FSDP by default if within context
461self.l1 = wrap(torch.nn.Linear(5, 5))
462
463Args:
464module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
465**wrap_overrides: configuration overrides that will take priority over
466the values provided by the :func:`enable_wrap` context
467"""
468if _ConfigAutoWrap.in_autowrap_context:
469assert _ConfigAutoWrap.wrapper_cls is not None
470
471wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
472return _wrap(
473module,
474_ConfigAutoWrap.wrapper_cls,
475**wrap_overrides,
476)
477return module
478
479
480def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
481assert wrapper_cls is not None
482if hasattr(module, "_wrap_overrides"):
483# If module has a _wrap_overrides attribute, we force overriding the
484# FSDP config with these attributes for this module. Currently this
485# is only used to disable mixed precision for BatchNorm when
486# auto_wrapping.
487overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type]
488return wrapper_cls(module, **overrides)
489
490return wrapper_cls(module, **kwargs)
491
492
493def _recursive_wrap(
494module: nn.Module,
495auto_wrap_policy: Callable,
496wrapper_cls: Callable,
497ignored_modules: Set[nn.Module],
498ignored_params: Set[nn.Parameter],
499only_wrap_children: bool = False,
500**kwargs: Any,
501) -> Tuple[nn.Module, int]:
502"""
503Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
504``True`` with ``wrapper_cls``.
505
506Args:
507module (nn.Module): Module to recursively wrap.
508auto_wrap_policy (Callable): A callable representing a policy that
509determines which modules to recursively wrap with ``wrapper_cls``.
510ignored_modules (Set[torch.nn.Module]): Modules to ignore when
511wrapping.
512ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
513wrapping; these should be the parameters contained in the modules
514in ``ignored_modules``.
515Returns:
516(nn.Module, int):
517``module`` after wrapping and the numel recursively wrapped.
518"""
519assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
520assert wrapper_cls is not None, "Must specify wrapper_cls"
521# Make sure no child is already wrapped.
522for _, child in module.named_modules():
523if child in ignored_modules:
524continue
525try:
526assert not isinstance(child, cast(type, wrapper_cls))
527except TypeError:
528# wrapper_cls is a function as opposed to a class type, just bypass above check.
529pass
530
531# We count all params, assuming none of them are already wrapped.
532nonwrapped_numel = sum(
533p.numel() for p in module.parameters() if p not in ignored_params
534)
535
536assert auto_wrap_policy is not None
537if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
538total_wrapped_numel = 0
539# Iterate through the children, recursively wrap if necessary
540for name, child in module.named_children():
541if child in ignored_modules:
542continue
543wrapped_child, num_wrapped_params = _recursive_wrap(
544module=child,
545auto_wrap_policy=auto_wrap_policy,
546wrapper_cls=wrapper_cls,
547ignored_modules=ignored_modules,
548ignored_params=ignored_params,
549**kwargs,
550)
551setattr(module, name, wrapped_child)
552# Keep track of how many parameters have been wrapped
553total_wrapped_numel += num_wrapped_params
554# decide if we need to wrap the current module,
555# since the left over parameters exceed the number of params to wrap
556remainder = nonwrapped_numel - total_wrapped_numel
557if not only_wrap_children and auto_wrap_policy(
558module=module, recurse=False, nonwrapped_numel=remainder
559):
560# Leaf node or final wrapping of the remainder both happen here.
561return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
562else:
563return module, total_wrapped_numel
564return module, 0
565
566
567class _ConfigAutoWrap:
568"""
569Helper class to wrap modules based on default config args via a context manager.
570See :func:`enable_wrap` for more information.
571"""
572
573in_autowrap_context: bool = False # Context flag
574wrapper_cls: Optional[Callable] = None # The wrapper class
575kwargs: Dict[str, Any] = {} # Wrapper's args
576
577def __init__(self, **kwargs: Dict[str, Any]):
578self.kwargs = kwargs
579
580@staticmethod
581def enable_autowrap_context(kwargs: Any) -> None:
582if _ConfigAutoWrap.in_autowrap_context:
583raise NotImplementedError(
584"You are already within an autowrap context and we currently do not supported nested autowrap."
585)
586_ConfigAutoWrap.in_autowrap_context = True
587# Get and save the wrapper cls for the context.
588assert (
589"wrapper_cls" in kwargs.keys()
590), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
591_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
592del kwargs["wrapper_cls"]
593# Save the rest.
594_ConfigAutoWrap.kwargs = kwargs
595
596@staticmethod
597def disable_autowrap_context() -> None:
598_ConfigAutoWrap.in_autowrap_context = False
599_ConfigAutoWrap.wrapper_cls = None
600_ConfigAutoWrap.kwargs = {}
601
602def __enter__(self) -> None:
603self.enable_autowrap_context(self.kwargs)
604
605def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
606self.disable_autowrap_context()
607