pytorch

Форк
0
/
test_fx_split.py 
226 строк · 7.5 Кб
1
# Owner(s): ["module: fx"]
2

3
from collections import defaultdict
4
from typing import Dict, List, Tuple
5

6
import torch
7
from torch.fx.passes.split_utils import split_by_tags
8
from torch.testing._internal.common_utils import TestCase
9

10

11
class TestFXSplit(TestCase):
12
    def test_split_preserve_node_meta(self):
13
        class TestModule(torch.nn.Module):
14
            def forward(self, x, y):
15
                x = x + x
16
                y = y * y
17
                return x - y
18

19
        gm = torch.fx.symbolic_trace(TestModule())
20
        for node in gm.graph.nodes:
21
            node.meta["name"] = node.name
22
            if node.name == "add":
23
                node.tag = "a"
24
            elif node.name == "mul":
25
                node.tag = "b"
26
            elif node.name == "sub":
27
                node.tag = "c"
28

29
        split_gm = split_by_tags(gm, ["a", "b", "c"])
30
        for m in split_gm.children():
31
            for n in m.graph.nodes:
32
                if n.op != "output":
33
                    self.assertIn("name", n.meta)
34
                    self.assertEqual(n.meta["name"], n.name)
35

36
        # Validate that metadata is copied correctly for graph placeholder nodes
37
        for node in split_gm.graph.nodes:
38
            if node.op == "placeholder":
39
                self.assertIn("name", node.meta)
40
                self.assertEqual(node.meta["name"], node.name)
41

42

43
class TestSplitByTags(TestCase):
44
    class TestModule(torch.nn.Module):
45
        def __init__(self) -> None:
46
            super().__init__()
47
            self.linear1 = torch.nn.Linear(2, 3)
48
            self.linear2 = torch.nn.Linear(4, 5)
49
            self.linear3 = torch.nn.Linear(6, 7)
50
            self.linear4 = torch.nn.Linear(8, 6)
51

52
        def forward(
53
            self,
54
            x1: torch.Tensor,
55
            x2: torch.Tensor,
56
            x3: torch.Tensor,
57
        ) -> torch.Tensor:
58
            v1 = self.linear1(x1)
59
            v2 = self.linear2(x2)
60
            v3 = self.linear3(x3)
61
            v4 = torch.cat([v1, v2, v3])
62
            return self.linear4(v4)
63

64
    @staticmethod
65
    def trace_and_tag(
66
        module: torch.nn.Module, tags: List[str]
67
    ) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
68
        """
69
        Test simple gm consists of nodes with tag (only show call_module nodes here):
70
            linear1 - tag: "red"
71
            linear2 - tag: "blue"
72
            linear3, linear4 - tag: "green"
73

74
        At the beginning we have:
75
            gm:
76
                linear1
77
                linear2
78
                linear3
79
                linear4
80

81
        split_gm = split_by_tags(gm, tags)
82

83
        Then we have:
84
            split_gm:
85
                red:
86
                    linear1
87
                blue:
88
                    linear2
89
                green:
90
                    linear3
91
                    linear4
92
        """
93
        tag_node = defaultdict(list)
94
        gm: torch.fx.GraphModule = torch.fx.symbolic_trace(module)
95

96
        # Add tag to all nodes and build dictionary record tag to call_module nodes
97
        for node in gm.graph.nodes:
98
            if "linear1" in node.name:
99
                node.tag = tags[0]
100
                tag_node[tags[0]].append(node.name)
101
            elif "linear2" in node.name:
102
                node.tag = tags[1]
103
                tag_node[tags[1]].append(node.name)
104
            else:
105
                node.tag = tags[2]
106
                if node.op == "call_module":
107
                    tag_node[tags[2]].append(node.name)
108
        return gm, tag_node
109

110
    def test_split_by_tags(self) -> None:
111
        tags = ["red", "blue", "green"]
112
        module = TestSplitByTags.TestModule()
113
        gm, tag_node = TestSplitByTags.trace_and_tag(module, tags)
114
        split_gm, orig_to_split_fqn_mapping = split_by_tags(
115
            gm, tags, return_fqn_mapping=True
116
        )
117
        # Ensure split_gm has (and only has) ordered submodules named
118
        # red_0, blue_1, green_2
119
        for idx, (name, _) in enumerate(split_gm.named_children()):
120
            if idx < len(tags):
121
                self.assertTrue(
122
                    name == tags[idx],
123
                    f"split_gm has an incorrect submodule named {name}",
124
                )
125

126
        # Ensure each submodule has expected (ordered) call_module node(s).
127
        # For example, a submodule named split_gm.red_0 has (and only has) linear1;
128
        # split_gm.green_2 has (and only has) linear3 and linear4 with order
129
        sub_graph_idx = 0
130
        for sub_name, sub_graph_module in split_gm.named_children():
131
            node_idx = 0
132
            for node in sub_graph_module.graph.nodes:
133
                if node.op != "call_module":
134
                    continue
135
                self.assertTrue(
136
                    node.name == tag_node[f"{sub_name}"][node_idx],
137
                    # pyre-fixme[61]: `name` is undefined, or not always defined.
138
                    f"{sub_name} has incorrectly include {node.name}",
139
                )
140
                node_idx += 1
141
            sub_graph_idx += 1
142

143
        self.assertEqual(
144
            orig_to_split_fqn_mapping,
145
            {
146
                "linear1": "red.linear1",
147
                "linear2": "blue.linear2",
148
                "linear3": "green.linear3",
149
                "linear4": "green.linear4",
150
            },
151
            f"{orig_to_split_fqn_mapping=}",
152
        )
153

154

155
class TestSplitOutputType(TestCase):
156
    class TestModule(torch.nn.Module):
157
        def __init__(self) -> None:
158
            super().__init__()
159
            self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
160
            self.relu = torch.nn.ReLU()
161

162
        def forward(self, x):
163
            conv = self.conv(x)
164
            conv = conv * 0.5
165
            relu = self.relu(conv)
166
            return relu
167

168
    @staticmethod
169
    def trace_and_tag(
170
        module: torch.nn.Module, inputs: torch.Tensor, tags: List[str]
171
    ) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
172
        """
173
        Test simple gm consists of nodes with tag (only show call_module nodes here):
174
            conv - tag: "red"
175
            mul - tag: "blue"
176
            relu - tag: "green"
177

178
        At the beginning we have:
179
            gm:
180
                conv
181
                mul
182
                relu
183

184
        split_gm = split_by_tags(gm, tags)
185

186
        Then we have:
187
            split_gm:
188
                red:
189
                    conv
190
                blue:
191
                    mul
192
                green:
193
                    relu
194
        """
195
        tag_node = defaultdict(list)
196
        gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module()
197
        # Add tag to all nodes and build dictionary record tag to call_module nodes
198
        for node in gm.graph.nodes:
199
            if "conv" in node.name:
200
                node.tag = tags[0]
201
                tag_node[tags[0]].append(node.name)
202
            elif "mul" in node.name:
203
                node.tag = tags[1]
204
                tag_node[tags[1]].append(node.name)
205
            else:
206
                node.tag = tags[2]
207
                if node.op == "call_module":
208
                    tag_node[tags[2]].append(node.name)
209
        return gm, tag_node
210

211
    def test_split_by_tags(self) -> None:
212
        tags = ["red", "blue", "green"]
213
        module = TestSplitOutputType.TestModule()
214

215
        inputs = torch.randn((1, 3, 224, 224))
216

217
        gm, tag_node = TestSplitOutputType.trace_and_tag(module, inputs, tags)
218
        split_gm, orig_to_split_fqn_mapping = split_by_tags(
219
            gm, tags, return_fqn_mapping=True
220
        )
221

222
        gm_output = module(inputs)
223
        split_gm_output = split_gm(inputs)
224

225
        self.assertTrue(type(gm_output) == type(split_gm_output))
226
        self.assertTrue(torch.equal(gm_output, split_gm_output))
227

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.