2
from typing import Any, Callable, Dict, Optional
5
import torch.utils._pytree as pytree
12
META_TAG = "MODULE_TYPE"
13
MODULE_TAG = "_MAIN_MODULE"
14
CONST_MODULE_TAG = "_CONST_MODULE"
17
def replace_node_with_constant(gm, node, constant, name=None):
23
if not hasattr(gm, "_frozen_param_count"):
24
gm._frozen_param_count = 0
25
i = gm._frozen_param_count
28
qualname = f"_frozen_param{i}"
29
if not hasattr(gm, qualname):
33
gm._frozen_param_count = i + 1
35
with g.inserting_before(node):
36
new_input_node = g.create_node("get_attr", qualname, (), {})
37
node.replace_all_uses_with(new_input_node)
38
new_input_node.meta.update(node.meta)
42
gm.register_buffer(qualname, constant)
43
setattr(gm, qualname, constant)
46
class ConstantFolder(torch.fx.Interpreter):
50
skip_constructors=False,
53
self.node_replacements: Dict[torch.fx.Node, Any] = {}
54
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
55
self.unknown_value = object()
56
self.skip_constructors: bool = skip_constructors
60
self.user_to_last_uses = self.node_to_last_non_output_use()
62
def is_impure(self, node: torch.fx.node.Node):
64
torch.ops.quantized_decomposed.dequantize_per_channel.default,
65
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
66
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
74
def node_to_last_non_output_use(self):
75
last_non_output_use = collections.defaultdict(list)
77
output_node = next(iter(reversed(self.module.graph.nodes)))
79
for node in reversed(self.module.graph.nodes):
80
if node.target == "output":
88
last_non_output_use[node].append(inp)
90
pytree.tree_map_only(torch.fx.Node, add_use, (node.args, node.kwargs))
93
if len(node.users) == 1 and output_node in node.users:
94
last_non_output_use[node].append(node)
96
return last_non_output_use
98
def run_node(self, node):
99
if node.target == "output":
103
self.env[arg] = self.unknown_value
105
pytree.tree_map_only(torch.fx.Node, set_env, node.args)
106
return super().run_node(node)
108
args, kwargs = self.fetch_args_kwargs_from_env(node)
109
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
111
if self.unknown_value in flattened_inputs:
112
return self.unknown_value
116
node.op == "call_function"
117
and node.target == aten._efficientzerotensor.default
119
return self.unknown_value
123
node.op == "call_function"
124
and node.name == "triton_kernel_wrapper_functional_proxy"
126
return self.unknown_value
132
self.skip_constructors
133
and node.op != "get_attr"
134
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
136
return self.unknown_value
140
isinstance(node.target, torch._ops.OpOverload)
141
and torch.Tag.nondeterministic_seeded in node.target.tags
143
return self.unknown_value
145
out = super().run_node(node)
147
if node.op != "get_attr" and isinstance(out, torch.Tensor):
148
if not self.insertable_tensor_check(out):
151
if self.is_impure(node):
152
return self.unknown_value
154
self.add_node_replacement(node, out)
156
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
158
for n in flattened_node_inps:
159
if not isinstance(n, torch.fx.Node):
162
self.replaced_uses[n] += 1
164
for to_delete in self.user_to_last_uses.get(node, []):
165
if self.replaced_uses[to_delete] == len(to_delete.users):
166
self.node_replacements.pop(to_delete, None)
170
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
173
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
174
self.node_replacements[node] = tensor
178
for n in self.module.graph.nodes:
179
if n.op == "placeholder":
180
env[n] = self.unknown_value
181
return super().run(initial_env=env)
184
@torch.utils._python_dispatch._disable_current_modes()
185
def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
186
cf = ConstantFolder(gm, skip_constructors=True)
189
for node, constant in cf.node_replacements.items():
190
if constraint_fn is not None and not constraint_fn(node):
192
replace_node_with_constant(gm, node, constant)
195
for node in gm.graph.nodes:
196
if node.op == "get_attr" and len(node.users) == 0:
197
if hasattr(gm, node.target):
198
delattr(gm, node.target)
199
erased_params.append(node)
201
for node in erased_params:
202
gm.graph.erase_node(node)
204
gm.graph.eliminate_dead_code()
209
@torch.utils._python_dispatch._disable_current_modes()
210
def constant_graph_tag(gm: torch.fx.GraphModule):
211
cf = ConstantFolder(gm, skip_constructors=True)
214
for node in gm.graph.nodes:
216
node.op == "get_attr"
217
or node in cf.node_replacements
218
or node in cf.replaced_uses
220
node.meta[META_TAG] = CONST_MODULE_TAG
222
node.meta[META_TAG] = MODULE_TAG
225
def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
227
Construct a GraphModule which corresponds to the part which could be
228
constant folded in provided gm.
231
constant_graph_tag(gm)
234
for node in gm.graph.nodes:
235
if node.op == "get_attr":
238
if u.meta[META_TAG] == CONST_MODULE_TAG:
242
node.meta[META_TAG] = MODULE_TAG
244
new_graph = torch.fx.Graph()
246
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
248
for node in gm.graph.nodes:
249
if node.meta[META_TAG] == MODULE_TAG:
252
new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
253
node_remapping[node] = new_node
255
for user in node.users:
256
if user.meta[META_TAG] == MODULE_TAG:
257
output_nodes.append(new_node)
260
new_graph.output(tuple(output_nodes))
262
new_gm = torch.fx.GraphModule(gm, new_graph)