pytorch
575 строк · 20.5 Кб
1from abc import ABC, abstractmethod2from contextlib import contextmanager, nullcontext3from copy import copy4from dataclasses import dataclass5from functools import partial, wraps6from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union7
8from functorch import make_fx9
10import torch11import torch.distributed as dist12
13# We need to import _functional_collectives to trigger op registration
14import torch.distributed._functional_collectives15import torch.nn as nn16import torch.utils._pytree as pytree17
18from torch import fx19from torch._decomp.decompositions import native_layer_norm_backward20
21from torch._subclasses.fake_tensor import FakeTensorMode22from torch.distributed._spmd.data_parallel import gradients_tagging23from torch.distributed._spmd.parallel_mode import (24DataParallel,25DTensorExpandMode,26ParallelMode,27)
28from torch.distributed._tensor import Placement29from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo, CodeGen30from torch.nn.utils import stateless31from torch.nn.utils._named_member_accessor import NamedMemberAccessor32
33
34class Override(ABC):35r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.36
37This is useful when any part of the model is not traceable or if you prefer
38to not trace it due to any reason. More specifically, users can implement
39:meth:`torch.distributed._spmd.Override.replacement` to replace an original
40submodule with the return new submodule. The new submodule contains
41operations that users preferred to be traced, which simply be a dummy
42placeholder operator. After tracing, users can implement
43:meth:`torch.distributed._spmd.Override.transform` to transform the traced
44graph, where the dummy placeholder operator serves as an anchor to insert
45new sub-graphs.
46"""
47
48@abstractmethod49def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:50r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule``51argument in the model.
52
53This helps if ``orig_submodule`` is not traceable or should not be traced.
54
55Args:
56fqn (str): fully quantified name of the submodule.
57orig_submodule (class:`nn.Module`): original submodule instance to replace.
58
59Returns:
60A new :class:`nn.Module` instance to replace the original one.
61
62"""
63pass64
65@abstractmethod66def transform(67self,68gm: fx.GraphModule,69flat_state: List[torch.Tensor],70) -> fx.GraphModule:71r"""72Given a DTensor-expanded graph and sharding schema for every node,
73conduct additional transformation for the sub-graph from the :class:`nn.Module`
74returned by :meth:`torch.distributed._spmd.Override.replacement` if
75necessary.
76
77Args:
78gm (:class:`fx.Graph`): a DTensor-expanded graph.
79flat_state (List[str, :class:`Tensor`]): a reference to the list of
80flattened state. The elements in ``flat_state`` map to the first
81``len(flat_state)`` placeholders in the graph. The transformation
82can add state to or remove state from ``flat_state`` as long as
83it keeps ``flat_state`` and the placeholders consistent.
84
85Returns:
86The :class:`fx.Graph` after transformation.
87
88"""
89pass90
91
92class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):93# pyre-ignore[3]94def process_inputs(self, *args: Any) -> Any:95return args96
97# pyre-ignore[2, 3]98def gen_fn_def(self, free_vars, maybe_return_annotation):99return CodeGen.gen_fn_def(self, free_vars, maybe_return_annotation)100
101
102def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:103"""Move the responsibility of flattening the input arguments from the graph module to the caller.104
105Example:
106
107output = gm(my_struct)
108
109gm = gm(to_caller_flattened_graph_module)
110
111output = gm(*pytree.flatten(my_struct)[0])
112
113"""
114# pyre-ignore[16]115gm._graph._codegen = _PyTreeCodeGenOutputsOnly(116pytree_info=_PyTreeInfo(117# pyre-ignore[6]118orig_args=None, # type: ignore[arg-type]119# pyre-ignore[6]120in_spec=None, # type: ignore[arg-type]121# pyre-ignore[16]122out_spec=gm._graph._codegen.pytree_info.out_spec,123)124)125gm.recompile()126return gm127
128
129# Use a dtensor expand mode for now to preserve the old behavior
130# and avoid breaking existing code
131dtensor_expand_mode = DTensorExpandMode()132
133
134def _override_placements(t: torch.Tensor, placements: List[Placement]):135global dtensor_expand_mode136dtensor_expand_mode._placements_override[id(t)] = placements137
138
139@contextmanager
140def _rematerialize_optimizer(141opt: torch.optim.Optimizer,142named_states: Dict[str, Any],143params: Dict[str, nn.Parameter],144):145assert opt is not None146
147# update opt.state with proxy tensors148orig_states = copy(opt.state)149for n in named_states:150# opt.state's key type is string, but optimizer uses Parameter as keys151opt.state[params[n]] = named_states[n] # type: ignore[index]152
153# FIXME: support multiple parameter groups154param_group = opt.param_groups[0]155orig_params = param_group["params"]156param_group["params"] = params.values()157
158try:159yield160finally:161param_group["params"] = orig_params162opt.state = orig_states163
164
165aten = torch.ops.aten # pyre-ignore166
167
168@contextmanager
169def _enable_compile():170# The return value of torch._utils.is_compiling changes optimizer behavior.171# We need that function to return True to include optimizer in the graph.172# See: https://github.com/pytorch/pytorch/blob/a524123c91ab399c9dd6882c1189596dd77e7734/torch/optim/optimizer.py#L41173def f_true():174return True175
176orig_is_compiling_code = torch._utils.is_compiling.__code__177torch._utils.is_compiling.__code__ = f_true.__code__178try:179yield180finally:181torch._utils.is_compiling.__code__ = orig_is_compiling_code182
183
184def _foreach_add_decomp(self, other, alpha=1):185self_updated = aten._foreach_add.List(self, other, alpha=alpha)186for s, s_u in zip(self, self_updated):187s.copy_(s_u)188
189
190def _foreach_unaop_decomp(op, self):191self_updated = op(self)192for s, s_u in zip(self, self_updated):193s.copy_(s_u)194
195
196def _foreach_binop_list_decomp(op, self, other):197self_updated = op(self, other)198for s, s_u in zip(self, self_updated):199s.copy_(s_u)200
201
202def _foreach_binop_scalar_decomp(op, self, scalar=1):203self_updated = op(self, scalar)204for s, s_u in zip(self, self_updated):205s.copy_(s_u)206
207
208def _foreach_addcop_scalar_decomp(op, self, tensor1, tensor2, scalar=1):209self_updated = op(self, tensor1, tensor2, scalar)210for s, s_u in zip(self, self_updated):211s.copy_(s_u)212
213
214def _fused_adam_decomp(215self,216grads,217exp_avgs,218exp_avg_sqs,219max_exp_avg_sqs,220state_steps,221*,222lr=1,223beta1=1,224beta2=1,225weight_decay=1,226eps=1,227amsgrad=True,228maximize=True,229grad_scale=None,230found_inf=None,231):232orig_tuple = (self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs)233updated_tuple = aten._fused_adam.default(234self,235grads,236exp_avgs,237exp_avg_sqs,238max_exp_avg_sqs,239state_steps,240lr=lr,241beta1=beta1,242beta2=beta2,243weight_decay=weight_decay,244eps=eps,245amsgrad=amsgrad,246maximize=maximize,247grad_scale=grad_scale,248found_inf=found_inf,249)250
251for idx, (orig, updated) in enumerate(zip(orig_tuple, updated_tuple)):252if idx == 1:253# skip gradient copying as we don't need to copy gradients back254continue255for o, u in zip(orig, updated):256o.copy_(u)257
258
259SPMD_DECOMP_TABLE = {260aten._foreach_add_.List: _foreach_add_decomp,261aten._foreach_add_.Scalar: partial(262_foreach_binop_scalar_decomp, aten._foreach_add.Scalar263),264aten._foreach_addcdiv_.Scalar: partial(265_foreach_addcop_scalar_decomp, aten._foreach_addcdiv.Scalar266),267aten._foreach_addcmul_.Scalar: partial(268_foreach_addcop_scalar_decomp, aten._foreach_addcmul.Scalar269),270aten._foreach_div_.List: partial(271_foreach_binop_list_decomp, aten._foreach_div.List272),273aten._foreach_mul_.Scalar: partial(274_foreach_binop_scalar_decomp, aten._foreach_mul.Scalar275),276aten._foreach_div_.Scalar: partial(277_foreach_binop_scalar_decomp, aten._foreach_div.Scalar278),279aten._foreach_neg_.default: partial(280_foreach_unaop_decomp, aten._foreach_neg.default281),282aten._foreach_reciprocal_.default: partial(283_foreach_unaop_decomp, aten._foreach_reciprocal.default284),285aten._foreach_sqrt_.default: partial(286_foreach_unaop_decomp, aten._foreach_sqrt.default287),288aten._foreach_sub_.Scalar: partial(289_foreach_binop_scalar_decomp, aten._foreach_sub.Scalar290),291aten._fused_adam_.default: _fused_adam_decomp,292aten.native_layer_norm_backward.default: native_layer_norm_backward,293}
294
295
296DEDUP_TARGETS: Set[torch._ops.OpOverload] = {297torch.ops.c10d_functional.all_reduce.default,298torch.ops.c10d_functional.wait_tensor.default,299}
300
301
302def _dedup_collectives(gm: fx.GraphModule) -> fx.GraphModule:303args_to_node: Dict[Tuple[Any, ...], fx.Node] = {}304
305for node in gm.graph.nodes:306# replace all args with the results from the first unique comm op307args = pytree.arg_tree_leaves(*node.args)308
309if node.target in DEDUP_TARGETS:310args_key = (node.target, *args)311unique_node = args_to_node.get(args_key, None)312if unique_node is None:313# first time seeing this combination, remember it314args_to_node[args_key] = node315else:316# the current node is a duplicate, replace it317node.replace_all_uses_with(unique_node)318gm.graph.erase_node(node)319
320gm.recompile()321
322return gm323
324
325@dataclass
326class _CompiledResult:327gm: fx.GraphModule328mod: nn.Module329opt: Optional[torch.optim.Optimizer]330flat_state: List[torch.Tensor]331
332
333def _compile(334func: Callable,335module_override: Optional[List[Override]],336parallel_mode: ParallelMode,337*args: Any,338**kwargs: Any,339) -> _CompiledResult:340# 1. Extract nn.Module and Optimizer from args and kwargs341# FIXME(@mrshenli): support multiple nn.Module instances342# FIXME(@mrshenli): support multiple Optiimzer instances343# FIXME(@mrshenli): need to broadcast model to sync parameters344mod, opt = None, None345for arg in pytree.arg_tree_leaves(*args, **kwargs):346if isinstance(arg, nn.Module):347assert mod is None, "Only support single nn.Module for now"348mod = arg349if isinstance(arg, torch.optim.Optimizer):350assert opt is None, "Only support single Optimizer for now"351opt = arg352
353assert mod is not None, "Couldn't find nn.Module instances from the arguments."354
355# 2. Override target submodules (e.g., MoE) with dummy replacements356if module_override:357accessor = NamedMemberAccessor(mod)358
359def swap(fqn_prefix: str, module: torch.nn.Module) -> None:360for override in module_override: # type: ignore[union-attr]361for name, child in module.named_children():362if len(name) == 0:363continue364fqn = fqn_prefix + "." + name if fqn_prefix != "" else name365new_child = override.replacement(fqn, child)366if id(new_child) == id(child):367swap(fqn, new_child)368else:369accessor.swap_submodule(fqn, new_child)370
371swap("", mod)372
373# 3. Trace statelss version of the train_step374params = dict(mod.named_parameters(remove_duplicate=False))375buffers = dict(mod.named_buffers(remove_duplicate=False))376
377named_states = {}378if opt is not None:379# Pass named_states instead of opt.state to stateless_func, because380# the later uses nn.Parameter as key. During tracing, we need to381# make sure optimizers can find the states using proxy tensors.382for n, p in params.items():383if p in opt.state:384# opt.state's key type is string, but optimizer uses385# Parameter as keys386named_states[n] = opt.state[p] # type: ignore[index]387
388is_data_parallel_mode = isinstance(parallel_mode, DataParallel)389
390# Lift states and parameters as function arguments so that make_fx391# can trace operations applied to them.392def stateless_func(func, params, buffers, named_states, args, kwargs):393with stateless._reparametrize_module(394mod, {**params, **buffers}395), _rematerialize_optimizer(396opt, named_states, params397) if opt else nullcontext():398# For DataParallel mode, install hooks first to tag the gradients399with gradients_tagging(params) if is_data_parallel_mode else nullcontext():400ret = func(*args, **kwargs)401
402# make sure updated parameters are returned403return ret, list(mod.parameters()), list(named_states.values()) # type: ignore[union-attr]404
405# FIXME: Using symbolic tracing to work around in DTensor expand mode.406# Otherwise it hits shape mismatch error, as we use local inputs to407# trace local graph and use DTensor to expand operators, where408# DTensor's shape is the global shape.409tracing_mode = "fake" if is_data_parallel_mode else "symbolic"410
411if is_data_parallel_mode:412fake_mode = FakeTensorMode()413data_parallel_mode = cast(DataParallel, parallel_mode)414
415def _get_full_batch_arg(arg: torch.Tensor) -> torch.Tensor:416# since compilation happens in the first iteration and we417# receives mini-batch input, convert them to full batch418# fake tensor input first for data parallel sharding419# propagations420fake_arg = fake_mode.from_tensor(arg)421arg_dims = [1] * arg.ndim422# expand the tensor to full batch size on its batch dim423arg_dims[data_parallel_mode.input_batch_dim] *= dist.get_world_size()424return fake_arg.repeat(arg_dims)425
426args = pytree.tree_map_only(427torch.Tensor,428_get_full_batch_arg,429args,430)431kwargs = pytree.tree_map_only(432torch.Tensor,433_get_full_batch_arg,434kwargs,435)436
437with _enable_compile(), torch.autograd.detect_anomaly(check_nan=False):438# FIXME(@mrshenli): functionalization does not work for our use439# case yet. Use explicit decompositions for foreach ops.440# Remove this when the following issue is addressed.441# Issue: https://github.com/pytorch/pytorch/issues/97852442gm = make_fx(443partial(stateless_func, func),444tracing_mode=tracing_mode,445decomposition_table=SPMD_DECOMP_TABLE,446_allow_non_fake_inputs=False,447)(params, buffers, named_states, args, kwargs)448
449params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = {450**params,451**buffers,452}453
454# 4. parallel mode to expand a single device graph to a distributed graph455gm = parallel_mode.partition(456gm,457mod,458opt,459params_and_buffers,460named_states,461args,462kwargs,463)464
465# 5. Move the responsibility of flattening the input arguments from the466# graph module to the caller. This serves two purposes:467# - Transformations that add/remove state need to manipulate a state468# container that maintains the state tensors in the same order as they469# appear in graph placeholders.470# - Reduced runtime cost. The state container is only flattened once upfront.471flat_state = pytree.tree_leaves([params_and_buffers, named_states])472gm = _to_caller_flattened_graph_module(gm)473
474# 6. dedup comm operators.475# The duplication could come from DTensor args and kwargs redistribution.476# Suppose one operator produces a Partial gradient tensor and model477# parameters are replicated. In this case, every optimizer operation using478# that Partial gradient tensor would trigger an allreduce. This is becuase479# DTensor only has local information on individual tensor/operator, which is480# not sufficient to detect duplications in the graph. This situation can481# also happen when inserting FSDP allgather if a parameter is used multiple482# times in the forward method.483# TODO(@mrshenli): @yifuwang has a suggestion of conducting expansion and484# dedup at tracer-level to avoid multiple graph passes.485gm = _dedup_collectives(gm)486
487# 7. Replace previously inserted dummy ones with real graphs.488if module_override:489for override in module_override:490gm = override.transform(gm, flat_state)491
492return _CompiledResult(gm, mod, opt, flat_state)493
494
495# Note that the Python convention of __dict__ requires the key to be str.
496# TODO: ensure the key is unique.
497COMPILED_OBJECT_KEY = "_compiled_obj"498
499
500def compile(501module_override: Optional[List[Override]] = None,502gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,503parallel_mode: Optional[ParallelMode] = None,504):505r"""Compile and optimize a callable, which can be a train step within a training loop.506
507This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
508instances from the input arguments and trace operations applied to their
509parameters and states.
510
511Args:
512module_override (Optional[List[Override]]): a list of Override instances
513that will be applied to the module in order. The :class:`Override`
514objects provide :class:`nn.Module` replacements during tracing and a
515graph transformation function after tracing. (Default: ``None``)
516gm_transformation (Optional[Callable[fx.GraphModule, fx.GraphModule]]):
517a callback that will be called after the original callable is
518compiled and distributed (usually after the first iteration) to
519transform the compiled GraphModule into a new optimized one.
520parallel_mode (Optional[ParallelMode]): a :class:`ParallelMode` object
521that specifies how to parallelize the callable. Each ParallelMode
522would have its own strategy to partition the model and the captured
523graph (Default: ``None``)
524
525"""
526
527def inner(func: Callable):528@wraps(func)529def wrapper(*args, **kwargs):530last_train_step = kwargs.pop("last_train_step", False) if kwargs else False531first_iter = False532# Put the COMPILED_OBJECT_KEY in ``wrapper`` instead of ``func`` as533# ``wrapper`` is the one that users will get.534compiled_obj = wrapper.__dict__.get(COMPILED_OBJECT_KEY, None)535if compiled_obj is None:536first_iter = True537global dtensor_expand_mode538mode: ParallelMode = (539dtensor_expand_mode if parallel_mode is None else parallel_mode540)541
542compiled_obj = _compile(func, module_override, mode, *args, **kwargs)543wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj544
545flat_inps = compiled_obj.flat_state + pytree.arg_tree_leaves(546*args, **kwargs547)548
549with torch.no_grad():550# N.B.: we don't need autograd as backward has already been551# captured in the graph.552if first_iter and gm_transformation:553# TODO: SPMD should provid a default and configurable554# transformation.555compiled_obj.gm = gm_transformation(compiled_obj.gm)556if not last_train_step:557output = compiled_obj.gm(*flat_inps)[0]558else:559# This is the last train step. Call IterGraphModule.forward()560# with the `last_iter` argument and catch the exception in561# case the compiled_obj is not wrapped with IterGraphModule.562try:563output = compiled_obj.gm(*flat_inps, last_iter=last_train_step)[5640565]566except TypeError as e:567if "last_iter" not in str(e):568raise e569output = compiled_obj.gm(*flat_inps)[0]570
571return output572
573return wrapper574
575return inner576