6
from torch.fx import symbolic_trace
7
from torch.fx.experimental.proxy_tensor import make_fx
8
from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops
9
from torch.testing._internal.common_utils import run_tests, TestCase
12
banned_ops = get_CSE_banned_ops()
13
P_default = CSEPass(banned_ops=banned_ops)
16
def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
18
check if the CSE modified graph of ``f``
19
1) has delta less nodes, and
20
2) do not reduce the number of nodes further on a second pass, and
21
3) modified returned is true only if the number of nodes decreases.
24
f: function to be checked
25
t: tensor to be passed to f
26
delta: an integer >= -1.
27
If delta = -1, it only checks if the new graph has less or equal number of nodes
28
check_val: if True, check if the output of f is correct
29
graph_input: True is f is type GraphModule
30
P: the pass to use. If None, use P_default
41
new_g = res.graph_module
42
new_graph = new_g.graph
43
modified = res.modified
46
old_num_nodes = len(fx_g.graph.nodes)
47
new_num_nodes = len(new_graph.nodes)
50
new_num_nodes < old_num_nodes
51
) == modified, "modified should be True if the number of nodes decrease"
55
old_num_nodes >= new_num_nodes,
56
(f"number of nodes increased {old_num_nodes}, {new_num_nodes}"),
60
old_num_nodes == new_num_nodes + delta,
62
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
68
pass_2_graph = res.graph_module.graph
69
pass_2_num_nodes = len(pass_2_graph.nodes)
71
pass_2_num_nodes == new_num_nodes,
73
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
81
if true_result is None:
83
our_result is None, f"true result is None, CSE result is {our_result}"
87
torch.all(true_result == our_result),
88
(f"results are different {true_result}, {our_result}"),
92
class TestCSEPass(TestCase):
93
def test_nochange(self):
101
t = torch.randn(2, 2)
104
def test_empty(self):
108
t = torch.randn(2, 2)
111
def test_immutable_list_type(self):
119
t = torch.randn(2, 2)
122
def test_immutable_list_multiple_entries(self):
124
a = x.sum(dim=[0, 1])
125
b = x.sum(dim=[0, 1])
130
t = torch.randn(2, 2)
133
def test_simple(self):
141
t = torch.randn(2, 2)
144
def test_simple_2(self):
155
def test_two_args_default(self):
158
b = x.sum(dim=1, keepdim=False)
159
c = x.sum(dim=1, keepdim=False)
163
t = torch.randn(2, 2)
166
def test_two_args(self):
169
b = x.sum(dim=1, keepdim=True)
170
c = x.sum(dim=1, keepdim=True)
174
t = torch.randn(2, 2)
177
def test_simple_multiple_same_ops(self):
185
t = torch.randn(2, 2)
188
def test_nested_immutable_list_type(self):
190
a = torch.cat((x, x))
191
b = torch.cat((x, x))
194
t = torch.randn(2, 2)
197
def test_kwarg(self):
199
a = torch.ones_like(x)
200
b = torch.ones_like(x)
203
t = torch.randn(2, 2)
207
Generate function with random ops and check if the result is the same
210
def test_random(self):
213
ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu]
215
new_val = random.choice(ops)(random.choice(vals))
219
fx_g = symbolic_trace(f)
220
fx_g.graph.eliminate_dead_code()
222
t = torch.randn(2, 2)
225
check(self, fx_g, t, -1, graph_input=True)
228
Test that banned list ban ops as expected.
231
def test_banned_list(self):
237
t = torch.randn(2, 2)
238
P_ban_add = P = CSEPass(banned_ops=[torch.ops.aten.add])
239
check(self, f, t, 0, P=P_ban_add)
242
def test_rand_like(self):
244
a = torch.rand_like(x)
245
b = torch.rand_like(x)
248
t = torch.randn(2, 2)
249
check(self, f, t, 0, check_val=False)
251
def test_rand_n(self):
257
t = torch.randn(2, 2)
258
check(self, f, t, 0, check_val=False)
261
if __name__ == "__main__":