1
# Owner(s): ["module: fx"]
3
from typing import Set, Type
7
from torch.testing._internal.common_utils import TestCase
10
class TestDCE(TestCase):
11
def _has_nodes_without_users(self, m: torch.fx.GraphModule):
12
for node in m.graph.nodes:
15
if len(node.users) == 0:
19
def _get_num_placeholders(self, m: torch.fx.GraphModule) -> int:
21
for node in m.graph.nodes:
22
if node.op == "placeholder":
26
def _run_dce_and_test(
29
expect_dce_changes: bool,
30
modules_to_be_leafs: Set[Type] = None,
32
class TestTracer(torch.fx.Tracer):
33
def is_leaf_module(self, m, qualname):
34
if modules_to_be_leafs and type(m) in modules_to_be_leafs:
36
return super().trace(m, qualname)
38
traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m))
39
print(str(traced.graph))
41
# Verify there are nodes without users (if expected).
42
has_nodes_without_users = self._has_nodes_without_users(traced)
43
if expect_dce_changes:
44
self.assertTrue(has_nodes_without_users)
46
self.assertFalse(has_nodes_without_users)
48
# Get the original number of placeholders to verify it doesn't change
50
orig_num_phs = self._get_num_placeholders(traced)
51
changed = traced.graph.eliminate_dead_code()
53
self.assertTrue(changed if expect_dce_changes else not changed)
55
# Verify there are no nodes without users after DCE is run.
56
self.assertFalse(self._has_nodes_without_users(traced))
57
new_num_phs = self._get_num_placeholders(traced)
58
self.assertEqual(orig_num_phs, new_num_phs)
61
# Make sure we run and get the same results before/after DCE.
62
inputs = [torch.tensor([1.5])] * new_num_phs
63
self.assertTrue(torch.equal(m(*inputs), traced(*inputs)))
65
def test_simple(self):
67
Tests that a single node in the graph is DCE'd correctly.
70
class TestModule(torch.nn.Module):
73
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
77
return x + self.attr_1
79
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
81
def test_dead_chain(self):
83
Tests that a chain of two nodes in the graph are DCE'd correctly.
86
class TestModule(torch.nn.Module):
89
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
94
return x + self.attr_1
96
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
98
def test_dead_getattr(self):
100
Tests that a getatrr in the graph is DCE'd correctly.
103
class TestModule(torch.nn.Module):
106
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
108
def forward(self, x):
113
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
115
def test_dead_placeholder(self):
117
Tests that a placeholder in the graph is not DCE'd, as that would change
118
the function signature.
121
class TestModule(torch.nn.Module):
122
def forward(self, x, y):
125
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
127
def test_dead_placeholder_with_user(self):
129
Tests that a placeholder in the graph is not DCE'd, as that would change
130
the function signature. Also verifies that a dead node that uses the
131
placeholder is DCE'd.
135
class TestModule(torch.nn.Module):
136
def forward(self, x, y):
140
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
142
def test_keep_module_with_side_effects(self):
144
Test that DCE doesn't remove a module if it's specified as having side effects.
147
class ReLUImpure(torch.nn.ReLU):
150
class TestModule(torch.nn.Module):
153
self.relu = ReLUImpure()
155
def forward(self, a: torch.Tensor) -> torch.Tensor:
159
self._run_dce_and_test(
160
TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure}
163
def test_keep_torch_assert(self):
165
Test that DCE doesn't remove torch._assert since it has side effects.
168
class TestModule(torch.nn.Module):
169
def forward(self, a: torch.Tensor) -> torch.Tensor:
170
torch._assert(torch.equal(a, a), "a must equal a")
173
# Note: Don't need to specify torch._assert as having side effects
174
# because it's known to.
175
self._run_dce_and_test(TestModule(), expect_dce_changes=False)