pytorch

Форк
0
/
test_jit_fuser.py 
991 строка · 40.3 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import unittest
4
import os
5
import sys
6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
from torch.testing import FileCheck
10
from unittest import skipIf
11

12
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
13
    enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
14
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
15
    RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward
16
from textwrap import dedent
17
from itertools import product, permutations
18
from torch.testing._internal.common_cuda import with_tf32_off
19

20
from test_jit import backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \
21
    LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
22

23
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
24
    torch._C._jit_set_profiling_executor(True)
25
    torch._C._jit_set_profiling_mode(True)
26

27

28
def strip_profiling_nodes(nodes):
29
    profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'}
30
    return [n for n in nodes if n.kind() not in profiling_opcodes]
31

32

33
def warmup_forward(f, *args):
34
    profiling_count = 2
35
    for i in range(profiling_count):
36
        results = f(*args)
37

38
    return results
39

40

41
@skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "skip due to SIGIOT failures, #67646")
42
class TestFuser(JitTestCase):
43
    def assertAllFused(self, graph, except_for=()):
44

45
        diff_graphs = [n for n in graph.nodes() if n.kind() == 'prim::DifferentiableGraph']
46
        if len(diff_graphs) > 0:
47
            self.assertEqual(len(diff_graphs), 1)
48
            graph = diff_graphs[0].g('Subgraph')
49

50
        allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::BailoutTemplate',
51
                         'prim::BailOut', 'prim::TupleConstruct'} | set(except_for)
52
        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
53
                        f'got {graph}')
54
        self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
55

56
    def _test_fused_abs(self, device='cpu'):
57
        def func(x):
58
            return x.abs() * 2
59

60
        a = torch.randn(5, device=device)
61
        scripted = self.checkScript(func, (a,))
62
        self.assertAllFused(scripted.graph_for(a))
63

64
    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
65
    @enable_cpu_fuser
66
    def test_abs_cpu(self):
67
        self._test_fused_abs()
68

69
    @unittest.skipIf(not IS_WINDOWS, "This is meant to be Windows-specific")
70
    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
71
    @enable_cpu_fuser
72
    def test_abs_cpu_unicode_temp_dir(self):
73
        with TemporaryDirectoryName(suffix='中文') as dname:
74
            shell_env = os.environ.copy()
75
            shell_env['TMP'] = dname
76
            cmd = [sys.executable, os.path.basename(__file__), type(self).__name__ + '.test_abs_cpu']
77
            legacy_jit_flag = '--jit-executor=legacy'
78
            for v in sys.argv:
79
                if v == legacy_jit_flag:
80
                    cmd.append(legacy_jit_flag)
81
            return_code = shell(cmd, cwd=os.path.dirname(__file__), env=shell_env)
82
            self.assertEqual(return_code, 0)
83

84
    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
85
    def test_abs_cuda(self):
86
        self._test_fused_abs(device="cuda")
87

88
    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
89
    def test_zero_element_tensors(self):
90
        def decode(sin_t, cos_t):
91
            theta = torch.atan2(sin_t.float(), cos_t.float())
92
            return theta
93

94
        sin = torch.zeros(0, device="cuda")
95
        cos = torch.zeros(0, device="cuda")
96
        inputs = [sin, cos]
97
        ge = self.checkScript(decode, inputs)
98

99
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
100
    def test_arg_configurations_smoke_cuda(self):
101
        # A smoke test to make sure we won't use the same kernel for contiguous
102
        # and non-contiguous arguments.
103
        # TODO: add optionally enabled debug counters to the fuser to verify
104
        #       that we really can tell the difference between configurations
105
        def f(x, y):
106
            z1, z2 = (x + y).chunk(2, dim=1)
107
            return z1 * z2
108

109
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
110
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
111
        traced_f = torch.jit.trace(f, (x, y,))
112
        self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
113

114
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
115
    def test_broadcast_cuda(self):
116
        def scaleshift(x, scale, shift):
117
            return x * scale + shift
118

119
        inputs = [
120
            torch.randn(4, 4, dtype=torch.float, device='cuda'),
121
            torch.randn(4, dtype=torch.float, device='cuda'),
122
            torch.randn(4, dtype=torch.float, device='cuda'),
123
        ]
124
        ge = self.checkTrace(scaleshift, inputs)
125
        self.assertAllFused(ge.graph_for(*inputs))
126

127
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
128
    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no bfloat support with profiling on")
129
    def test_cuda_bfloat16(self):
130
        def foo(x, y):
131
            return (x + y).relu()
132
        m = torch.jit.script(foo)
133
        x = torch.randn(65536).cuda().bfloat16()
134
        y = torch.randn_like(x)
135
        self.assertAllFused(m.graph_for(x, y))
136

137
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
138
    @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
139
    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
140
    def test_cuda_half(self):
141
        x = torch.randn(4, 4, dtype=torch.half, device='cuda')
142
        y = torch.randn(4, 4, dtype=torch.half, device='cuda')
143

144
        funcs = [
145
            self.fn_test_comparison_gt_lt,
146
            self.fn_test_relu,
147
            self.fn_test_exp
148
        ]
149

150
        # Note: Non fused inputs must be float to prevent loss of precision
151
        inputs = (x.float(), y.float())
152
        fusion_inputs = (x, y)
153
        for fn in funcs:
154
            local_inputs = [t.clone().requires_grad_() for t in inputs]
155
            local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
156

157
            # Verifies outputs
158
            fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False)
159
            outputs = fn(*local_inputs)
160
            fusion_outputs = fusion(*local_fusion_inputs)
161
            outputs_half = [t.half() for t in outputs]
162
            self.assertEqual(outputs_half, fusion_outputs)
163

164
            # Verifies gradients
165
            for output, fusion_output in zip(outputs_half, fusion_outputs):
166
                grads = torch.autograd.grad(
167
                    output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
168
                fusion_grads = torch.autograd.grad(
169
                    fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
170
                grads_half = [t.half() for t in grads]
171
                self.assertEqual(grads_half, fusion_grads)
172

173
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
174
    def test_checks_cat_inputs(self):
175
        # We shouldn't treat cat nodes as broadcasting. All their inputs
176
        # need to be checked for having the same map size, before we can
177
        # run the kernel.
178
        def f(x, y):
179
            return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
180

181
        # NOTE: y is broadcastable to x, but output of f(x, y) should have
182
        # shape 3x4, and not 4x4.
183
        x = torch.randn(2, 4, dtype=torch.float, device='cuda')
184
        y = torch.randn(1, 4, dtype=torch.float, device='cuda')
185

186
        scripted = self.checkScript(f, (x, y))
187
        self.assertAllFused(scripted.graph_for(x, y))
188

189
    @unittest.skipIf(not RUN_CUDA, "No CUDA")
190
    def test_remainder_cuda(self):
191
        def cuda_rem(x, y):
192
            return 1 + torch.remainder(x, y) - 1
193

194
        a = torch.rand([512], dtype=torch.float).cuda()
195
        b = torch.rand([512], dtype=torch.float).cuda()
196
        inputs = [a, b]
197
        ge = self.checkScript(cuda_rem, inputs)
198
        graph = ge.graph_for(*inputs)
199
        self.assertAllFused(graph)
200

201
    @unittest.skipIf(not RUN_CUDA, "No CUDA")
202
    def test_chunk_cuda(self):
203
        def fn(x):
204
            a, b, c = x.chunk(3, 1)
205
            return a * b + c
206

207
        inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
208

209
        ge = self.checkScript(fn, inputs)
210
        graph = ge.graph_for(*inputs)
211
        self.assertAllFused(graph)
212
        FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph))
213

214
    @staticmethod
215
    def _test_chunk_correctness(self, device='cpu'):
216
        def chunk_4_0(x):
217
            x0, x1, x2, x3 = x.chunk(4, 0)
218
            return x0 + x1 + x2 + x3
219

220
        def chunk_4_1(x):
221
            x0, x1, x2, x3 = x.chunk(4, 1)
222
            return x0 + x1 + x2 + x3
223

224
        def chunk_4_last(x):
225
            x0, x1, x2, x3 = x.chunk(4, 2)
226
            return x0 + x1 + x2 + x3
227

228
        fns = [chunk_4_0, chunk_4_1, chunk_4_last]
229
        tensors = [
230
            # splitSize = 1
231
            torch.randn(4, 4, 4, dtype=torch.float, device=device),
232

233
            # contiguous case
234
            torch.randn(12, 8, 16, dtype=torch.float, device=device),
235

236
            # non-contiguous case
237
            torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
238
        ]
239

240
        for tensor in tensors:
241
            for fn in fns:
242
                self.checkScript(fn, [tensor])
243

244
    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
245
    @enable_cpu_fuser
246
    def test_chunk_correctness(self):
247
        return self._test_chunk_correctness(self, 'cpu')
248

249
    @unittest.skipIf(not RUN_CUDA, "No CUDA")
250
    def test_chunk_correctness_cuda(self):
251
        return self._test_chunk_correctness(self, 'cuda')
252

253
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
254
    def test_chunk_distributes_cuda(self):
255
        def f(x, y):
256
            z1, z2 = (x + y).chunk(2, dim=1)
257
            return z1 * z2
258

259
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
260
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
261

262
        ge = self.checkTrace(f, (x, y))
263
        graph = ge.graph_for(x, y)
264
        FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_') \
265
            .check_count('ConstantChunk', 2, exactly=True).run(str(graph))
266

267
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
268
    def test_chunk_motion_deduplicates_inputs(self):
269
        def func1(x):
270
            z = x * x
271
            z0, z1 = z.chunk(2)
272
            return z0 * z1
273

274
        def func2(x):
275
            z = x * x * x
276
            z0, z1 = z.chunk(2)
277
            return z0 * z1
278

279
        inputs = [
280
            torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
281
        ]
282
        for func in [func1, func2]:
283
            module = self.checkScript(func, inputs)
284
            forward_graph = module.graph_for(*inputs)
285
            self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
286
            fusion_group = list(forward_graph.nodes())[-1]
287
            self.assertEqual(len(list(fusion_group.inputs())), 1)
288

289
    @unittest.skipIf(not RUN_CUDA, "No CUDA")
290
    def test_chunk_multiple_cuda(self):
291
        # The arguments are intentionally used out of order as a test to see
292
        # if the fusion compiler adds extra args in the correct order
293
        def fn(s, x, y, z):
294
            z1, z2 = z.chunk(2, 2)
295
            x1, x2, x3 = x.chunk(3, 1)
296
            y1, y2 = y.chunk(2, 0)
297
            return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
298

299
        inputs = [
300
            torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
301
            torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
302
            torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
303
            torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
304
        ]
305

306
        ge = self.checkScript(fn, inputs)
307
        self.assertAllFused(ge.graph_for(*inputs))
308

309
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
310
    def test_minmax(self):
311
        def tmax(a, b):
312
            return torch.max(2 * a, b)
313

314
        def tmin(a, b):
315
            return torch.min(2 * a, b)
316

317
        a = torch.randn(4, 4, dtype=torch.float, device="cuda")
318
        b = torch.randn(4, 4, dtype=torch.float, device="cuda")
319
        nan = torch.tensor(float('nan'), dtype=torch.float, device="cuda")
320

321
        for f, inputs in product(
322
                (tmax, tmin),
323
                ([a, b], [a, nan], [b, nan])):
324
            s = self.checkScript(f, inputs)
325
            self.assertAllFused(s.graph_for(*inputs))
326

327
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
328
    def test_clamp(self):
329
        def func2(a, b):
330
            return torch.clamp(a + b, min=0, max=2)
331

332
        def funcInf(a, b):
333
            return torch.clamp(a + b, min=0, max=float('inf'))
334

335
        def funcOptMin(a, b):
336
            return torch.clamp(a + b, max=2)
337

338
        def funcOptMax(a, b):
339
            return torch.clamp(a + b, min=0)
340

341
        a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
342
        b = torch.randn(4, 4, dtype=torch.float, device='cuda')
343
        nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')
344

345
        funcs = (func2, funcInf, funcOptMin, funcOptMax)
346
        for f, inputs in product(funcs, [[a, b], [a, nan]]):
347
            f.__disable_jit_function_caching__ = True
348
            inp1, inp2 = inputs
349
            s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
350
            self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
351
            c = s(inp1, inp2)
352
            with enable_profiling_mode_for_profiling_tests():
353
                warmup_backward(c.sum())
354
            graph = backward_graph(s)
355
            self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
356

357
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
358
    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
359
    def test_dropout(self):
360
        def func(x):
361
            x = torch.nn.functional.dropout(x)
362
            return torch.nn.functional.relu(x)
363

364
        a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
365
        s = torch.jit.script(func)
366
        c = s(a)
367
        c = s(a)
368
        warmup_backward(c.sum())
369
        # skip_check to skip extra bailout nodes in between
370
        graph = backward_graph(s, skip_check=True)
371
        self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'})
372

373
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
374
    def test_comparison_eq_ne(self):
375
        def f(x, y):
376
            mask = (x == 0).type_as(x)
377
            z = x * mask + y
378
            mask = (x != 0).type_as(x)
379
            z = z * mask + y
380
            return z
381

382
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
383
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
384

385
        ge = self.checkTrace(f, (x, y))
386
        self.assertAllFused(ge.graph_for(x, y))
387

388
    @staticmethod
389
    def fn_test_comparison_gt_lt(x, y):
390
        mask = (x > 0).type_as(x)
391
        z = x * mask + y
392
        mask = (x < 0).type_as(x)
393
        z = z * mask + y
394
        return z
395

396
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
397
    def test_comparison_gt_lt_cuda(self):
398
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
399
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
400

401
        ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
402
        self.assertAllFused(ge.graph_for(x, y))
403

404
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
405
    def test_comparison_ge_le_cuda(self):
406
        def f(x, y):
407
            mask = (x >= 0).type_as(x)
408
            z = x * mask + y
409
            mask = (x <= 0).type_as(x)
410
            z = z * mask + y
411
            return z
412

413
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
414
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
415

416
        ge = self.checkTrace(f, (x, y))
417
        self.assertAllFused(ge.graph_for(x, y))
418
        x.requires_grad_(True)
419
        y.requires_grad_(True)
420
        self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
421
                                                            "aten::_size_if_not_equal"))
422

423
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
424
    def test_addcmul_cuda(self):
425
        t = torch.randn(1, 4, dtype=torch.float, device='cuda')
426
        t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
427
        t2 = torch.randn(1, 4, dtype=torch.float, device='cuda')
428

429
        def foo(t, t1, t2):
430
            return t.addcmul(t + 1, t2, value=0.1)
431

432
        ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
433
        graph = ge.graph_for(t, t1, t2)
434
        self.assertAllFused(graph)
435

436
    # TODO: We leak CUDA memory here because the traced graph holds onto a
437
    # constant-ified tensor. Since the Python-global CompilationUnit is alive
438
    # until the end of the process, the memory is effectively leaked.
439
    # Removed `_cuda` suffix from this test which disables leak-checking.
440
    # If this is a real problem, we'll need to revisit Torchscript Function
441
    # lifetimes in Python.
442
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
443
    def test_lerp(self):
444
        start = torch.randn(4, 1, dtype=torch.float, device='cuda')
445
        end = torch.randn(1, 4, dtype=torch.float, device='cuda')
446
        weight = torch.tensor(0.5, dtype=torch.float, device='cuda')
447

448
        # scalar weight overload
449
        def foo_weight_scalar(start, end):
450
            return torch.lerp(start + 1, end, 0.5)
451

452
        # tensor weight overload
453
        def foo_weight_tensor(start, end):
454
            return torch.lerp(start + 1, end, weight)
455

456
        ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
457
        graph = ge_weight_scalar.graph_for(start, end)
458
        self.assertAllFused(graph)
459

460
        ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end))
461
        graph = ge_weight_tensor.graph_for(start, end)
462
        self.assertAllFused(graph)
463

464
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
465
    def test_concat_cuda(self):
466
        hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
467
        cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
468

469
        def foo(hx, cx):
470
            return torch.cat((hx + cx, hx * cx))
471

472
        ge = self.checkTrace(foo, (hx, cx))
473
        graph = ge.graph_for(hx, cx)
474
        self.assertAllFused(graph)
475
        FileCheck().check("FusedConcat").check_next("return").run(str(graph))
476

477
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
478
    def test_concat_invariant_cuda(self):
479
        # Invariant: the output of prim::FusedConcat may
480
        # not be an input to any node inside the FusionGroup.
481
        def fn(x, y, z):
482
            x1 = x + y
483
            y1 = x - y
484
            w = torch.cat([x1, y1])
485
            return w + z
486

487
        x = torch.randn(2, 2, dtype=torch.float, device='cuda')
488
        y = torch.randn(2, 2, dtype=torch.float, device='cuda')
489
        z = torch.randn(4, 2, dtype=torch.float, device='cuda')
490
        ge = self.checkTrace(fn, (x, y, z))
491
        graph = ge.graph_for(x, y, z)
492
        self.assertAllFused(graph, except_for={'aten::add'})
493
        FileCheck().check("FusedConcat").check_next("return").run(str(graph))
494

495
    @staticmethod
496
    def fn_test_exp(x, y):
497
        return (x + .5 * y).exp()
498

499
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
500
    def test_exp_cuda(self):
501
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
502
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
503

504
        ge = self.checkTrace(self.fn_test_exp, (x, y))
505
        self.assertAllFused(ge.graph_for(x, y))
506

507
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
508
    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
509
    @torch._jit_internal._disable_emit_hooks_decorator
510
    @_inline_everything
511
    def test_fuse_decompose_normalization(self):
512
        class ResLike(torch.jit.ScriptModule):
513
            def __init__(self, norm_module):
514
                super().__init__()
515
                self.nm = norm_module
516

517
            @torch.jit.script_method
518
            def forward(self, x, y):
519
                return y + torch.relu(self.nm(x))
520

521
        def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph):
522
            model = ResLike(nm).cuda()
523
            model_noopt = ResLike(nm).cuda()
524
            model_noopt.load_state_dict(model.state_dict())
525
            x = torch.randn(2, 16, 8, 8, device='cuda')
526
            y = torch.randn(2, 16, 8, 8, device='cuda')
527

528
            # FIXME: We need differentiation for CNNs for this optimization to trigger
529
            with torch.no_grad():
530
                out = model(x, y)
531
                graph = model.graph_for(x, y)
532
                rep = str(graph)
533

534
                with torch.jit.optimized_execution(False):
535
                    out_noopt = model_noopt(x, y)
536
                    rep_noopt = str(model_noopt.graph_for(x, y))
537
                self.assertEqual(out, out_noopt, atol=3e-5)
538

539
            # Check that normalization op has really been decomposed
540
            for node_in_graph in in_opt_graph:
541
                self.assertIn(node_in_graph, rep)
542

543
            for node_not_in_graph in not_in_opt_graph:
544
                self.assertNotIn(node_not_in_graph, rep)
545
                self.assertIn(node_not_in_graph, rep_noopt)
546

547
            fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
548
            self.assertEqual(len(fusion_groups), 1)
549
            fused_graph = str(fusion_groups[0].g('Subgraph'))
550
            for node_in_fusegraph in in_fusegraph:
551
                self.assertIn(node_in_fusegraph, fused_graph)
552

553
        # test for batchnorm decompose
554
        bm = nn.BatchNorm2d(16)
555
        test_norm_decompose(bm, ['aten::batch_norm_update_stats'],
556
                            ['aten::batch_norm('], ['aten::sqrt'])
557

558
        # test for layernorm decompose
559
        lm = nn.LayerNorm(8)
560
        test_norm_decompose(lm, ['aten::batch_norm_stats'],
561
                            ['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add'])
562

563
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
564
    def test_threshold(self):
565
        def f(x):
566
            return torch.threshold(x, 0, -10) + x + x + x
567

568
        x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
569
        scripted = self.checkScript(f, (x,))
570
        self.assertAllFused(scripted.graph_for(x))
571

572
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
573
    def test_scalar_arg_cuda(self):
574
        def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
575
            return p * (x * x + x)
576

577
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
578
        p = 3
579
        scripted = self.checkScript(fn_test_scalar_arg, (x, p))
580
        self.assertAllFused(scripted.graph_for(x, p))
581

582
        x.requires_grad_(True)
583

584
        # use another function otherwise we will bailout
585
        # and won't be able to do fused checks
586
        def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor:
587
            return p * (x * x + x)
588

589
        scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
590
        out = scripted(x, p)
591
        self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes",
592
                                                                  "aten::_size_if_not_equal"))
593

594
    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
595
    @unittest.skip("deduplicating introduces aliasing in backward graph's outputs")
596
    @enable_cpu_fuser
597
    def test_fuser_deduplication(self):
598
        # See that fusion kernel outputs are deduplicated when removing  _grad_sum_to_size in the fuser's compilation
599
        # see the discussion in PR #14957.
600
        def f(x, y):
601
            return torch.sigmoid(x + y)
602

603
        b = torch.randn(5, 5, requires_grad=True)
604
        a = torch.randn(5, 5, requires_grad=True)
605
        s = self.checkScript(f, (a, b))
606
        self.assertAllFused(s.graph_for(a, b), except_for={
607
                            'aten::size', 'aten::_size_if_not_equal', 'prim::BroadcastSizes'})
608

609
        c = s(a, b)
610
        results = warmup_backward(c.sum(), [a, b])
611
        ga2, gb2 = results.pop()
612
        graph = backward_graph(s)
613
        self.assertAllFused(graph)
614
        # check that a, b share storage, i.e. were generated as a single output in the fuser
615
        self.assertEqual(ga2.data_ptr(), gb2.data_ptr())
616

617
    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
618
    @enable_cpu_fuser
619
    @unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
620
    def test_fuser_iou(self):
621
        # This checks if most of Intersection over Union is fused.
622
        # In particular, the backward contains many _grad_sum_to_size.
623
        def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
624
            ltx = torch.max(b1x1, b2x1)  # [N,M]
625
            lty = torch.max(b1y1, b2y1)
626
            rbx = torch.min(b1x2, b2x2)
627
            rby = torch.min(b1y2, b2y2)
628

629
            w = (rbx - ltx).clamp(min=0, max=float('inf'))  # [N,M]
630
            h = (rby - lty).clamp(min=0, max=float('inf'))  # [N,M]
631
            inter = w * h  # [N,M]
632

633
            area1 = (b1x2 - b1x1) * (b1y2 - b1y2)  # [N,1]
634
            area2 = (b2x2 - b2x1) * (b2y2 - b2y2)  # [1,M]
635
            iou = inter / (area1 + area2 - inter)
636
            return iou
637

638
        box1 = torch.randn(5, 4, requires_grad=True)
639
        box2 = torch.randn(5, 4, requires_grad=True)
640
        # unsqueezing can currently not be fused
641
        b1x1 = box1[:, 0].unsqueeze(1)  # [N,1]
642
        b1y1 = box1[:, 1].unsqueeze(1)
643
        b1x2 = box1[:, 2].unsqueeze(1)
644
        b1y2 = box1[:, 3].unsqueeze(1)
645
        b2x1 = box2[:, 0].unsqueeze(0)  # [1,N]
646
        b2y1 = box2[:, 1].unsqueeze(0)
647
        b2x2 = box2[:, 2].unsqueeze(0)
648
        b2y2 = box2[:, 3].unsqueeze(0)
649

650
        s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
651
        self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
652
                            except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
653

654
        with enable_profiling_mode_for_profiling_tests(True):
655
            c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
656
            warmup_backward(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
657
            graph = backward_graph(s)
658
            self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
659

660
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
661
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
662
    @enable_cpu_fuser
663
    def test_fusion_reuse_multi_gpu(self):
664
        def fn(x, y):
665
            return x * y * x * y
666

667
        inputs_cpu = [
668
            torch.randn(4, 4, dtype=torch.float),
669
            torch.randn(4, 4, dtype=torch.float),
670
        ]
671
        inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
672
        inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
673

674
        # Should not crash; these should compile different kernels.
675
        ge = self.checkScript(fn, inputs_cpu)
676
        self.assertAllFused(ge.graph_for(*inputs_cpu))
677
        ge(*inputs_cuda0)
678
        ge(*inputs_cuda1)
679

680
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
681
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
682
    @enable_cpu_fuser
683
    def test_kernel_cache_multi_gpu(self):
684
        def not_fusible(x):
685
            return x
686

687
        def fn(x, y, z):
688
            x_out = x * x * x * x * x  # fusion: lambda x. x * x * x * x * x
689
            y_out = y * y * y * y * y
690
            z_out = z * z * z * z * z
691
            return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
692

693
        inputs = [
694
            torch.randn(4, 4, dtype=torch.float),
695
            torch.randn(4, 4, dtype=torch.float, device='cuda:0'),
696
            torch.randn(4, 4, dtype=torch.float, device='cuda:1'),
697
        ]
698

699
        prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
700

701
        # There are 3 FusionGroups. Because they have the same graph, they
702
        # should reuse the same KernelSpec in the KernelSpec cache.
703
        ge = self.checkScript(fn, inputs)
704
        self.assertGraphContainsExactly(
705
            ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
706
        new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
707
        # XXX: This assumes that the same kernel isn't already used by another test
708
        self.assertEqual(new_cache_size - prev_cache_size, 1)
709

710
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
711
    def test_nonzero_device_cuda(self):
712
        device = 'cuda:' + str(1)
713
        x = torch.tensor([0.4], dtype=torch.float, device=device)
714
        y = torch.tensor([0.7], dtype=torch.float, device=device)
715

716
        def doit(x, y):
717
            return torch.sigmoid(torch.tanh(x * (x + y) + x))
718

719
        ge = self.checkTrace(doit, (x, y))
720
        self.assertAllFused(ge.graph_for(x, y))
721

722
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
723
    def test_lstm_cuda(self):
724
        inputs = get_lstm_inputs('cuda', training=True)
725
        module = self.checkScript(LSTMCellS, inputs)
726
        return
727
        forward_graph = module.graph_for(*inputs)
728
        self.assertGraphContainsExactly(
729
            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
730
        self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
731
        # Everything is differentiable but TupleConstruct return
732
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
733
            .check_next("return").run(str(forward_graph))
734

735
        with enable_profiling_mode_for_profiling_tests(True):
736
            hy, cy = module(*inputs)
737
            warmup_backward((hy + cy).sum())
738
            backward = backward_graph(module)
739
        self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
740
                                                  "aten::_grad_sum_to_size"))
741

742
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
743
    # By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision.
744
    # We want float tensors to be computed at full precision in order to use the default precision
745
    @with_tf32_off
746
    def test_lstm_concat_cuda(self):
747
        inputs = get_lstm_inputs('cuda')
748
        ge = self.checkTrace(LSTMCellC, inputs)
749
        graph = ge.graph_for(*inputs)
750
        FileCheck().check("FusedConcat").check_next("return").run(str(graph))
751

752
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
753
    def test_lstm_gates_permutations_cuda(self):
754
        # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
755
        # Test that any permutation of this will still result in one FusionGroup.
756
        choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
757
        template = dedent('''
758
        def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
759
            gates = {} + {} + {} + {}
760
            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
761
            return ingate * forgetgate * cellgate * outgate
762
        ''')
763
        for permutation in permutations(choices, len(choices)):
764
            code = template.format(*permutation)
765
            scope = {}
766
            exec(code, globals(), scope)
767
            cu = torch.jit.CompilationUnit(code)
768

769
            inputs = get_lstm_inputs('cuda', training=False)
770
            self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
771
            forward_graph = cu.cell.graph_for(*inputs)
772
            self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
773

774
    # TODO: Fuser doesn't work at all when inputs require grad. Fix that
775
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
776
    # By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision.
777
    # We want float tensors to be computed at full precision in order to use the default precision
778
    @with_tf32_off
779
    def test_lstm_traced_cuda(self):
780
        inputs = get_lstm_inputs('cuda')
781
        ge = self.checkTrace(LSTMCellF, inputs)
782
        graph = ge.graph_for(*inputs)
783
        # .check_not("aten::add") don't get pulled into FusionGroup because of BailOuts
784
        FileCheck().check_not("Chunk").check_not("aten::sigmoid") \
785
            .check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
786
            .check_next("return").check_not("FusionGroup_2").run(str(graph))
787

788
    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
789
    @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
790
    @enable_cpu_fuser
791
    def test_lstm_traced_cpu(self):
792
        inputs = get_lstm_inputs('cpu')
793
        try:
794
            ge = self.checkTrace(LSTMCellF, inputs)
795
            graph = ge.graph_for(*inputs)
796
            FileCheck.check("FusionGroup").run(str(graph))
797
        except RuntimeError as e:
798
            if 'Failed to compile' in e.args[0]:
799
                warnings.warn('CPU fuser test has failed! This is not a hard failure, '  # noqa: F821
800
                              'because the kernels sometimes trigger bugs in compilers '
801
                              '(most notably GCC 7.2).')
802
                raise unittest.SkipTest('Failed to compile') from e
803
            else:
804
                raise
805

806
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
807
    def test_milstm_cuda(self):
808
        inputs = get_milstm_inputs('cuda', training=True)
809
        module = self.checkScript(MiLSTMCell, inputs)
810
        forward_graph = module.graph_for(*inputs)
811
        self.assertGraphContainsExactly(
812
            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
813
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
814
            .check_next("return").check("FusionGroup").run(str(forward_graph))
815
        hy, cy = module(*inputs)
816
        warmup_backward((hy + cy).sum())
817

818
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
819
    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor")
820
    def test_rand_cuda(self):
821
        class M(torch.jit.ScriptModule):
822
            __constants__ = ['d']
823

824
            def __init__(self):
825
                super().__init__()
826
                self.d = torch.device('cuda')
827

828
            @torch.jit.script_method
829
            def create(self, x):
830
                return x * x + x + torch.rand_like(x)
831

832
        x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
833
        m = M()
834
        out1 = m.create(x)
835
        out2 = m.create(x)
836
        self.assertNotEqual(out1, out2)
837
        self.assertTrue(torch.all(out1 >= 0))
838
        self.assertTrue(torch.all(out1 < 1))
839
        self.assertTrue(torch.all(out2 >= 0))
840
        self.assertTrue(torch.all(out2 < 1))
841
        self.assertAllFused(m.create.graph_for(x))
842

843
    @staticmethod
844
    def fn_test_relu(x, y):
845
        return F.relu(x + .5 * y)
846

847
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
848
    def test_relu_cuda(self):
849
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
850
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
851

852
        ge = self.checkTrace(self.fn_test_relu, (x, y))
853
        self.assertAllFused(ge.graph_for(x, y))
854

855
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
856
    def test_erf_cuda(self):
857
        def fn_test_erf(x):
858
            return F.relu(torch.erf(x) - torch.erfc(x))
859

860
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
861
        ge = self.checkTrace(fn_test_erf, (x,))
862
        self.assertAllFused(ge.graph_for(x))
863
        x.requires_grad_(True)
864
        ge = self.checkTrace(fn_test_erf, (x,))
865
        self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes",
866
                                                         "aten::_size_if_not_equal"))
867

868
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
869
    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor")
870
    def test_rand_broadcast_cuda(self):
871
        def fn_test_rand(x, y):
872
            r = torch.rand_like(y)
873
            return r * x + x
874

875
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
876
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
877
        script_f = torch.jit.script(fn_test_rand)
878
        out = script_f(x, y)
879
        self.assertAllFused(script_f.graph_for(x, y))
880
        x.requires_grad_(True)
881
        out = script_f(x, y)
882
        self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
883
                                                                  "aten::_size_if_not_equal"))
884
        # test that broadcasting random produces correct results
885
        x = torch.ones(4, 4, dtype=torch.float, device='cuda')
886
        y = torch.ones(4, dtype=torch.float, device='cuda')
887
        out = script_f(x, y)
888
        self.assertEqual(out[0], out[1])
889

890
    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
891
    @enable_cpu_fuser
892
    def test_scalar(self):
893
        def fn(x, y):
894
            return 2 * x + y
895

896
        x = torch.tensor(0.1, dtype=torch.float, device='cpu')
897
        y = torch.tensor(1, dtype=torch.float, device='cpu')
898
        ge = self.checkScript(fn, (x, y))
899
        self.assertAllFused(ge.graph_for(x, y))
900

901
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
902
    def test_small_constant_cuda(self):
903
        def fn_test_small_constant(x, y):
904
            return (1e-8 * x + 5e-9 * y) * 1e8
905
        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
906
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
907

908
        ge = self.checkTrace(fn_test_small_constant, (x, y))
909
        self.assertAllFused(ge.graph_for(x, y))
910

911
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
912
    def test_tensor_scalar_ops_cuda(self):
913
        def should_fuse(x):
914
            z = 3.
915
            y = x + z
916
            return x * y
917

918
        # XXX: right now we only support fusing scalars if
919
        # they're constant (#9940)
920
        def should_not_fuse(x, z):
921
            y = x + int(z)
922
            return x * y
923

924
        inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
925
        ge = self.checkScript(should_fuse, inputs)
926
        self.assertAllFused(ge.graph_for(*inputs))
927

928
        inputs = [
929
            torch.randn(2, 2, dtype=torch.float, device='cuda'),
930
            torch.tensor(3., dtype=torch.float, device='cuda'),
931
        ]
932
        ge = self.checkScript(should_not_fuse, inputs)
933
        self.assertGraphContainsExactly(
934
            ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
935

936
    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
937
    @enable_cpu_fuser
938
    def test_where_and_typing(self):
939
        def f(x, y):
940
            mask = x > y
941
            res = torch.where(mask, x, y)
942
            return mask, res
943

944
        x = torch.randn(4, 4, dtype=torch.double)
945
        y = torch.randn(4, 4, dtype=torch.double)
946

947
        script_f = self.checkScript(f, (x, y))
948
        self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
949

950
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
951
    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
952
    def test_grad_sum_to_size_elimination(self):
953

954
        def my_broadcasted_cell(a, b, c):
955
            return (a + b) + c
956

957
        s1 = torch.randn(5, 1, requires_grad=True, device='cuda')
958
        s2 = torch.randn(5, 5, requires_grad=True, device='cuda')
959

960
        module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING)
961
        forward_graph = module.graph_for(s1, s1, s1)
962
        self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes",
963
                                                       "aten::_size_if_not_equal"))
964

965
        old_plans = set()
966
        for i in range(3):
967
            # if we have s2, then the s1 are _grad_sum_to_size'd
968

969
            args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2
970
            args = [a.detach_().requires_grad_() for a in args]
971
            # recompile, so we don't trigger bailouts
972
            module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING)
973
            res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2)
974
            warmup_backward(res.sum(), args)
975
            grads = torch.autograd.grad(res.sum(), args)
976
            for inp, gr in zip(args, grads):
977
                self.assertEqual(inp.shape, gr.shape)
978
            backward = None
979
            # this is a workaround for the backward graphs not being
980
            # in order for Python 2
981
            for g in all_backward_graphs(module):
982
                if str(g) not in old_plans:
983
                    assert backward is None
984
                    backward = g
985
                    old_plans.add(str(backward))
986
            num_grads = 1 if i > 0 else 0
987
            self.assertEqual(len([n for n in backward.nodes() if n.kind() == 'aten::_grad_sum_to_size']), num_grads)
988

989

990
if __name__ == '__main__':
991
    run_tests()
992

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.