pytorch

Форк
0
575 строк · 20.5 Кб
1
from abc import ABC, abstractmethod
2
from contextlib import contextmanager, nullcontext
3
from copy import copy
4
from dataclasses import dataclass
5
from functools import partial, wraps
6
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union
7

8
from functorch import make_fx
9

10
import torch
11
import torch.distributed as dist
12

13
# We need to import _functional_collectives to trigger op registration
14
import torch.distributed._functional_collectives
15
import torch.nn as nn
16
import torch.utils._pytree as pytree
17

18
from torch import fx
19
from torch._decomp.decompositions import native_layer_norm_backward
20

21
from torch._subclasses.fake_tensor import FakeTensorMode
22
from torch.distributed._spmd.data_parallel import gradients_tagging
23
from torch.distributed._spmd.parallel_mode import (
24
    DataParallel,
25
    DTensorExpandMode,
26
    ParallelMode,
27
)
28
from torch.distributed._tensor import Placement
29
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo, CodeGen
30
from torch.nn.utils import stateless
31
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
32

33

34
class Override(ABC):
35
    r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
36

37
    This is useful when any part of the model is not traceable or if you prefer
38
    to not trace it due to any reason. More specifically, users can implement
39
    :meth:`torch.distributed._spmd.Override.replacement` to replace an original
40
    submodule with the return new submodule. The new submodule contains
41
    operations that users preferred to be traced, which simply be a dummy
42
    placeholder operator. After tracing, users can implement
43
    :meth:`torch.distributed._spmd.Override.transform` to transform the traced
44
    graph, where the dummy placeholder operator serves as an anchor to insert
45
    new sub-graphs.
46
    """
47

48
    @abstractmethod
49
    def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:
50
        r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule``
51
        argument in the model.
52

53
        This helps if ``orig_submodule`` is not traceable or should not be traced.
54

55
        Args:
56
            fqn (str): fully quantified name of the submodule.
57
            orig_submodule (class:`nn.Module`): original submodule instance to replace.
58

59
        Returns:
60
            A new :class:`nn.Module` instance to replace the original one.
61

62
        """
63
        pass
64

65
    @abstractmethod
66
    def transform(
67
        self,
68
        gm: fx.GraphModule,
69
        flat_state: List[torch.Tensor],
70
    ) -> fx.GraphModule:
71
        r"""
72
        Given a DTensor-expanded graph and sharding schema for every node,
73
        conduct additional transformation for the sub-graph from the :class:`nn.Module`
74
        returned by :meth:`torch.distributed._spmd.Override.replacement` if
75
        necessary.
76

77
        Args:
78
            gm (:class:`fx.Graph`): a DTensor-expanded graph.
79
            flat_state (List[str, :class:`Tensor`]): a reference to the list of
80
                flattened state. The elements in ``flat_state`` map to the first
81
                ``len(flat_state)`` placeholders in the graph. The transformation
82
                can add state to or remove state from ``flat_state`` as long as
83
                it keeps ``flat_state`` and the placeholders consistent.
84

85
        Returns:
86
            The :class:`fx.Graph` after transformation.
87

88
        """
89
        pass
90

91

92
class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):
93
    # pyre-ignore[3]
94
    def process_inputs(self, *args: Any) -> Any:
95
        return args
96

97
    # pyre-ignore[2, 3]
98
    def gen_fn_def(self, free_vars, maybe_return_annotation):
99
        return CodeGen.gen_fn_def(self, free_vars, maybe_return_annotation)
100

101

102
def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
103
    """Move the responsibility of flattening the input arguments from the graph module to the caller.
104

105
    Example:
106

107
        output = gm(my_struct)
108

109
        gm = gm(to_caller_flattened_graph_module)
110

111
        output = gm(*pytree.flatten(my_struct)[0])
112

113
    """
114
    # pyre-ignore[16]
115
    gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
116
        pytree_info=_PyTreeInfo(
117
            # pyre-ignore[6]
118
            orig_args=None,  # type: ignore[arg-type]
119
            # pyre-ignore[6]
120
            in_spec=None,  # type: ignore[arg-type]
121
            # pyre-ignore[16]
122
            out_spec=gm._graph._codegen.pytree_info.out_spec,
123
        )
124
    )
125
    gm.recompile()
126
    return gm
127

128

129
# Use a dtensor expand mode for now to preserve the old behavior
130
# and avoid breaking existing code
131
dtensor_expand_mode = DTensorExpandMode()
132

133

134
def _override_placements(t: torch.Tensor, placements: List[Placement]):
135
    global dtensor_expand_mode
136
    dtensor_expand_mode._placements_override[id(t)] = placements
137

138

139
@contextmanager
140
def _rematerialize_optimizer(
141
    opt: torch.optim.Optimizer,
142
    named_states: Dict[str, Any],
143
    params: Dict[str, nn.Parameter],
144
):
145
    assert opt is not None
146

147
    # update opt.state with proxy tensors
148
    orig_states = copy(opt.state)
149
    for n in named_states:
150
        # opt.state's key type is string, but optimizer uses Parameter as keys
151
        opt.state[params[n]] = named_states[n]  # type: ignore[index]
152

153
    # FIXME: support multiple parameter groups
154
    param_group = opt.param_groups[0]
155
    orig_params = param_group["params"]
156
    param_group["params"] = params.values()
157

158
    try:
159
        yield
160
    finally:
161
        param_group["params"] = orig_params
162
        opt.state = orig_states
163

164

165
aten = torch.ops.aten  # pyre-ignore
166

167

168
@contextmanager
169
def _enable_compile():
170
    # The return value of torch._utils.is_compiling changes optimizer behavior.
171
    # We need that function to return True to include optimizer in the graph.
172
    # See: https://github.com/pytorch/pytorch/blob/a524123c91ab399c9dd6882c1189596dd77e7734/torch/optim/optimizer.py#L41
173
    def f_true():
174
        return True
175

176
    orig_is_compiling_code = torch._utils.is_compiling.__code__
177
    torch._utils.is_compiling.__code__ = f_true.__code__
178
    try:
179
        yield
180
    finally:
181
        torch._utils.is_compiling.__code__ = orig_is_compiling_code
182

183

184
def _foreach_add_decomp(self, other, alpha=1):
185
    self_updated = aten._foreach_add.List(self, other, alpha=alpha)
186
    for s, s_u in zip(self, self_updated):
187
        s.copy_(s_u)
188

189

190
def _foreach_unaop_decomp(op, self):
191
    self_updated = op(self)
192
    for s, s_u in zip(self, self_updated):
193
        s.copy_(s_u)
194

195

196
def _foreach_binop_list_decomp(op, self, other):
197
    self_updated = op(self, other)
198
    for s, s_u in zip(self, self_updated):
199
        s.copy_(s_u)
200

201

202
def _foreach_binop_scalar_decomp(op, self, scalar=1):
203
    self_updated = op(self, scalar)
204
    for s, s_u in zip(self, self_updated):
205
        s.copy_(s_u)
206

207

208
def _foreach_addcop_scalar_decomp(op, self, tensor1, tensor2, scalar=1):
209
    self_updated = op(self, tensor1, tensor2, scalar)
210
    for s, s_u in zip(self, self_updated):
211
        s.copy_(s_u)
212

213

214
def _fused_adam_decomp(
215
    self,
216
    grads,
217
    exp_avgs,
218
    exp_avg_sqs,
219
    max_exp_avg_sqs,
220
    state_steps,
221
    *,
222
    lr=1,
223
    beta1=1,
224
    beta2=1,
225
    weight_decay=1,
226
    eps=1,
227
    amsgrad=True,
228
    maximize=True,
229
    grad_scale=None,
230
    found_inf=None,
231
):
232
    orig_tuple = (self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs)
233
    updated_tuple = aten._fused_adam.default(
234
        self,
235
        grads,
236
        exp_avgs,
237
        exp_avg_sqs,
238
        max_exp_avg_sqs,
239
        state_steps,
240
        lr=lr,
241
        beta1=beta1,
242
        beta2=beta2,
243
        weight_decay=weight_decay,
244
        eps=eps,
245
        amsgrad=amsgrad,
246
        maximize=maximize,
247
        grad_scale=grad_scale,
248
        found_inf=found_inf,
249
    )
250

251
    for idx, (orig, updated) in enumerate(zip(orig_tuple, updated_tuple)):
252
        if idx == 1:
253
            # skip gradient copying as we don't need to copy gradients back
254
            continue
255
        for o, u in zip(orig, updated):
256
            o.copy_(u)
257

258

259
SPMD_DECOMP_TABLE = {
260
    aten._foreach_add_.List: _foreach_add_decomp,
261
    aten._foreach_add_.Scalar: partial(
262
        _foreach_binop_scalar_decomp, aten._foreach_add.Scalar
263
    ),
264
    aten._foreach_addcdiv_.Scalar: partial(
265
        _foreach_addcop_scalar_decomp, aten._foreach_addcdiv.Scalar
266
    ),
267
    aten._foreach_addcmul_.Scalar: partial(
268
        _foreach_addcop_scalar_decomp, aten._foreach_addcmul.Scalar
269
    ),
270
    aten._foreach_div_.List: partial(
271
        _foreach_binop_list_decomp, aten._foreach_div.List
272
    ),
273
    aten._foreach_mul_.Scalar: partial(
274
        _foreach_binop_scalar_decomp, aten._foreach_mul.Scalar
275
    ),
276
    aten._foreach_div_.Scalar: partial(
277
        _foreach_binop_scalar_decomp, aten._foreach_div.Scalar
278
    ),
279
    aten._foreach_neg_.default: partial(
280
        _foreach_unaop_decomp, aten._foreach_neg.default
281
    ),
282
    aten._foreach_reciprocal_.default: partial(
283
        _foreach_unaop_decomp, aten._foreach_reciprocal.default
284
    ),
285
    aten._foreach_sqrt_.default: partial(
286
        _foreach_unaop_decomp, aten._foreach_sqrt.default
287
    ),
288
    aten._foreach_sub_.Scalar: partial(
289
        _foreach_binop_scalar_decomp, aten._foreach_sub.Scalar
290
    ),
291
    aten._fused_adam_.default: _fused_adam_decomp,
292
    aten.native_layer_norm_backward.default: native_layer_norm_backward,
293
}
294

295

296
DEDUP_TARGETS: Set[torch._ops.OpOverload] = {
297
    torch.ops.c10d_functional.all_reduce.default,
298
    torch.ops.c10d_functional.wait_tensor.default,
299
}
300

301

302
def _dedup_collectives(gm: fx.GraphModule) -> fx.GraphModule:
303
    args_to_node: Dict[Tuple[Any, ...], fx.Node] = {}
304

305
    for node in gm.graph.nodes:
306
        # replace all args with the results from the first unique comm op
307
        args = pytree.arg_tree_leaves(*node.args)
308

309
        if node.target in DEDUP_TARGETS:
310
            args_key = (node.target, *args)
311
            unique_node = args_to_node.get(args_key, None)
312
            if unique_node is None:
313
                # first time seeing this combination, remember it
314
                args_to_node[args_key] = node
315
            else:
316
                # the current node is a duplicate, replace it
317
                node.replace_all_uses_with(unique_node)
318
                gm.graph.erase_node(node)
319

320
    gm.recompile()
321

322
    return gm
323

324

325
@dataclass
326
class _CompiledResult:
327
    gm: fx.GraphModule
328
    mod: nn.Module
329
    opt: Optional[torch.optim.Optimizer]
330
    flat_state: List[torch.Tensor]
331

332

333
def _compile(
334
    func: Callable,
335
    module_override: Optional[List[Override]],
336
    parallel_mode: ParallelMode,
337
    *args: Any,
338
    **kwargs: Any,
339
) -> _CompiledResult:
340
    # 1. Extract nn.Module and Optimizer from args and kwargs
341
    # FIXME(@mrshenli): support multiple nn.Module instances
342
    # FIXME(@mrshenli): support multiple Optiimzer instances
343
    # FIXME(@mrshenli): need to broadcast model to sync parameters
344
    mod, opt = None, None
345
    for arg in pytree.arg_tree_leaves(*args, **kwargs):
346
        if isinstance(arg, nn.Module):
347
            assert mod is None, "Only support single nn.Module for now"
348
            mod = arg
349
        if isinstance(arg, torch.optim.Optimizer):
350
            assert opt is None, "Only support single Optimizer for now"
351
            opt = arg
352

353
    assert mod is not None, "Couldn't find nn.Module instances from the arguments."
354

355
    # 2. Override target submodules (e.g., MoE) with dummy replacements
356
    if module_override:
357
        accessor = NamedMemberAccessor(mod)
358

359
        def swap(fqn_prefix: str, module: torch.nn.Module) -> None:
360
            for override in module_override:  # type: ignore[union-attr]
361
                for name, child in module.named_children():
362
                    if len(name) == 0:
363
                        continue
364
                    fqn = fqn_prefix + "." + name if fqn_prefix != "" else name
365
                    new_child = override.replacement(fqn, child)
366
                    if id(new_child) == id(child):
367
                        swap(fqn, new_child)
368
                    else:
369
                        accessor.swap_submodule(fqn, new_child)
370

371
        swap("", mod)
372

373
    # 3. Trace statelss version of the train_step
374
    params = dict(mod.named_parameters(remove_duplicate=False))
375
    buffers = dict(mod.named_buffers(remove_duplicate=False))
376

377
    named_states = {}
378
    if opt is not None:
379
        # Pass named_states instead of opt.state to stateless_func, because
380
        # the later uses nn.Parameter as key. During tracing, we need to
381
        # make sure optimizers can find the states using proxy tensors.
382
        for n, p in params.items():
383
            if p in opt.state:
384
                # opt.state's key type is string, but optimizer uses
385
                # Parameter as keys
386
                named_states[n] = opt.state[p]  # type: ignore[index]
387

388
    is_data_parallel_mode = isinstance(parallel_mode, DataParallel)
389

390
    # Lift states and parameters as function arguments so that make_fx
391
    # can trace operations applied to them.
392
    def stateless_func(func, params, buffers, named_states, args, kwargs):
393
        with stateless._reparametrize_module(
394
            mod, {**params, **buffers}
395
        ), _rematerialize_optimizer(
396
            opt, named_states, params
397
        ) if opt else nullcontext():
398
            # For DataParallel mode, install hooks first to tag the gradients
399
            with gradients_tagging(params) if is_data_parallel_mode else nullcontext():
400
                ret = func(*args, **kwargs)
401

402
            # make sure updated parameters are returned
403
            return ret, list(mod.parameters()), list(named_states.values())  # type: ignore[union-attr]
404

405
    # FIXME: Using symbolic tracing to work around in DTensor expand mode.
406
    # Otherwise it hits shape mismatch error, as we use local inputs to
407
    # trace local graph and use DTensor to expand operators, where
408
    # DTensor's shape is the global shape.
409
    tracing_mode = "fake" if is_data_parallel_mode else "symbolic"
410

411
    if is_data_parallel_mode:
412
        fake_mode = FakeTensorMode()
413
        data_parallel_mode = cast(DataParallel, parallel_mode)
414

415
        def _get_full_batch_arg(arg: torch.Tensor) -> torch.Tensor:
416
            # since compilation happens in the first iteration and we
417
            # receives mini-batch input, convert them to full batch
418
            # fake tensor input first for data parallel sharding
419
            # propagations
420
            fake_arg = fake_mode.from_tensor(arg)
421
            arg_dims = [1] * arg.ndim
422
            # expand the tensor to full batch size on its batch dim
423
            arg_dims[data_parallel_mode.input_batch_dim] *= dist.get_world_size()
424
            return fake_arg.repeat(arg_dims)
425

426
        args = pytree.tree_map_only(
427
            torch.Tensor,
428
            _get_full_batch_arg,
429
            args,
430
        )
431
        kwargs = pytree.tree_map_only(
432
            torch.Tensor,
433
            _get_full_batch_arg,
434
            kwargs,
435
        )
436

437
    with _enable_compile(), torch.autograd.detect_anomaly(check_nan=False):
438
        # FIXME(@mrshenli): functionalization does not work for our use
439
        # case yet. Use explicit decompositions for foreach ops.
440
        # Remove this when the following issue is addressed.
441
        # Issue: https://github.com/pytorch/pytorch/issues/97852
442
        gm = make_fx(
443
            partial(stateless_func, func),
444
            tracing_mode=tracing_mode,
445
            decomposition_table=SPMD_DECOMP_TABLE,
446
            _allow_non_fake_inputs=False,
447
        )(params, buffers, named_states, args, kwargs)
448

449
    params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = {
450
        **params,
451
        **buffers,
452
    }
453

454
    # 4. parallel mode to expand a single device graph to a distributed graph
455
    gm = parallel_mode.partition(
456
        gm,
457
        mod,
458
        opt,
459
        params_and_buffers,
460
        named_states,
461
        args,
462
        kwargs,
463
    )
464

465
    # 5. Move the responsibility of flattening the input arguments from the
466
    # graph module to the caller. This serves two purposes:
467
    #   - Transformations that add/remove state need to manipulate a state
468
    #   container that maintains the state tensors in the same order as they
469
    #   appear in graph placeholders.
470
    #   - Reduced runtime cost. The state container is only flattened once upfront.
471
    flat_state = pytree.tree_leaves([params_and_buffers, named_states])
472
    gm = _to_caller_flattened_graph_module(gm)
473

474
    # 6. dedup comm operators.
475
    # The duplication could come from DTensor args and kwargs redistribution.
476
    # Suppose one operator produces a Partial gradient tensor and model
477
    # parameters are replicated. In this case, every optimizer operation using
478
    # that Partial gradient tensor would trigger an allreduce. This is becuase
479
    # DTensor only has local information on individual tensor/operator, which is
480
    # not sufficient to detect duplications in the graph. This situation can
481
    # also happen when inserting FSDP allgather if a parameter is used multiple
482
    # times in the forward method.
483
    # TODO(@mrshenli): @yifuwang has a suggestion of conducting expansion and
484
    # dedup at tracer-level to avoid multiple graph passes.
485
    gm = _dedup_collectives(gm)
486

487
    # 7. Replace previously inserted dummy ones with real graphs.
488
    if module_override:
489
        for override in module_override:
490
            gm = override.transform(gm, flat_state)
491

492
    return _CompiledResult(gm, mod, opt, flat_state)
493

494

495
# Note that the Python convention of __dict__ requires the key to be str.
496
# TODO: ensure the key is unique.
497
COMPILED_OBJECT_KEY = "_compiled_obj"
498

499

500
def compile(
501
    module_override: Optional[List[Override]] = None,
502
    gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
503
    parallel_mode: Optional[ParallelMode] = None,
504
):
505
    r"""Compile and optimize a callable, which can be a train step within a training loop.
506

507
    This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
508
    instances from the input arguments and trace operations applied to their
509
    parameters and states.
510

511
    Args:
512
        module_override (Optional[List[Override]]): a list of Override instances
513
            that will be applied to the module in order. The :class:`Override`
514
            objects provide :class:`nn.Module` replacements during tracing and a
515
            graph transformation function after tracing. (Default: ``None``)
516
        gm_transformation (Optional[Callable[fx.GraphModule, fx.GraphModule]]):
517
            a callback that will be called after the original callable is
518
            compiled and distributed (usually after the first iteration) to
519
            transform the compiled GraphModule into a new optimized one.
520
        parallel_mode (Optional[ParallelMode]): a :class:`ParallelMode` object
521
            that specifies how to parallelize the callable. Each ParallelMode
522
            would have its own strategy to partition the model and the captured
523
            graph (Default: ``None``)
524

525
    """
526

527
    def inner(func: Callable):
528
        @wraps(func)
529
        def wrapper(*args, **kwargs):
530
            last_train_step = kwargs.pop("last_train_step", False) if kwargs else False
531
            first_iter = False
532
            # Put the COMPILED_OBJECT_KEY in ``wrapper`` instead of ``func`` as
533
            # ``wrapper`` is the one that users will get.
534
            compiled_obj = wrapper.__dict__.get(COMPILED_OBJECT_KEY, None)
535
            if compiled_obj is None:
536
                first_iter = True
537
                global dtensor_expand_mode
538
                mode: ParallelMode = (
539
                    dtensor_expand_mode if parallel_mode is None else parallel_mode
540
                )
541

542
                compiled_obj = _compile(func, module_override, mode, *args, **kwargs)
543
                wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj
544

545
            flat_inps = compiled_obj.flat_state + pytree.arg_tree_leaves(
546
                *args, **kwargs
547
            )
548

549
            with torch.no_grad():
550
                # N.B.: we don't need autograd as backward has already been
551
                # captured in the graph.
552
                if first_iter and gm_transformation:
553
                    # TODO: SPMD should provid a default and configurable
554
                    # transformation.
555
                    compiled_obj.gm = gm_transformation(compiled_obj.gm)
556
                if not last_train_step:
557
                    output = compiled_obj.gm(*flat_inps)[0]
558
                else:
559
                    # This is the last train step. Call IterGraphModule.forward()
560
                    # with the `last_iter` argument and catch the exception in
561
                    # case the compiled_obj is not wrapped with IterGraphModule.
562
                    try:
563
                        output = compiled_obj.gm(*flat_inps, last_iter=last_train_step)[
564
                            0
565
                        ]
566
                    except TypeError as e:
567
                        if "last_iter" not in str(e):
568
                            raise e
569
                        output = compiled_obj.gm(*flat_inps)[0]
570

571
                return output
572

573
        return wrapper
574

575
    return inner
576

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.