pytorch
1# Owner(s): ["module: fx"]
2
3import os4import tempfile5
6import torch7from torch.fx import subgraph_rewriter, symbolic_trace8from torch.fx.passes.graph_transform_observer import GraphTransformObserver9from torch.testing._internal.common_utils import TestCase10
11
12if __name__ == "__main__":13raise RuntimeError(14"This test file is not meant to be run directly, use:\n\n"15"\tpython test/test_fx.py TESTNAME\n\n"16"instead."17)18
19
20class TestGraphTransformObserver(TestCase):21def test_graph_transform_observer(self):22class M(torch.nn.Module):23def forward(self, x):24val = torch.neg(x)25return torch.add(val, val)26
27def pattern(x):28return torch.neg(x)29
30def replacement(x):31return torch.relu(x)32
33traced = symbolic_trace(M())34
35log_url = tempfile.mkdtemp()36
37with GraphTransformObserver(traced, "replace_neg_with_relu", log_url) as ob:38subgraph_rewriter.replace_pattern(traced, pattern, replacement)39
40self.assertTrue("relu" in ob.created_nodes)41self.assertTrue("neg" in ob.erased_nodes)42
43current_pass_count = GraphTransformObserver.get_current_pass_count()44
45self.assertTrue(46os.path.isfile(47os.path.join(48log_url,49f"pass_{current_pass_count}_replace_neg_with_relu_input_graph.dot",50)51)52)53self.assertTrue(54os.path.isfile(55os.path.join(56log_url,57f"pass_{current_pass_count}_replace_neg_with_relu_output_graph.dot",58)59)60)61