3
from collections import defaultdict
4
from typing import Dict, List, Tuple
7
from torch.fx.passes.split_utils import split_by_tags
8
from torch.testing._internal.common_utils import TestCase
11
class TestFXSplit(TestCase):
12
def test_split_preserve_node_meta(self):
13
class TestModule(torch.nn.Module):
14
def forward(self, x, y):
19
gm = torch.fx.symbolic_trace(TestModule())
20
for node in gm.graph.nodes:
21
node.meta["name"] = node.name
22
if node.name == "add":
24
elif node.name == "mul":
26
elif node.name == "sub":
29
split_gm = split_by_tags(gm, ["a", "b", "c"])
30
for m in split_gm.children():
31
for n in m.graph.nodes:
33
self.assertIn("name", n.meta)
34
self.assertEqual(n.meta["name"], n.name)
37
for node in split_gm.graph.nodes:
38
if node.op == "placeholder":
39
self.assertIn("name", node.meta)
40
self.assertEqual(node.meta["name"], node.name)
43
class TestSplitByTags(TestCase):
44
class TestModule(torch.nn.Module):
45
def __init__(self) -> None:
47
self.linear1 = torch.nn.Linear(2, 3)
48
self.linear2 = torch.nn.Linear(4, 5)
49
self.linear3 = torch.nn.Linear(6, 7)
50
self.linear4 = torch.nn.Linear(8, 6)
61
v4 = torch.cat([v1, v2, v3])
62
return self.linear4(v4)
66
module: torch.nn.Module, tags: List[str]
67
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
69
Test simple gm consists of nodes with tag (only show call_module nodes here):
72
linear3, linear4 - tag: "green"
74
At the beginning we have:
81
split_gm = split_by_tags(gm, tags)
93
tag_node = defaultdict(list)
94
gm: torch.fx.GraphModule = torch.fx.symbolic_trace(module)
97
for node in gm.graph.nodes:
98
if "linear1" in node.name:
100
tag_node[tags[0]].append(node.name)
101
elif "linear2" in node.name:
103
tag_node[tags[1]].append(node.name)
106
if node.op == "call_module":
107
tag_node[tags[2]].append(node.name)
110
def test_split_by_tags(self) -> None:
111
tags = ["red", "blue", "green"]
112
module = TestSplitByTags.TestModule()
113
gm, tag_node = TestSplitByTags.trace_and_tag(module, tags)
114
split_gm, orig_to_split_fqn_mapping = split_by_tags(
115
gm, tags, return_fqn_mapping=True
119
for idx, (name, _) in enumerate(split_gm.named_children()):
123
f"split_gm has an incorrect submodule named {name}",
130
for sub_name, sub_graph_module in split_gm.named_children():
132
for node in sub_graph_module.graph.nodes:
133
if node.op != "call_module":
136
node.name == tag_node[f"{sub_name}"][node_idx],
138
f"{sub_name} has incorrectly include {node.name}",
144
orig_to_split_fqn_mapping,
146
"linear1": "red.linear1",
147
"linear2": "blue.linear2",
148
"linear3": "green.linear3",
149
"linear4": "green.linear4",
151
f"{orig_to_split_fqn_mapping=}",
155
class TestSplitOutputType(TestCase):
156
class TestModule(torch.nn.Module):
157
def __init__(self) -> None:
159
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
160
self.relu = torch.nn.ReLU()
162
def forward(self, x):
165
relu = self.relu(conv)
170
module: torch.nn.Module, inputs: torch.Tensor, tags: List[str]
171
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
173
Test simple gm consists of nodes with tag (only show call_module nodes here):
178
At the beginning we have:
184
split_gm = split_by_tags(gm, tags)
195
tag_node = defaultdict(list)
196
gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module()
198
for node in gm.graph.nodes:
199
if "conv" in node.name:
201
tag_node[tags[0]].append(node.name)
202
elif "mul" in node.name:
204
tag_node[tags[1]].append(node.name)
207
if node.op == "call_module":
208
tag_node[tags[2]].append(node.name)
211
def test_split_by_tags(self) -> None:
212
tags = ["red", "blue", "green"]
213
module = TestSplitOutputType.TestModule()
215
inputs = torch.randn((1, 3, 224, 224))
217
gm, tag_node = TestSplitOutputType.trace_and_tag(module, inputs, tags)
218
split_gm, orig_to_split_fqn_mapping = split_by_tags(
219
gm, tags, return_fqn_mapping=True
222
gm_output = module(inputs)
223
split_gm_output = split_gm(inputs)
225
self.assertTrue(type(gm_output) == type(split_gm_output))
226
self.assertTrue(torch.equal(gm_output, split_gm_output))