pytorch

Форк
0
/
test_cse_pass.py 
262 строки · 6.8 Кб
1
# Owner(s): ["oncall: fx"]
2

3
import random
4

5
import torch
6
from torch.fx import symbolic_trace
7
from torch.fx.experimental.proxy_tensor import make_fx
8
from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops
9
from torch.testing._internal.common_utils import run_tests, TestCase
10

11

12
banned_ops = get_CSE_banned_ops()
13
P_default = CSEPass(banned_ops=banned_ops)
14

15

16
def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
17
    """
18
    check if the CSE modified graph of ``f``
19
    1) has delta less nodes, and
20
    2) do not reduce the number of nodes further on a second pass, and
21
    3) modified returned is true only if the number of nodes decreases.
22

23
    Args:
24
        f: function to be checked
25
        t: tensor to be passed to f
26
        delta: an integer >= -1.
27
               If delta = -1, it only checks if the new graph has less or equal number of nodes
28
        check_val: if True, check if the output of f is correct
29
        graph_input: True is f is type GraphModule
30
        P: the pass to use. If None, use P_default
31
    """
32
    if graph_input:
33
        fx_g = f
34
    else:
35
        fx_g = make_fx(f)(t)
36

37
    if P is None:
38
        P = P_default
39

40
    res = P(fx_g)
41
    new_g = res.graph_module
42
    new_graph = new_g.graph
43
    modified = res.modified
44

45
    # the number of nodes decrease/ or stay the same
46
    old_num_nodes = len(fx_g.graph.nodes)
47
    new_num_nodes = len(new_graph.nodes)
48

49
    assert (
50
        new_num_nodes < old_num_nodes
51
    ) == modified, "modified should be True if the number of nodes decrease"
52

53
    if delta == -1:
54
        self.assertTrue(
55
            old_num_nodes >= new_num_nodes,
56
            (f"number of nodes increased {old_num_nodes}, {new_num_nodes}"),
57
        )
58
    else:
59
        self.assertTrue(
60
            old_num_nodes == new_num_nodes + delta,
61
            (
62
                f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
63
            ),
64
        )
65

66
    # a second pass should not reduce more nodes
67
    res = P(new_g)
68
    pass_2_graph = res.graph_module.graph
69
    pass_2_num_nodes = len(pass_2_graph.nodes)
70
    self.assertTrue(
71
        pass_2_num_nodes == new_num_nodes,
72
        (
73
            f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
74
        ),
75
    )
76

77
    # check correctness
78
    if check_val:
79
        true_result = fx_g(t)
80
        our_result = new_g(t)
81
        if true_result is None:  # both return None
82
            self.assertTrue(
83
                our_result is None, f"true result is None, CSE result is {our_result}"
84
            )
85
        else:  # results returned are the same
86
            self.assertTrue(
87
                torch.all(true_result == our_result),
88
                (f"results are different {true_result}, {our_result}"),
89
            )  # check results are the same
90

91

92
class TestCSEPass(TestCase):
93
    def test_nochange(self):
94
        def f(x):
95
            a = x + 1
96
            b = x + a
97
            a = x
98
            d = x + a
99
            return b + d
100

101
        t = torch.randn(2, 2)
102
        check(self, f, t, 0)
103

104
    def test_empty(self):
105
        def f(x):
106
            pass
107

108
        t = torch.randn(2, 2)
109
        check(self, f, t, 0)
110

111
    def test_immutable_list_type(self):
112
        def f(x):
113
            a = x.sum(dim=1)
114
            b = x.sum(dim=1)
115
            c = x.sum()
116
            d = x.sum()
117
            return a + b + c + d
118

119
        t = torch.randn(2, 2)
120
        check(self, f, t, 2)
121

122
    def test_immutable_list_multiple_entries(self):
123
        def f(x):
124
            a = x.sum(dim=[0, 1])
125
            b = x.sum(dim=[0, 1])
126
            c = x.sum(dim=1)
127
            d = x.sum(dim=1)
128
            return a + b + c + d
129

130
        t = torch.randn(2, 2)
131
        check(self, f, t, 2)
132

133
    def test_simple(self):
134
        def f(x):
135
            a = x.cos()
136
            b = x.cos()
137
            c = a + a
138
            d = b + b
139
            return c + d
140

141
        t = torch.randn(2, 2)
142
        check(self, f, t, 2)
143

144
    def test_simple_2(self):
145
        def f(x):
146
            a = x.cos().sin()
147
            b = x.cos().sin()
148
            c = a + a
149
            d = b + b
150
            return c + d
151

152
        t = torch.randn(1)
153
        check(self, f, t, 3)
154

155
    def test_two_args_default(self):
156
        def f(x):
157
            a = x.sum(dim=1)
158
            b = x.sum(dim=1, keepdim=False)
159
            c = x.sum(dim=1, keepdim=False)
160
            d = x.sum(dim=1)
161
            return a + b + c + d
162

163
        t = torch.randn(2, 2)
164
        check(self, f, t, 3)
165

166
    def test_two_args(self):
167
        def f(x):
168
            a = x.sum(dim=1)
169
            b = x.sum(dim=1, keepdim=True)
170
            c = x.sum(dim=1, keepdim=True)
171
            d = x.sum(dim=1)
172
            return a + b + c + d
173

174
        t = torch.randn(2, 2)
175
        check(self, f, t, 2)
176

177
    def test_simple_multiple_same_ops(self):
178
        def f(x):
179
            a = x.sum()
180
            b = x.sum()
181
            c = x.sum()
182
            d = x.sum()
183
            return a + b + c + d
184

185
        t = torch.randn(2, 2)
186
        check(self, f, t, 3)
187

188
    def test_nested_immutable_list_type(self):
189
        def f(x):
190
            a = torch.cat((x, x))
191
            b = torch.cat((x, x))
192
            return a + b
193

194
        t = torch.randn(2, 2)
195
        check(self, f, t, 1)
196

197
    def test_kwarg(self):
198
        def f(x):
199
            a = torch.ones_like(x)
200
            b = torch.ones_like(x)
201
            return a + b
202

203
        t = torch.randn(2, 2)
204
        check(self, f, t, 1)
205

206
    """
207
    Generate function with random ops and check if the result is the same
208
    """
209

210
    def test_random(self):
211
        def f(x):
212
            vals = [x]
213
            ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu]
214
            for _ in range(100):
215
                new_val = random.choice(ops)(random.choice(vals))
216
                vals.append(new_val)
217
            return vals[-1]
218

219
        fx_g = symbolic_trace(f)
220
        fx_g.graph.eliminate_dead_code()
221
        fx_g.recompile()
222
        t = torch.randn(2, 2)
223

224
        for _ in range(30):
225
            check(self, fx_g, t, -1, graph_input=True)
226

227
    """
228
    Test that banned list ban ops as expected.
229
    """
230

231
    def test_banned_list(self):
232
        def f(x):
233
            a = x + 1
234
            b = x + 1
235
            return a + b
236

237
        t = torch.randn(2, 2)
238
        P_ban_add = P = CSEPass(banned_ops=[torch.ops.aten.add])
239
        check(self, f, t, 0, P=P_ban_add)  # check that add is banned
240
        check(self, f, t, 1)  # check that add is not banned by default
241

242
    def test_rand_like(self):
243
        def f(x):
244
            a = torch.rand_like(x)
245
            b = torch.rand_like(x)
246
            return a + b
247

248
        t = torch.randn(2, 2)
249
        check(self, f, t, 0, check_val=False)
250

251
    def test_rand_n(self):
252
        def f(x):
253
            a = torch.randn(4)
254
            b = torch.randn(4)
255
            return a + b
256

257
        t = torch.randn(2, 2)
258
        check(self, f, t, 0, check_val=False)
259

260

261
if __name__ == "__main__":
262
    run_tests()
263

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.