pytorch

Форк
0
/
constant_folding.py 
264 строки · 8.6 Кб
1
import collections
2
from typing import Any, Callable, Dict, Optional
3

4
import torch
5
import torch.utils._pytree as pytree
6

7
aten = torch.ops.aten
8

9
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
10
# The use case and more information could be found at:
11
# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
12
META_TAG = "MODULE_TYPE"
13
MODULE_TAG = "_MAIN_MODULE"
14
CONST_MODULE_TAG = "_CONST_MODULE"
15

16

17
def replace_node_with_constant(gm, node, constant, name=None):
18
    g = gm.graph
19

20
    if name:
21
        qualname = name
22
    else:
23
        if not hasattr(gm, "_frozen_param_count"):
24
            gm._frozen_param_count = 0
25
        i = gm._frozen_param_count
26

27
        while True:
28
            qualname = f"_frozen_param{i}"
29
            if not hasattr(gm, qualname):
30
                break
31
            i += 1
32

33
        gm._frozen_param_count = i + 1
34

35
    with g.inserting_before(node):
36
        new_input_node = g.create_node("get_attr", qualname, (), {})
37
        node.replace_all_uses_with(new_input_node)
38
        new_input_node.meta.update(node.meta)
39
        g.erase_node(node)
40

41
    # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
42
    gm.register_buffer(qualname, constant)
43
    setattr(gm, qualname, constant)
44

45

46
class ConstantFolder(torch.fx.Interpreter):
47
    def __init__(
48
        self,
49
        gm,
50
        skip_constructors=False,
51
    ):
52
        super().__init__(gm)
53
        self.node_replacements: Dict[torch.fx.Node, Any] = {}
54
        self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
55
        self.unknown_value = object()
56
        self.skip_constructors: bool = skip_constructors
57

58
        # overwrite this to deallocate env values if their only remaining use
59
        # is the output
60
        self.user_to_last_uses = self.node_to_last_non_output_use()
61

62
    def is_impure(self, node: torch.fx.node.Node):
63
        if node.target in [
64
            torch.ops.quantized_decomposed.dequantize_per_channel.default,
65
            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
66
            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
67
        ]:
68
            # For the pattern fp32_weight -> q -> dq
69
            # We only folding fp32_weight -> q
70
            # int8_weight and leave dq in graph to be fused
71
            return True
72
        return False
73

74
    def node_to_last_non_output_use(self):
75
        last_non_output_use = collections.defaultdict(list)
76
        seen_uses = set()
77
        output_node = next(iter(reversed(self.module.graph.nodes)))
78

79
        for node in reversed(self.module.graph.nodes):
80
            if node.target == "output":
81
                continue
82

83
            def add_use(inp):
84
                if inp in seen_uses:
85
                    return
86

87
                seen_uses.add(inp)
88
                last_non_output_use[node].append(inp)
89

90
            pytree.tree_map_only(torch.fx.Node, add_use, (node.args, node.kwargs))
91

92
            # if this node is only used in output, we want to gc it right away
93
            if len(node.users) == 1 and output_node in node.users:
94
                last_non_output_use[node].append(node)
95

96
        return last_non_output_use
97

98
    def run_node(self, node):
99
        if node.target == "output":
100
            # because we remove nodes from env on last non output use,
101
            # re-define them now or we'll get error in interpreter
102
            def set_env(arg):
103
                self.env[arg] = self.unknown_value
104

105
            pytree.tree_map_only(torch.fx.Node, set_env, node.args)
106
            return super().run_node(node)
107

108
        args, kwargs = self.fetch_args_kwargs_from_env(node)
109
        flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
110

111
        if self.unknown_value in flattened_inputs:
112
            return self.unknown_value
113

114
        # TODO - fix errors with this
115
        if (
116
            node.op == "call_function"
117
            and node.target == aten._efficientzerotensor.default
118
        ):
119
            return self.unknown_value
120

121
        # TODO - constant folding triton kernel returns the inputs -- fix this
122
        if (
123
            node.op == "call_function"
124
            and node.name == "triton_kernel_wrapper_functional_proxy"
125
        ):
126
            return self.unknown_value
127

128
        # skip constructors, since inductor generates optimal code for them already
129
        # and turning into tensor would result in an additional global memory read
130
        # TODO - more complicated strategy
131
        if (
132
            self.skip_constructors
133
            and node.op != "get_attr"
134
            and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
135
        ):
136
            return self.unknown_value
137

138
        # All mutations should either be removed or on inputs which we did not make constant
139
        if (
140
            isinstance(node.target, torch._ops.OpOverload)
141
            and torch.Tag.nondeterministic_seeded in node.target.tags
142
        ):
143
            return self.unknown_value
144

145
        out = super().run_node(node)
146

147
        if node.op != "get_attr" and isinstance(out, torch.Tensor):
148
            if not self.insertable_tensor_check(out):
149
                return out
150

151
            if self.is_impure(node):
152
                return self.unknown_value
153

154
            self.add_node_replacement(node, out)
155

156
            flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
157

158
            for n in flattened_node_inps:
159
                if not isinstance(n, torch.fx.Node):
160
                    continue
161

162
                self.replaced_uses[n] += 1
163

164
            for to_delete in self.user_to_last_uses.get(node, []):
165
                if self.replaced_uses[to_delete] == len(to_delete.users):
166
                    self.node_replacements.pop(to_delete, None)
167

168
        return out
169

170
    def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
171
        return True
172

173
    def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
174
        self.node_replacements[node] = tensor
175

176
    def run(self):
177
        env = {}
178
        for n in self.module.graph.nodes:
179
            if n.op == "placeholder":
180
                env[n] = self.unknown_value
181
        return super().run(initial_env=env)
182

183

184
@torch.utils._python_dispatch._disable_current_modes()
185
def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
186
    cf = ConstantFolder(gm, skip_constructors=True)
187
    cf.run()
188

189
    for node, constant in cf.node_replacements.items():
190
        if constraint_fn is not None and not constraint_fn(node):
191
            continue
192
        replace_node_with_constant(gm, node, constant)
193

194
    erased_params = []
195
    for node in gm.graph.nodes:
196
        if node.op == "get_attr" and len(node.users) == 0:
197
            if hasattr(gm, node.target):
198
                delattr(gm, node.target)
199
            erased_params.append(node)
200

201
    for node in erased_params:
202
        gm.graph.erase_node(node)
203

204
    gm.graph.eliminate_dead_code()
205
    gm.graph.lint()
206
    gm.recompile()
207

208

209
@torch.utils._python_dispatch._disable_current_modes()
210
def constant_graph_tag(gm: torch.fx.GraphModule):
211
    cf = ConstantFolder(gm, skip_constructors=True)
212
    cf.run()
213

214
    for node in gm.graph.nodes:
215
        if (
216
            node.op == "get_attr"
217
            or node in cf.node_replacements
218
            or node in cf.replaced_uses
219
        ):
220
            node.meta[META_TAG] = CONST_MODULE_TAG
221
        else:
222
            node.meta[META_TAG] = MODULE_TAG
223

224

225
def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
226
    """
227
    Construct a GraphModule which corresponds to the part which could be
228
    constant folded in provided gm.
229
    """
230

231
    constant_graph_tag(gm)
232
    # We rewrite the tags, if it's a constant being directly consumed, without
233
    # any folding opportunity, we keep it in main gm.
234
    for node in gm.graph.nodes:
235
        if node.op == "get_attr":
236
            used_to_fold = False
237
            for u in node.users:
238
                if u.meta[META_TAG] == CONST_MODULE_TAG:
239
                    used_to_fold = True
240
                    break
241
            if not used_to_fold:
242
                node.meta[META_TAG] = MODULE_TAG
243

244
    new_graph = torch.fx.Graph()
245

246
    node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
247
    output_nodes = []
248
    for node in gm.graph.nodes:
249
        if node.meta[META_TAG] == MODULE_TAG:
250
            continue
251

252
        new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
253
        node_remapping[node] = new_node
254

255
        for user in node.users:
256
            if user.meta[META_TAG] == MODULE_TAG:
257
                output_nodes.append(new_node)
258
                break
259

260
    new_graph.output(tuple(output_nodes))
261
    new_graph.lint()
262
    new_gm = torch.fx.GraphModule(gm, new_graph)
263

264
    return new_gm
265

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.