pytorch
179 строк · 7.7 Кб
1from typing import Callable, Dict, List, Set
2
3import torch
4
5import torch.fx as fx
6
7import torch.utils._pytree as pytree
8
9from torch import Tensor
10
11from torch.distributed._tensor import DeviceMesh, Replicate, Shard
12from torch.distributed._tensor.ops.view_ops import (
13DimSpec,
14InputDim,
15ops as view_op_rules,
16)
17from torch.distributed._tensor.placement_types import _Partial, DTensorSpec
18
19aten = torch.ops.aten
20
21
22class BatchDimAnalyzer:
23"""This class is used to analyze the batch dimension of each tensor/node in the graph.
24
25We need to know the batch dimension of each tensor/node so that we know
26exactly the sharding layout of intermediate tensors.
27
28We possibly should evaluate using symbolic shapes to track the batch dimension.
29We can experiment it later with dynamo integration (as dynamo have mark_dynamic
30API which allows marking batch dimension only) or try to use FakeTensorMode to
31mark the batch dimension. For now, let's just use the batch dimension of the first
32input tensor as the hint to track the batch dimension of all tensors/nodes in
33the graph.
34"""
35
36def __init__(self, batch_dim: int = 0) -> None:
37self.batch_dim = batch_dim
38
39self.batch_dim_map: Dict[fx.Node, int] = {}
40# batch dim size is used to track the batch dim size of the input tensor
41self.batch_dim_size = -1
42
43self.dim_rule_map: Dict[torch._ops.OpOverload, Callable[..., torch.Tensor]] = {
44aten.squeeze.default: torch.squeeze,
45aten.squeeze.dim: torch.squeeze,
46aten.view.default: Tensor.view,
47aten.reshape.default: torch.reshape,
48aten._unsafe_view.default: Tensor.view,
49aten.unsqueeze.default: torch.unsqueeze,
50aten.expand.default: Tensor.expand,
51aten.permute.default: torch.permute,
52aten.repeat.default: Tensor.repeat,
53aten.transpose.int: torch.transpose,
54}
55
56def init_batch_dim_size(self, batch_dim_size: int) -> None:
57"""Initialize batch dim size base on the first input batch size."""
58if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size:
59raise RuntimeError(
60f"batch dim size is already initialized! "
61f"Found new batch size: {batch_dim_size} not "
62f"matching existing batch dim size: {self.batch_dim_size}!"
63)
64self.batch_dim_size = batch_dim_size
65
66def set_batch_dim(self, node: fx.Node, batch_dim: int) -> None:
67self.batch_dim_map[node] = batch_dim
68
69def get_batch_dim(self, node: fx.Node) -> int:
70if node not in self.batch_dim_map:
71raise RuntimeError(f"batch dim analysis failed on node: {node}!")
72return self.batch_dim_map[node]
73
74def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int:
75"""Compute the batch dimension for the `node`."""
76assert self.batch_dim_size != -1, "batch dim size is not initialized!"
77
78if node in self.batch_dim_map:
79# if batch dim already computed, simply return it
80return self.batch_dim_map[node]
81
82if node.target in self.dim_rule_map:
83view_op_rule = view_op_rules[self.dim_rule_map[node.target]] # type: ignore[index]
84args_val = pytree.tree_map_only(fx.Node, lambda n: n.meta["val"], node.args)
85kwargs_val = pytree.tree_map_only(
86fx.Node, lambda n: n.meta["val"], node.kwargs
87)
88output_dim_rules = view_op_rule.dim_map(*args_val, **kwargs_val)
89
90def collect_input_dim(cmd: DimSpec, input_dims: Set[int]):
91if isinstance(cmd, InputDim):
92input_dims.add(cmd.input_dim)
93for inp in cmd.inputs():
94collect_input_dim(inp, input_dims)
95
96output_dim_to_input_dims: List[Set[int]] = []
97for inp in output_dim_rules:
98input_dims: Set[int] = set()
99collect_input_dim(inp, input_dims=input_dims)
100output_dim_to_input_dims.append(input_dims)
101
102operand = node.all_input_nodes[0]
103operand_batch_dim = self.get_batch_dim(operand)
104for output_dim, input_dims in enumerate(output_dim_to_input_dims):
105if operand_batch_dim in input_dims:
106self.set_batch_dim(node, output_dim)
107# update batch dim size before return
108# this is because batch dim size might change during the middle
109self.batch_dim_size = node.meta["val"].shape[output_dim]
110return output_dim
111
112# if there's no hints from the output_dim_rules, we infer from output
113# shape to see if there's batch dim, and shard correspondingly
114node_val = node.meta["val"]
115if isinstance(node_val, (list, tuple)):
116shapes = [val.shape for val in node_val]
117else:
118shapes = [node_val.shape]
119
120# for reduction op that reduces over the sharded batch dim
121# we don't generate partial, but rather, we generate shard
122# This is because the intention of data parallel is to never
123# do full reduction across batch dimension, it would still
124# keep the reduction activation as sharded.
125full_reduction = False
126# loop through the dim size to find the output batch dim
127for shape in shapes:
128if len(shape) == 0:
129full_reduction = True
130
131for i, dim_size in enumerate(shape):
132if dim_size == self.batch_dim_size:
133self.set_batch_dim(node, i)
134return i
135
136operands = node.all_input_nodes
137if not operands:
138# if there's no operands, it must be factory ops and it's a tensor
139# generated for computation and should be marked as replicated
140self.set_batch_dim(node, -1)
141# -1 means replicated
142return -1
143else:
144# if there's operand we see the operand have batch dim, if operand
145# have batch dim but output does not, it's either a full reduction,
146# where we should stay sharded, or it's a reduction on batch dim only
147# where we should produce partial
148operand_batch_dim = -1
149for operand in operands:
150if operand in self.batch_dim_map:
151operand_batch_dim = self.get_batch_dim(operand)
152# self.get_batch_dim(operands[0])
153if operand_batch_dim < 0:
154# if operand does not have batch dim, we also don't have batch dim
155self.set_batch_dim(node, operand_batch_dim)
156return operand_batch_dim
157elif full_reduction:
158self.set_batch_dim(node, operand_batch_dim)
159return operand_batch_dim
160else:
161# if operand have batch dim but output does not, it should
162# produce partial, we use -2 to indicate partial
163self.set_batch_dim(node, -2)
164return -2
165
166def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec:
167"""Compute the batch dimension for the current node, then generate the sharding spec that shards on the batch dimension."""
168node_batch_dim = self.compute_batch_dim(node)
169if node_batch_dim == -1:
170# indicate this activation is replicated
171act_spec = DTensorSpec(mesh=mesh, placements=(Replicate(),))
172elif node_batch_dim == -2:
173# indicate this activation is partial
174act_spec = DTensorSpec(mesh=mesh, placements=(_Partial(),))
175else:
176# indicate this activation is Shard
177act_spec = DTensorSpec(mesh=mesh, placements=(Shard(node_batch_dim),))
178
179return act_spec
180