pytorch

Форк
0
/
batch_dim_utils.py 
179 строк · 7.7 Кб
1
from typing import Callable, Dict, List, Set
2

3
import torch
4

5
import torch.fx as fx
6

7
import torch.utils._pytree as pytree
8

9
from torch import Tensor
10

11
from torch.distributed._tensor import DeviceMesh, Replicate, Shard
12
from torch.distributed._tensor.ops.view_ops import (
13
    DimSpec,
14
    InputDim,
15
    ops as view_op_rules,
16
)
17
from torch.distributed._tensor.placement_types import _Partial, DTensorSpec
18

19
aten = torch.ops.aten
20

21

22
class BatchDimAnalyzer:
23
    """This class is used to analyze the batch dimension of each tensor/node in the graph.
24

25
    We need to know the batch dimension of each tensor/node so that we know
26
    exactly the sharding layout of intermediate tensors.
27

28
    We possibly should evaluate using symbolic shapes to track the batch dimension.
29
    We can experiment it later with dynamo integration (as dynamo have mark_dynamic
30
    API which allows marking batch dimension only) or try to use FakeTensorMode to
31
    mark the batch dimension. For now, let's just use the batch dimension of the first
32
    input tensor as the hint to track the batch dimension of all tensors/nodes in
33
    the graph.
34
    """
35

36
    def __init__(self, batch_dim: int = 0) -> None:
37
        self.batch_dim = batch_dim
38

39
        self.batch_dim_map: Dict[fx.Node, int] = {}
40
        # batch dim size is used to track the batch dim size of the input tensor
41
        self.batch_dim_size = -1
42

43
        self.dim_rule_map: Dict[torch._ops.OpOverload, Callable[..., torch.Tensor]] = {
44
            aten.squeeze.default: torch.squeeze,
45
            aten.squeeze.dim: torch.squeeze,
46
            aten.view.default: Tensor.view,
47
            aten.reshape.default: torch.reshape,
48
            aten._unsafe_view.default: Tensor.view,
49
            aten.unsqueeze.default: torch.unsqueeze,
50
            aten.expand.default: Tensor.expand,
51
            aten.permute.default: torch.permute,
52
            aten.repeat.default: Tensor.repeat,
53
            aten.transpose.int: torch.transpose,
54
        }
55

56
    def init_batch_dim_size(self, batch_dim_size: int) -> None:
57
        """Initialize batch dim size base on the first input batch size."""
58
        if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size:
59
            raise RuntimeError(
60
                f"batch dim size is already initialized! "
61
                f"Found new batch size: {batch_dim_size} not "
62
                f"matching existing batch dim size: {self.batch_dim_size}!"
63
            )
64
        self.batch_dim_size = batch_dim_size
65

66
    def set_batch_dim(self, node: fx.Node, batch_dim: int) -> None:
67
        self.batch_dim_map[node] = batch_dim
68

69
    def get_batch_dim(self, node: fx.Node) -> int:
70
        if node not in self.batch_dim_map:
71
            raise RuntimeError(f"batch dim analysis failed on node: {node}!")
72
        return self.batch_dim_map[node]
73

74
    def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int:
75
        """Compute the batch dimension for the `node`."""
76
        assert self.batch_dim_size != -1, "batch dim size is not initialized!"
77

78
        if node in self.batch_dim_map:
79
            # if batch dim already computed, simply return it
80
            return self.batch_dim_map[node]
81

82
        if node.target in self.dim_rule_map:
83
            view_op_rule = view_op_rules[self.dim_rule_map[node.target]]  # type: ignore[index]
84
            args_val = pytree.tree_map_only(fx.Node, lambda n: n.meta["val"], node.args)
85
            kwargs_val = pytree.tree_map_only(
86
                fx.Node, lambda n: n.meta["val"], node.kwargs
87
            )
88
            output_dim_rules = view_op_rule.dim_map(*args_val, **kwargs_val)
89

90
            def collect_input_dim(cmd: DimSpec, input_dims: Set[int]):
91
                if isinstance(cmd, InputDim):
92
                    input_dims.add(cmd.input_dim)
93
                for inp in cmd.inputs():
94
                    collect_input_dim(inp, input_dims)
95

96
            output_dim_to_input_dims: List[Set[int]] = []
97
            for inp in output_dim_rules:
98
                input_dims: Set[int] = set()
99
                collect_input_dim(inp, input_dims=input_dims)
100
                output_dim_to_input_dims.append(input_dims)
101

102
            operand = node.all_input_nodes[0]
103
            operand_batch_dim = self.get_batch_dim(operand)
104
            for output_dim, input_dims in enumerate(output_dim_to_input_dims):
105
                if operand_batch_dim in input_dims:
106
                    self.set_batch_dim(node, output_dim)
107
                    # update batch dim size before return
108
                    # this is because batch dim size might change during the middle
109
                    self.batch_dim_size = node.meta["val"].shape[output_dim]
110
                    return output_dim
111

112
        # if there's no hints from the output_dim_rules, we infer from output
113
        # shape to see if there's batch dim, and shard correspondingly
114
        node_val = node.meta["val"]
115
        if isinstance(node_val, (list, tuple)):
116
            shapes = [val.shape for val in node_val]
117
        else:
118
            shapes = [node_val.shape]
119

120
        # for reduction op that reduces over the sharded batch dim
121
        # we don't generate partial, but rather, we generate shard
122
        # This is because the intention of data parallel is to never
123
        # do full reduction across batch dimension, it would still
124
        # keep the reduction activation as sharded.
125
        full_reduction = False
126
        # loop through the dim size to find the output batch dim
127
        for shape in shapes:
128
            if len(shape) == 0:
129
                full_reduction = True
130

131
            for i, dim_size in enumerate(shape):
132
                if dim_size == self.batch_dim_size:
133
                    self.set_batch_dim(node, i)
134
                    return i
135

136
        operands = node.all_input_nodes
137
        if not operands:
138
            # if there's no operands, it must be factory ops and it's a tensor
139
            # generated for computation and should be marked as replicated
140
            self.set_batch_dim(node, -1)
141
            # -1 means replicated
142
            return -1
143
        else:
144
            # if there's operand we see the operand have batch dim, if operand
145
            # have batch dim but output does not, it's either a full reduction,
146
            # where we should stay sharded, or it's a reduction on batch dim only
147
            # where we should produce partial
148
            operand_batch_dim = -1
149
            for operand in operands:
150
                if operand in self.batch_dim_map:
151
                    operand_batch_dim = self.get_batch_dim(operand)
152
            # self.get_batch_dim(operands[0])
153
            if operand_batch_dim < 0:
154
                # if operand does not have batch dim, we also don't have batch dim
155
                self.set_batch_dim(node, operand_batch_dim)
156
                return operand_batch_dim
157
            elif full_reduction:
158
                self.set_batch_dim(node, operand_batch_dim)
159
                return operand_batch_dim
160
            else:
161
                # if operand have batch dim but output does not, it should
162
                # produce partial, we use -2 to indicate partial
163
                self.set_batch_dim(node, -2)
164
                return -2
165

166
    def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec:
167
        """Compute the batch dimension for the current node, then generate the sharding spec that shards on the batch dimension."""
168
        node_batch_dim = self.compute_batch_dim(node)
169
        if node_batch_dim == -1:
170
            # indicate this activation is replicated
171
            act_spec = DTensorSpec(mesh=mesh, placements=(Replicate(),))
172
        elif node_batch_dim == -2:
173
            # indicate this activation is partial
174
            act_spec = DTensorSpec(mesh=mesh, placements=(_Partial(),))
175
        else:
176
            # indicate this activation is Shard
177
            act_spec = DTensorSpec(mesh=mesh, placements=(Shard(node_batch_dim),))
178

179
        return act_spec
180

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.