pytorch

Форк
0
/
test_profiler.py 
286 строк · 10.1 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import sys
5

6
import torch
7
from torch.testing._internal.common_utils import skipIfTorchDynamo
8

9

10
# Make the helper files in test/ importable
11
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12
sys.path.append(pytorch_test_dir)
13
from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
14

15

16
if __name__ == "__main__":
17
    raise RuntimeError(
18
        "This test file is not meant to be run directly, use:\n\n"
19
        "\tpython test/test_jit.py TESTNAME\n\n"
20
        "instead."
21
    )
22

23

24
@skipIfTorchDynamo()
25
class TestProfiler(JitTestCase):
26
    def setUp(self):
27
        self.prev_exec = torch._C._jit_set_profiling_executor(True)
28
        self.prev_profiling = torch._C._get_graph_executor_optimize(True)
29
        self.inline_autodiff = torch._C._debug_set_autodiff_subgraph_inlining(False)
30
        self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
31
        self.can_fuse_on_cpu = torch._C._jit_can_fuse_on_cpu()
32
        torch._C._jit_set_texpr_fuser_enabled(True)
33
        torch._C._jit_override_can_fuse_on_cpu(True)
34
        self.default_dtype = torch.get_default_dtype()
35
        self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True)
36
        torch.set_default_dtype(torch.double)
37
        self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
38
        torch._C._debug_set_fusion_group_inlining(False)
39
        self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
40
        torch._C._jit_set_te_must_use_llvm_cpu(False)
41

42
    def tearDown(self):
43
        torch._C._jit_set_profiling_executor(self.prev_exec)
44
        torch._C._get_graph_executor_optimize(self.prev_profiling)
45
        torch._C._debug_set_autodiff_subgraph_inlining(self.inline_autodiff)
46
        torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
47
        torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu)
48
        torch.set_default_dtype(self.default_dtype)
49
        torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
50
        torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
51
        torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
52

53
    def test_tensor_type_not_determined_by_inputs(self):
54
        @torch.jit.script
55
        def scalar_type_input(x, y, z):
56
            return x + y + 4 + z.item()
57

58
        x = torch.tensor([2, 2])
59
        scalar_type_input(x, x, torch.tensor(1))
60
        scalar_type_input(x, x, torch.tensor(1))
61
        scalar_type_input(x, x, torch.tensor(1.0))
62
        g = torch.jit.last_executed_optimized_graph()
63

64
        # item & add should not get pulled into the fusion group -
65
        # we expect to see Fusion Group (item / add) Fusion Group in ir dump
66
        FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next(
67
            "Tensor = aten::add"
68
        ).check("TensorExpr").run(g)
69

70
        @torch.jit.script
71
        def non_const_dtype(x, y, cond: bool):
72
            dtype = torch.int16 if cond else torch.int32
73
            return (x + y + 3).sum(dtype=dtype)
74

75
        non_const_dtype(x, x, True)
76
        non_const_dtype(x, x, True)
77
        g = torch.jit.last_executed_optimized_graph()
78
        # because dtype is non-const, sum should not get pulled into the Fusion Group
79
        FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(
80
            g
81
        )
82

83
    def test_specialize_backward(self):
84
        def test_fuse(a, b):
85
            c = a * b
86
            d = c * b
87
            return d
88

89
        test_fuse.__disable_jit_function_caching__ = True
90

91
        scripted_f = torch.jit.script(test_fuse)
92
        x = torch.ones(1, requires_grad=True)
93
        y = torch.ones(1, requires_grad=True)
94
        scripted_f(x, y)
95
        b = scripted_f(x, y)
96
        warmup_backward(b)
97
        g = torch.jit.last_executed_optimized_graph()
98
        # Backward has an if node guarding specializations,
99
        # within the if node true block there is only one if node
100
        # that guards a tensorexpr group
101
        optimized_block = next(g.findNode("prim::If").blocks())
102
        if_nodes = list(optimized_block.findAllNodes("prim::If"))
103

104
        self.assertEqual(len(if_nodes), 1)
105
        FileCheck().check("Group[Subgraph").run(str(if_nodes[0]))
106
        # no broadcasts occurred, sum_to_size have been specialized out
107
        self.assertIsNone(optimized_block.findNode("aten::_grad_sum_to_size"))
108

109
        broadcast_f = torch.jit.script(test_fuse)
110
        x = torch.ones([2, 2], requires_grad=True)
111
        y = torch.ones([1], requires_grad=True)
112
        broadcast_f(x, y)
113
        b = broadcast_f(x, y)
114
        b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True)
115
        b.backward(torch.ones([2, 2], dtype=torch.float))
116
        # warmup_backward(b, torch.ones([2, 2], dtype=torch.float))
117
        g = torch.jit.last_executed_optimized_graph()
118
        optimized_block = next(g.findNode("prim::If").blocks())
119
        # broadcasts occurred, currently expect to see aten::_grad_sum_to_size
120
        self.assertIsNotNone(optimized_block.findNode("aten::_grad_sum_to_size"))
121

122
    def test_specialized_types(self):
123
        @torch.jit.script
124
        def test_fuse(a, b):
125
            c = a * b
126
            d = c * b
127
            return d
128

129
        x = torch.tensor([0.5])
130
        for _ in range(3):
131
            test_fuse(x, x)
132

133
        g = torch.jit.last_executed_optimized_graph()
134
        # Types should remain specialized for typecheck outputs & fusion outputs
135
        FileCheck().check("Double(").check_same("prim::TypeCheck").check_same(
136
            "\n"
137
        ).check("Double").check_same("TensorExpr").run(g)
138

139
        # other outputs should not be specialized
140
        FileCheck().check("Tensor = prim::If").run(g)
141

142
    def test_aliasing_merge(self):
143
        @torch.jit.script
144
        def foo(a, b):
145
            c = a * b
146
            d = c * b
147
            d.add_(b)
148
            e = d * b
149
            return d + e
150

151
        x = torch.ones(1)
152
        y = torch.ones(1)
153
        foo(x, y)
154
        b = foo(x, y)
155
        g = torch.jit.last_executed_optimized_graph()
156
        self.assertEqual(len(list(g.findAllNodes("prim::TypeCheck"))), 2)
157
        FileCheck().check("TensorExpr").check("aten::add_").check("TensorExpr").run(g)
158

159
    def test_use_not_profiled(self):
160
        def foo(t1, t2, t3, t4, t: float):
161
            h = t1 + t2 + t3 + t4
162
            if t > 0.5:
163
                # Putting a use of t1 in a never-executed conditional prevents
164
                return t1 + 1
165
            return h
166

167
        t = torch.rand(8, dtype=torch.float)
168

169
        foo_script = torch.jit.script(foo)
170
        for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
171
            foo_script(t, t, t, t, 0.1)
172

173
        self.assertEqual(foo(t, t, t, t, 0.1), foo_script(t, t, t, t, 0.1))
174
        g = torch.jit.last_executed_optimized_graph()
175
        # all adds fused
176
        FileCheck().check("graph").check_not("aten::add").check("prim::If").run(g)
177

178
    def test_not_fusing_scalar_ops(self):
179
        @torch.jit.script
180
        def foo(x: int, y: int):
181
            return x + y + 2 + 4 + 5 + 6
182

183
        foo(1, 2)
184
        foo(2, 3)
185
        g = torch.jit.last_executed_optimized_graph()
186
        FileCheck().check_not("TensorExpr").run(g)
187

188
    def test_not_optimizing_property(self):
189
        @torch.jit.script
190
        def foo(x, y):
191
            return x + y + 1 + 2 + 3, x.size()
192

193
        x = torch.ones(1)
194
        foo(x, x)
195
        foo(x, x)
196
        g = torch.jit.last_executed_optimized_graph()
197
        FileCheck().check("aten::size").run(g)
198
        x = torch.ones([2, 3, 5])
199
        self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size()))
200

201
    def test_fallback_graph_not_specialized(self):
202
        @torch.jit.script
203
        def foo(a, b):
204
            c = a * b
205
            d = c * b
206
            e = d * b
207
            return d + e
208

209
        x = torch.ones(1)
210
        y = torch.ones(1)
211
        foo(x, y)
212
        foo(x, y)
213
        g = torch.jit.last_executed_optimized_graph()
214
        FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(
215
            g
216
        )
217

218
    def test_autograd_fallback_graph(self):
219
        @torch.jit.script
220
        def foo(a, b):
221
            c = a * b
222
            d = c * b
223
            e = d * b
224
            return d + e
225

226
        x = torch.ones(1, requires_grad=True)
227
        y = torch.ones(1, requires_grad=True)
228
        foo(x, y)
229
        b = foo(x, y)
230
        b.backward(torch.ones([1], dtype=torch.float), retain_graph=True)
231
        b.backward(torch.ones([1], dtype=torch.float))
232

233
        g = torch.jit.last_executed_optimized_graph()
234
        FileCheck().check("fallback_function").check_next("CallFunction").run(g)
235

236
    def test_tensor_constant(self):
237
        def foo(a, b):
238
            return a + b + torch.tensor([2])
239

240
        x = torch.ones(1, requires_grad=False)
241
        foo_script = torch.jit.script(foo)
242
        foo_script(x, x)
243
        foo_script(x, x)
244

245
        self.assertEqual(foo_script(x, x), foo(x, x))
246
        g = torch.jit.last_executed_optimized_graph()
247
        FileCheck().check_count("aten::add", 2, exactly=True).run(g)
248

249
    def test_local_fusion_strategy(self):
250
        @torch.jit.script
251
        def foo(x):
252
            return x + x + x
253

254
        torch.jit.set_fusion_strategy([("STATIC", 1)])
255
        for _ in range(3):
256
            foo(torch.rand([10]))
257

258
        torch.jit.set_fusion_strategy([("STATIC", 10)])
259

260
        for i in range(10):
261
            foo(torch.rand([i]))
262
            foo(torch.rand([i]))
263

264
        g = torch.jit.last_executed_optimized_graph()
265
        FileCheck().check_count(":TensorExprGroup", 2, exactly=True).run(g)
266

267
    def test_iterative_fusion(self):
268
        @torch.jit.script
269
        def foo(a, b, c, d):
270
            a = a + b
271
            b.add_(3)
272
            c = c + b + d
273
            a = a + 1
274
            return a, c
275

276
        x = torch.ones(1, requires_grad=False)
277
        foo(x, x, x, x)
278
        foo(x, x, x, x)
279

280
        # when we iterate through the block, we will start
281
        # by fusing a = a + b with a = a + 1
282
        # if we were to continue iteration from that fusion point,
283
        # would miss the fusion opportunity of c = c + d + b
284

285
        g = torch.jit.last_executed_optimized_graph()
286
        self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2)
287

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.