pytorch

Форк
0
/
test_jit_fuser_te.py 
3047 строк · 107.0 Кб
1
# Owner(s): ["NNC"]
2

3
import contextlib
4
import math
5
import operator
6
import os
7
import unittest
8
import warnings
9
from typing import List
10

11
import torch
12
import torch.nn.functional as F
13
from torch.testing import FileCheck
14

15

16
# these needs to be set before `common_utils`
17
# infers `GRAPH_EXECUTOR`.
18
# this file **requires** these settings
19
# and setting them after `GRAPH_EXECUTOR` is
20
# inferred erroneously runs or skips
21
# some tests
22
torch._C._jit_set_profiling_executor(True)
23
torch._C._get_graph_executor_optimize(True)
24

25
from itertools import combinations, permutations, product
26
from textwrap import dedent
27

28
from jit.test_fuser_common import TestFuserCommon  # noqa: F401
29
from test_jit import (
30
    backward_graph,
31
    get_lstm_inputs,
32
    get_milstm_inputs,
33
    LSTMCellC,
34
    LSTMCellF,
35
    LSTMCellS,
36
    MiLSTMCell,
37
)
38

39
from torch.testing._internal.common_device_type import (
40
    instantiate_device_type_tests,
41
    onlyCPU,
42
    OpDTypes,
43
    ops,
44
)
45
from torch.testing._internal.common_jit import JitCommonTestCase
46
from torch.testing._internal.common_methods_invocations import op_db
47
from torch.testing._internal.common_utils import (
48
    enable_profiling_mode_for_profiling_tests,
49
    GRAPH_EXECUTOR,
50
    IS_FBCODE,
51
    ProfilingMode,
52
    run_tests,
53
    skipIfTorchDynamo,
54
    slowTest,
55
    TEST_WITH_ASAN,
56
    TEST_WITH_ROCM,
57
)
58
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
59
from torch.testing._internal.jit_utils import (
60
    clone_inputs,
61
    get_traced_sample_variant_pairs,
62
    JitTestCase,
63
    NoTracerWarnContextManager,
64
    RUN_CUDA,
65
    RUN_CUDA_HALF,
66
    RUN_CUDA_MULTI_GPU,
67
    set_fusion_group_inlining,
68
    TensorExprTestOptions,
69
    warmup_backward,
70
)
71

72

73
FUSION_GROUP = "prim::TensorExprGroup"
74
LLVM_ENABLED = torch._C._llvm_enabled()
75

76
autograd_check_set = {
77
    "aten::__is__",
78
    "prim::AutogradAllNonZero",
79
    "prim::AutogradAllZero",
80
    "prim::ListConstruct",
81
}
82

83

84
def strip_profiling_nodes(nodes):
85
    profiling_opcodes = {"prim::BailoutTemplate", "prim::BailOut"}
86
    return [n for n in nodes if n.kind() not in profiling_opcodes]
87

88

89
def warmup_forward(f, *args, profiling_count=2):
90
    for i in range(profiling_count):
91
        results = f(*args)
92

93
    return results
94

95

96
@contextlib.contextmanager
97
def texpr_reductions_enabled():
98
    old = torch._C._jit_set_texpr_reductions_enabled(True)
99
    try:
100
        yield
101
    finally:
102
        torch._C._jit_set_texpr_reductions_enabled(old)
103

104

105
@contextlib.contextmanager
106
def texpr_enable_strategy(strategy):
107
    old = torch._C._jit_set_fusion_strategy(strategy)
108
    try:
109
        yield
110
    finally:
111
        torch._C._jit_set_fusion_strategy(old)
112

113

114
@contextlib.contextmanager
115
def inline_fusion_groups():
116
    old_inlining = torch._C._debug_get_fusion_group_inlining()
117
    torch._C._debug_set_fusion_group_inlining(True)
118
    try:
119
        yield
120
    finally:
121
        torch._C._debug_set_fusion_group_inlining(old_inlining)
122

123

124
class TestTEFuser(JitTestCase):
125
    def setUp(self):
126
        super().setUp()
127
        self.tensorexpr_options = TensorExprTestOptions()
128

129
        # note: `self.dynamic_shapes` instatiated in specialization of class
130
        # defined below
131

132
        fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)]
133
        self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy)
134

135
        self.devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
136
        self.int_dtypes = [
137
            torch.int8,
138
            torch.int16,
139
            torch.int32,
140
            torch.int64,
141
            torch.bool,
142
        ]
143
        self.fp_dtypes = [
144
            torch.float16,
145
            torch.float32,
146
            torch.float64,
147
            torch.bfloat16,
148
        ]
149
        self.dtypes = self.int_dtypes + self.fp_dtypes
150

151
    def tearDown(self):
152
        self.tensorexpr_options.restore()
153
        torch._C._jit_set_fusion_strategy(self.old_fusion_strategy)
154
        super().tearDown()
155

156
    def assertAllFused(self, graph, except_for=None):
157
        except_for = except_for if except_for is not None else set()
158
        # TODO - upstream
159
        guards = (
160
            "prim::TypeCheck",
161
            "prim::RequiresGradCheck",
162
            "prim::TensorExprDynamicGuard",
163
        )
164
        guard_found = False
165

166
        def autodiff_guard(node):
167
            if node.kind() != "aten::all":
168
                return False
169
            inps = list(node.inputs())
170
            if len(inps) != 1 or inps[0].node().kind() != "prim::ListConstruct":
171
                return False
172
            li_inps = list(inps[0].node().inputs())
173
            for li_inp in li_inps:
174
                if li_inp.node().kind() in (
175
                    "prim::AutogradAllNonZero",
176
                    "prim::AutogradAllZero",
177
                ):
178
                    return True
179
            return False
180

181
        def is_guard(node):
182
            return node.kind() in guards or autodiff_guard(node)
183

184
        for node in graph.block().nodes():
185
            if node.kind() == "prim::Constant":
186
                continue
187
            if is_guard(node):
188
                self.assertFalse(guard_found)
189
                guard_found = True
190
                continue
191
            if node.kind() in except_for:
192
                continue
193
            if node.kind() == "prim::If":
194
                self.assertTrue(is_guard(node.prev()))
195
                continue
196
            self.assertTrue(False, "Found unexpected node:" + node.kind())
197

198
        self.assertTrue(guard_found)
199

200
    def assertLastGraphAllFused(self):
201
        self.assertAllFused(torch.jit.last_executed_optimized_graph())
202

203
    def findFusionGroups(self, graph):
204
        result = []
205
        for n in graph.nodes():
206
            if n.kind() == FUSION_GROUP:
207
                result.append(n.g("Subgraph"))
208
                continue
209
            for block in n.blocks():
210
                result += self.findFusionGroups(block)
211
        return result
212

213
    def test_typecheck(self):
214
        a = torch.ones(1)
215

216
        def fused_kernel(a, b):
217
            return (a + b) * 2.0
218

219
        scripted = self.checkScript(fused_kernel, (a, a))
220
        graph = scripted.graph_for(a, a)
221
        # double check we fused
222
        fusion_groups = self.findFusionGroups(graph)
223
        self.assertEqual(len(fusion_groups), 1)
224
        # we use a bigger tensor now (size 2)
225
        # if we won't trigger a recompilation
226
        # we will still create a tensor up to (size 1)
227
        # if the type check fails
228
        a = torch.ones(2)
229
        # shape changed if we don't trigger recompilation
230
        # we would compute the wrong result silently
231
        self.assertEqual(scripted(a, a), fused_kernel(a, a))
232

233
    def test_sum_simple(self):
234
        def func(x):
235
            x2 = x * x
236
            return x2.sum()
237

238
        with texpr_reductions_enabled():
239
            a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
240
            a = a.reshape(5, 3)
241
            scripted = self.checkScript(func, (a,))
242
            self.assertLastGraphAllFused()
243

244
    def test_nop(self):
245
        pass
246

247
    def test_sum_dim(self):
248
        def func(x):
249
            return x.sum((0,)) * 2
250

251
        def func_neg(x):
252
            return x.sum((-2,)) * 2
253

254
        with texpr_reductions_enabled():
255
            a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
256
            a = a.reshape(5, 3)
257
            scripted = self.checkScript(func, (a,))
258
            self.assertLastGraphAllFused()
259
            scripted = self.checkScript(func_neg, (a,))
260
            self.assertLastGraphAllFused()
261

262
    def test_sum_keepdim_cast(self):
263
        def func(x):
264
            return x.sum((0,), keepdim=True, dtype=torch.double) * 2
265

266
        with texpr_reductions_enabled():
267
            a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
268
            a = a.reshape(5, 3)
269

270
            self.checkScript(func, (a,))
271
            self.assertLastGraphAllFused()
272

273
    def test_abs(self):
274
        for device in self.devices:
275

276
            def func(x):
277
                return x.abs() * 2
278

279
            a = torch.randn(5, device=device)
280
            scripted = self.checkScript(func, (a,))
281
            self.assertLastGraphAllFused()
282

283
    def test_unsqueeze_size_calculation(self):
284
        for device in self.devices:
285

286
            def foo(b, d):
287
                x = d.unsqueeze(1)
288
                y = x * 42.0
289
                z = b + y
290
                r = z / 42.0
291
                return r
292

293
            inputs = (
294
                torch.rand(20, 28, device=device, requires_grad=True),
295
                torch.rand(20, device=device),
296
            )
297
            scripted = self.checkScript(foo, inputs)
298
            self.assertAllFused(scripted.graph_for(*inputs))
299

300
    def test_zero_element_tensors(self):
301
        for device in self.devices:
302

303
            def decode(sin_t, cos_t):
304
                theta = torch.atan2(sin_t.float(), cos_t.float())
305
                return theta
306

307
            sin = torch.zeros(0, device=device)
308
            cos = torch.zeros(0, device=device)
309
            inputs = [sin, cos]
310
            ge = self.checkScript(decode, inputs)
311

312
    def test_arg_configurations_smoke(self):
313
        if self.dynamic_shapes:
314
            self.skipTest("TODO: chunk dynamic shapes")
315

316
        # A smoke test to make sure we won't use the same kernel for contiguous
317
        # and non-contiguous arguments.
318
        # TODO: add optionally enabled debug counters to the fuser to verify
319
        #       that we really can tell the difference between configurations
320
        for device in self.devices:
321

322
            def f(x, y):
323
                z1, z2 = (x + y).chunk(2, dim=1)
324
                return z1 * z2
325

326
            x = torch.randn(4, 4, dtype=torch.float, device=device)
327
            y = torch.randn(4, 4, dtype=torch.float, device=device)
328
            traced_f = torch.jit.trace(f, (x, y))
329
            self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
330

331
    def test_broadcast(self):
332
        for device in self.devices:
333

334
            def scaleshift(x, scale, shift):
335
                return x * scale + shift
336

337
            inputs = [
338
                torch.randn(4, 4, dtype=torch.float, device=device),
339
                torch.randn(4, dtype=torch.float, device=device),
340
                torch.randn(4, dtype=torch.float, device=device),
341
            ]
342
            self.checkScript(scaleshift, inputs)
343

344
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
345
    @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
346
    @unittest.skipIf(
347
        GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on"
348
    )
349
    def test_cuda_half(self):
350
        x = torch.randn(4, 4, dtype=torch.half, device="cuda")
351
        y = torch.randn(4, 4, dtype=torch.half, device="cuda")
352

353
        funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp]
354

355
        # Note: Non fused inputs must be float to prevent loss of precision
356
        inputs = (x.float(), y.float())
357
        fusion_inputs = (x, y)
358
        for fn in funcs:
359
            local_inputs = [t.clone().requires_grad_() for t in inputs]
360
            local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
361

362
            # Verifies outputs
363
            fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False)
364
            outputs = fn(*local_inputs)
365
            fusion_outputs = fusion(*local_fusion_inputs)
366
            outputs_half = [t.half() for t in outputs]
367
            self.assertEqual(outputs_half, fusion_outputs)
368

369
            # Verifies gradients
370
            for output, fusion_output in zip(outputs_half, fusion_outputs):
371
                grads = torch.autograd.grad(
372
                    output.float().sum(),
373
                    local_inputs,
374
                    allow_unused=True,
375
                    retain_graph=True,
376
                )
377
                fusion_grads = torch.autograd.grad(
378
                    fusion_output.sum(),
379
                    local_fusion_inputs,
380
                    allow_unused=True,
381
                    retain_graph=True,
382
                )
383
                grads_half = [t.half() for t in grads]
384
                self.assertEqual(grads_half, fusion_grads)
385

386
    def test_checks_cat_inputs(self):
387
        # single fusion node causes error
388
        with set_fusion_group_inlining(True):
389
            for device in self.devices:
390
                # We shouldn't treat cat nodes as broadcasting. All their inputs
391
                # need to be checked for having the same map size, before we can
392
                # run the kernel.
393
                def f(x, y):
394
                    return torch.cat([x + 2 * x + x**2, y + 4 * y + y**3], dim=0)
395

396
                # NOTE: y is broadcastable to x, but output of f(x, y) should have
397
                # shape 3x4, and not 4x4.
398
                x = torch.randn(2, 4, dtype=torch.float, device=device)
399
                y = torch.randn(1, 4, dtype=torch.float, device=device)
400

401
                scripted = self.checkScript(f, (x, y))
402
                self.assertEqual(scripted(x, y).shape, (3, 4))
403
                self.assertAllFused(scripted.graph_for(x, y))
404

405
    def test_chunk(self):
406
        if self.dynamic_shapes:
407
            self.skipTest("TODO: chunk dynamic shapes")
408

409
        for device in self.devices:
410

411
            def fn(x):
412
                a, b, c = x.chunk(3, 1)
413
                return a * b + c
414

415
            inputs = [torch.randn(10, 6, dtype=torch.float, device=device)]
416

417
            self.checkScript(fn, inputs)
418
            self.assertLastGraphAllFused()
419

420
    def test_chunk_correctness(self):
421
        if self.dynamic_shapes:
422
            self.skipTest("TODO: chunk dynamic shapes")
423

424
        for device in self.devices:
425

426
            def chunk_4_0(x):
427
                x0, x1, x2, x3 = x.chunk(4, 0)
428
                return x0 + x1 + x2 + x3
429

430
            def chunk_4_1(x):
431
                x0, x1, x2, x3 = x.chunk(4, 1)
432
                return x0 + x1 + x2 + x3
433

434
            def chunk_4_last(x):
435
                x0, x1, x2, x3 = x.chunk(4, 2)
436
                return x0 + x1 + x2 + x3
437

438
            fns = [chunk_4_0, chunk_4_1, chunk_4_last]
439
            tensors = [
440
                # splitSize = 1
441
                torch.randn(4, 4, 4, dtype=torch.float, device=device),
442
                # contiguous case
443
                torch.randn(12, 8, 16, dtype=torch.float, device=device),
444
                # non-contiguous case
445
                torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(
446
                    1, 2
447
                ),
448
            ]
449

450
            for tensor in tensors:
451
                for fn in fns:
452
                    self.checkScript(fn, [tensor])
453
                    self.assertLastGraphAllFused()
454

455
    def test_chunk_distributes(self):
456
        if self.dynamic_shapes:
457
            self.skipTest("TODO: chunk dynamic shapes")
458

459
        if self.dynamic_shapes:
460
            self.skipTest("TODO: chunk dynamic shapes")
461

462
        for device in self.devices:
463

464
            def f(x, y):
465
                z1, z2 = (x + y).chunk(2, dim=1)
466
                return z1 * z2
467

468
            x = torch.randn(4, 4, dtype=torch.float, device=device)
469
            y = torch.randn(4, 4, dtype=torch.float, device=device)
470

471
            ge = self.checkTrace(f, (x, y))
472
            graph = ge.graph_for(x, y)
473
            # XXX: The old fuser does broadcast_tensors but the new fuser doesn't.
474
            # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \
475
            #     .check_count('ConstantChunk', 2, exactly=True).run(str(graph))
476
            FileCheck().check("with " + FUSION_GROUP + "_").check_count(
477
                "ConstantChunk", 1, exactly=True
478
            ).run(str(graph))
479

480
    def test_chunk_motion_deduplicates_inputs(self):
481
        if self.dynamic_shapes:
482
            self.skipTest("TODO: chunk dynamic shapes")
483

484
        for device in self.devices:
485

486
            def func1(x):
487
                z = x * x
488
                z0, z1 = z.chunk(2)
489
                return z0 * z1
490

491
            def func2(x):
492
                z = x * x * x
493
                z0, z1 = z.chunk(2)
494
                return z0 * z1
495

496
            inputs = [torch.tensor([1.1, 1.2], device=device, dtype=torch.float)]
497
            for func in [func1, func2]:
498
                self.checkScript(func, inputs)
499
                self.assertLastGraphAllFused()
500

501
    def test_chunk_multiple(self):
502
        if self.dynamic_shapes:
503
            self.skipTest("TODO: chunk dynamic shapes")
504

505
        for device in self.devices:
506
            # The arguments are intentionally used out of order as a test to see
507
            # if the fusion compiler adds extra args in the correct order
508
            def fn(s, x, y, z):
509
                z1, z2 = z.chunk(2, 2)
510
                x1, x2, x3 = x.chunk(3, 1)
511
                y1, y2 = y.chunk(2, 0)
512
                return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
513

514
            inputs = [
515
                torch.randn(5, 2, 3, dtype=torch.float, device=device),
516
                torch.randn(5, 6, 3, dtype=torch.float, device=device),
517
                torch.randn(10, 2, 3, dtype=torch.float, device=device),
518
                torch.randn(5, 2, 6, dtype=torch.float, device=device),
519
            ]
520

521
            ge = self.checkScript(fn, inputs)
522
            self.assertAllFused(ge.graph_for(*inputs))
523

524
    def test_minmax(self):
525
        for device in self.devices:
526

527
            def tmax(a, b):
528
                return torch.max(2 * a, b)
529

530
            def tmin(a, b):
531
                return torch.min(2 * a, b)
532

533
            a = torch.randn(4, 4, dtype=torch.float)
534
            b = torch.randn(4, 4, dtype=torch.float)
535
            nan = torch.tensor(float("nan"), dtype=torch.float)
536

537
            for f, inputs, device in product(
538
                (tmax, tmin), ([a, b], [a, nan], [b, nan]), self.devices
539
            ):
540
                inputs = [t.to(device) for t in inputs]
541
                s = self.checkScript(f, inputs)
542
                self.assertAllFused(s.graph_for(*inputs))
543

544
    def test_clamp(self):
545
        for device in self.devices:
546

547
            def func2(a, b):
548
                return torch.clamp(a + b, min=0, max=2)
549

550
            def funcInf(a, b):
551
                return torch.clamp(a + b, min=0, max=float("inf"))
552

553
            def funcNegInf(a, b):
554
                return torch.clamp(a + b, min=float("-inf"), max=0)
555

556
            def funcOptMin(a, b):
557
                return torch.clamp(a + b, max=2)
558

559
            def funcOptMax(a, b):
560
                return torch.clamp(a + b, min=0)
561

562
            a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True)
563
            b = torch.randn(4, 4, dtype=torch.float, device=device)
564
            nan = torch.tensor(float("nan"), dtype=torch.float, device=device)
565

566
            funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
567
            for f, inputs in product(funcs, [[a, b], [a, nan]]):
568
                inp1, inp2 = inputs
569
                s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
570
                self.assertAllFused(
571
                    s.graph_for(inp1, inp2),
572
                    except_for={"aten::size", "aten::_size_if_not_equal"},
573
                )
574
                c = s(inp1, inp2)
575
                with enable_profiling_mode_for_profiling_tests():
576
                    warmup_backward(c.sum())
577
                graph = backward_graph(s)
578
                self.assertAllFused(
579
                    graph,
580
                    except_for={"aten::Float", "aten::_grad_sum_to_size"}.union(
581
                        autograd_check_set
582
                    ),
583
                )
584

585
    def test_clamp_double(self):
586
        for device in self.devices:
587

588
            def clamp_double(x, eta: float):
589
                return 1 - x.clamp(eta, 1 - eta)
590

591
            x = torch.tensor([1.0, 1.0], dtype=torch.double, device=device)
592
            eta = 1e-9
593
            s = self.checkScript(
594
                clamp_double,
595
                (x, eta),
596
                profiling=ProfilingMode.PROFILING,
597
                atol=1e-10,
598
                rtol=1e-5,
599
            )
600
            self.assertAllFused(s.graph_for(x, eta), except_for={"aten::sub"})
601

602
    def test_clamp_int(self):
603
        for device in self.devices:
604

605
            def clamp_int(x, eta: int):
606
                return x.clamp(0, eta)
607

608
            x = torch.tensor([1, 1], device=device)
609
            eta = 1 << 32
610
            s = self.checkScript(clamp_int, (x, eta), profiling=ProfilingMode.PROFILING)
611
            self.assertAllFused(s.graph_for(x, eta))
612

613
    def test_add_bool(self):
614
        sizes = [(1,), (2,), (4, 4)]
615
        for device, size in product(self.devices, sizes):
616

617
            def f(x, y, z):
618
                return x + y + z
619

620
            x = torch.randint(0, 2, size, dtype=torch.bool, device=device)
621
            y = torch.randint(0, 2, size, dtype=torch.bool, device=device)
622
            z = torch.randint(0, 2, size, dtype=torch.bool, device=device)
623
            ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
624
            self.assertAllFused(ge.graph_for(x, y, z))
625

626
    def test_mul_bool(self):
627
        for device in self.devices:
628

629
            def f(x, y, z):
630
                return x * y * z
631

632
            x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
633
            y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
634
            z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
635

636
            ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
637
            self.assertAllFused(ge.graph_for(x, y, z))
638

639
    def test_div_bool(self):
640
        for device in self.devices:
641

642
            def f(x, y, z):
643
                return (x + y) / z
644

645
            x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
646
            y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
647
            z = torch.ones_like(x, dtype=torch.bool, device=device)
648

649
            ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
650
            self.assertAllFused(ge.graph_for(x, y, z))
651

652
    def test_bitwise_ops(self):
653
        def apply(fn):
654
            return lambda x, y, z: fn(fn(x, y), z)
655

656
        binary_ops = [
657
            operator.__and__,
658
            operator.__or__,
659
            operator.__xor__,
660
            operator.__lshift__,
661
            operator.__rshift__,
662
        ]
663
        devices = self.devices
664
        for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
665
            try:
666
                x = self.data_for(dtype, device)
667
                y = self.data_for(dtype, device)
668
                z = self.data_for(dtype, device)
669
                fn = apply(op)
670
                ref = fn(x, y, z)
671
            except Exception:
672
                # If eager mode doesn't support a dtype/op/device combo,
673
                # neither does the fuser.  Catch everything to avoid needing to
674
                # guess what errors might be thrown by eager.
675
                continue
676
            try:
677
                t = torch.jit.trace(fn, (x, y, z))
678
                self.assertEqual(ref, t(x, y, z))
679
                self.assertAllFused(t.graph_for(x, y, z))
680
            except Exception as e:
681
                raise RuntimeError(
682
                    " ".join(["Failed:", str(dtype), op.__name__, device])
683
                ) from e
684

685
    def test_minmax_int_ops(self):
686
        def apply(fn):
687
            return lambda x, y, z: fn(fn(x, y), z)
688

689
        binary_ops = [torch.min, torch.max]
690
        devices = self.devices
691
        for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
692
            try:
693
                x = self.data_for(dtype, device)
694
                y = self.data_for(dtype, device)
695
                z = self.data_for(dtype, device)
696
                fn = apply(op)
697
                ref = fn(x, y, z)
698
            except Exception:
699
                # If eager mode doesn't support a dtype/op/device combo,
700
                # neither does the fuser.  Catch everything to avoid needing to
701
                # guess what errors might be thrown by eager.
702
                continue
703
            try:
704
                t = torch.jit.trace(fn, (x, y, z))
705
                self.assertEqual(ref, t(x, y, z))
706
                self.assertAllFused(t.graph_for(x, y, z))
707
            except Exception as e:
708
                raise RuntimeError(
709
                    " ".join(["Failed:", str(dtype), op.__name__, device])
710
                ) from e
711

712
    def test_comparison_eq_ne(self):
713
        for device in self.devices:
714

715
            def f(x, y):
716
                mask = (x == 0).type_as(x)
717
                z = x * mask + y
718
                mask = (x != 0).type_as(x)
719
                z = z * mask + y
720
                return z
721

722
            x = torch.randn(4, 4, dtype=torch.float, device=device)
723
            y = torch.randn(4, 4, dtype=torch.float, device=device)
724

725
            ge = self.checkTrace(f, (x, y))
726
            self.assertAllFused(ge.graph_for(x, y))
727

728
    @staticmethod
729
    def fn_test_comparison_gt_lt(x, y):
730
        mask = (x > 0).type_as(x)
731
        z = x * mask + y
732
        mask = (x < 0).type_as(x)
733
        z = z * mask + y
734
        return z
735

736
    def test_comparison_gt_lt(self):
737
        for device in self.devices:
738
            x = torch.randn(4, 4, dtype=torch.float, device=device)
739
            y = torch.randn(4, 4, dtype=torch.float, device=device)
740

741
            ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
742
            self.assertAllFused(ge.graph_for(x, y))
743

744
    def test_comparison_ge_le(self):
745
        for device in self.devices:
746

747
            def f(x, y):
748
                mask = (x >= 0).type_as(x)
749
                z = x * mask + y
750
                mask = (x <= 0).type_as(x)
751
                z = z * mask + y
752
                return z
753

754
            x = torch.randn(4, 4, dtype=torch.float, device=device)
755
            y = torch.randn(4, 4, dtype=torch.float, device=device)
756

757
            ge = self.checkTrace(f, (x, y))
758
            self.assertAllFused(ge.graph_for(x, y))
759
            x.requires_grad_(True)
760
            y.requires_grad_(True)
761
            self.assertAllFused(
762
                ge.graph_for(x, y),
763
                except_for=(
764
                    "aten::size",
765
                    "prim::BroadcastSizes",
766
                    "aten::_size_if_not_equal",
767
                ),
768
            )
769

770
    def test_addcmul(self):
771
        for device in self.devices:
772
            t = torch.randn(1, 4, dtype=torch.float, device=device)
773
            t1 = torch.randn(4, 1, dtype=torch.float, device=device)
774
            t2 = torch.randn(1, 4, dtype=torch.float, device=device)
775

776
            def foo(t, t1, t2):
777
                return t.addcmul(t + 1, t2, value=0.1)
778

779
            ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
780
            graph = ge.graph_for(t, t1, t2)
781
            fusion_groups = self.findFusionGroups(graph)
782
            self.assertEqual(len(fusion_groups), 1)
783
            FileCheck().check("aten::add(").check("aten::addcmul(").run(
784
                str(fusion_groups[0])
785
            )
786

787
    # TODO: We leak CUDA memory here because the traced graph holds onto a
788
    # constant-ified tensor. Since the Python-global CompilationUnit is alive
789
    # until the end of the process, the memory is effectively leaked.
790
    # Removed `_cuda` suffix from this test which disables leak-checking.
791
    # If this is a real problem, we'll need to revisit Torchscript Function
792
    # lifetimes in Python.
793
    def test_lerp(self):
794
        for device in self.devices:
795
            start = torch.randn(4, 1, dtype=torch.float, device=device)
796
            end = torch.randn(1, 4, dtype=torch.float, device=device)
797
            weight = torch.tensor(0.5, dtype=torch.float, device=device)
798

799
            # scalar weight overload
800
            def foo_weight_scalar(start, end):
801
                return torch.lerp(start + 1, end, 0.5)
802

803
            # tensor weight overload
804
            def foo_weight_tensor(start, end):
805
                return torch.lerp(start + 1, end, weight)
806

807
            ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
808
            graph = ge_weight_scalar.graph_for(start, end)
809
            self.assertAllFused(graph)
810

811
            # TODO: uncomment when TE enables support for scalar tensors
812
            # ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end))
813
            # graph = ge_weight_tensor.graph_for(start, end)
814
            # self.assertAllFused(graph)
815

816
    def test_concat(self):
817
        # disabling concat causes error with single concat node
818
        with set_fusion_group_inlining(True):
819
            for device in self.devices:
820
                hx = torch.randn(3, 20, dtype=torch.float, device=device)
821
                cx = torch.randn(3, 20, dtype=torch.float, device=device)
822

823
                def foo(hx, cx):
824
                    return torch.cat((hx + cx, hx * cx))
825

826
                ge = self.checkTrace(foo, (hx, cx))
827
                graph = ge.graph_for(hx, cx)
828
                self.assertAllFused(graph)
829
                # XXX: TE fuser can handle concats in a fusion group.
830
                # FileCheck().check("FusedConcat").check_next("return").run(str(graph))
831

832
    def test_remove_output_used_only_in_size(self):
833
        for device in self.devices:
834

835
            def test_fuse(a, b):
836
                c = a + b
837
                d = c + b
838
                return d
839

840
            scripted_f = torch.jit.script(test_fuse)
841
            x = torch.ones(1, requires_grad=True, device=device)
842
            y = torch.ones(1, requires_grad=True, device=device)
843
            warmup_forward(scripted_f, x, y, profiling_count=3)
844
            g = scripted_f.graph_for(x, y)
845
            diff_nodes = g.findAllNodes("prim::DifferentiableGraph")
846
            self.assertEqual(len(diff_nodes), 1)
847
            g = diff_nodes[0].g("Subgraph")
848
            if_nodes = [n for n in g.nodes() if n.kind() == "prim::If"]
849
            self.assertEqual(len(if_nodes), 1)
850

851
            # the if node and the fusion group inside it should only have one output
852
            self.assertEqual(len(list(if_nodes[0].outputs())), 1)
853

854
    def test_concat_invariant(self):
855
        for device in self.devices:
856
            # Invariant: the output of prim::FusedConcat may
857
            # not be an input to any node inside the FusionGroup.
858
            def fn(x, y, z):
859
                x1 = x + y
860
                y1 = x - y
861
                w = torch.cat([x1, y1])
862
                return w + z
863

864
            x = torch.randn(2, 2, dtype=torch.float, device=device)
865
            y = torch.randn(2, 2, dtype=torch.float, device=device)
866
            z = torch.randn(4, 2, dtype=torch.float, device=device)
867
            ge = self.checkTrace(fn, (x, y, z))
868
            graph = ge.graph_for(x, y, z)
869
            self.assertAllFused(graph, except_for={"aten::add"})
870
            # XXX: TE fuser can handle concats inside a fusion group.
871
            # FileCheck().check("FusedConcat").check_next("return").run(str(graph))
872

873
    @staticmethod
874
    def fn_test_exp(x, y):
875
        return (x + 0.5 * y).exp()
876

877
    def test_exp(self):
878
        for device in self.devices:
879
            x = torch.randn(4, 4, dtype=torch.float, device=device)
880
            y = torch.randn(4, 4, dtype=torch.float, device=device)
881

882
            ge = self.checkTrace(self.fn_test_exp, (x, y))
883
            self.assertAllFused(ge.graph_for(x, y))
884

885
    def test_threshold(self):
886
        for device in self.devices:
887

888
            def f(x):
889
                return torch.threshold(x, 0, -10) + x + x + x
890

891
            x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device=device)
892
            scripted = self.checkScript(f, (x,))
893
            self.assertAllFused(scripted.graph_for(x))
894

895
    def test_scalar_arg(self):
896
        for device in self.devices:
897

898
            def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
899
                return p * (x * x + x)
900

901
            x = torch.randn(4, 4, dtype=torch.float, device=device)
902
            p = 3
903
            scripted = self.checkScript(fn_test_scalar_arg, (x, p))
904
            self.assertAllFused(scripted.graph_for(x, p))
905

906
            x.requires_grad_(True)
907

908
            # use another function otherwise we will bailout
909
            # and won't be able to do fused checks
910
            def fn_test_scalar_arg_requires_grad(
911
                x: torch.Tensor, p: float
912
            ) -> torch.Tensor:
913
                return p * (x * x + x)
914

915
            scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
916
            out = scripted(x, p)
917
            out = scripted(x, p)
918
            out = scripted(x, p)
919
            self.assertAllFused(
920
                scripted.graph_for(x, p),
921
                except_for=(
922
                    "aten::size",
923
                    "prim::BroadcastSizes",
924
                    "aten::_size_if_not_equal",
925
                ),
926
            )
927

928
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
929
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
930
    def test_fusion_reuse_multi_gpu(self):
931
        def fn(x, y):
932
            return x * y * x * y
933

934
        inputs_cpu = [
935
            torch.randn(4, 4, dtype=torch.float),
936
            torch.randn(4, 4, dtype=torch.float),
937
        ]
938
        inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
939
        inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
940

941
        # Should not crash; these should compile different kernels.
942
        ge = self.checkScript(fn, inputs_cpu)
943
        self.assertAllFused(ge.graph_for(*inputs_cpu))
944
        ge(*inputs_cuda0)
945
        ge(*inputs_cuda1)
946

947
    # TODO: we're currently not checking 'device' in the type info when pulling
948
    # nodes into a fusion group. We should fix that and re-enable this test.
949
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
950
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
951
    def test_kernel_cache_multi_gpu(self):
952
        def not_fusible(x):
953
            return x
954

955
        def fn(x, y, z):
956
            x_out = x * x * x * x * x  # fusion: lambda x. x * x * x * x * x
957
            y_out = y * y * y * y * y
958
            z_out = z * z * z * z * z
959
            return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
960

961
        inputs = [
962
            torch.randn(4, 4, dtype=torch.float),
963
            torch.randn(4, 4, dtype=torch.float, device="cuda:0"),
964
            torch.randn(4, 4, dtype=torch.float, device="cuda:1"),
965
        ]
966

967
        prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
968

969
        # There are 3 FusionGroups. Because they have the same graph, they
970
        # should reuse the same KernelSpec in the KernelSpec cache.
971
        ge = self.checkScript(fn, inputs)
972
        self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 3, True)
973
        new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
974
        # XXX: This assumes that the same kernel isn't already used by another test
975
        # FIXME: Use the TE fuser's way of querying the cache.
976
        # self.assertEqual(new_cache_size - prev_cache_size, 1)
977

978
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
979
    def test_nonzero_device_cuda(self):
980
        device = "cuda:" + str(1)
981
        x = torch.tensor([0.4], dtype=torch.float, device=device)
982
        y = torch.tensor([0.7], dtype=torch.float, device=device)
983

984
        def doit(x, y):
985
            return torch.sigmoid(torch.tanh(x * (x + y) + x))
986

987
        ge = self.checkTrace(doit, (x, y))
988
        self.assertAllFused(ge.graph_for(x, y))
989

990
    def test_lstm(self):
991
        for device in self.devices:
992
            inputs = get_lstm_inputs(device, training=True)
993
            module = self.checkScript(LSTMCellS, inputs)
994
            self.assertAllFused(
995
                module.graph_for(inputs), except_for={"prim::TupleConstruct"}
996
            )
997

998
    def test_lstm_concat(self):
999
        # single fusion node causes error
1000
        with set_fusion_group_inlining(True):
1001
            for device in self.devices:
1002
                inputs = get_lstm_inputs(device)
1003
                ge = self.checkTrace(LSTMCellC, inputs)
1004
                graph = ge.graph_for(*inputs)
1005
                except_nodes = {"prim::TupleConstruct", "aten::linear"}
1006
                # TODO... Chunk
1007
                if self.dynamic_shapes:
1008
                    except_nodes = except_nodes.union(
1009
                        {"aten::add", "prim::ConstantChunk"}
1010
                    )
1011
                self.assertAllFused(ge.graph_for(*inputs), except_for=except_nodes)
1012
                # XXX: TE fuser can handle concats inside a fusion group.
1013
                # FileCheck().check("FusedConcat").check_next("return").run(str(graph))
1014

1015
    def test_lstm_gates_permutations(self):
1016
        for device in self.devices:
1017
            # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
1018
            # Test that any permutation of this will still result in one FusionGroup.
1019
            choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"]
1020
            template = dedent(
1021
                """
1022
            def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
1023
                gates = {} + {} + {} + {}
1024
                ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
1025
                return ingate * forgetgate * cellgate * outgate
1026
            """
1027
            )
1028
            for permutation in permutations(choices, len(choices)):
1029
                code = template.format(*permutation)
1030
                scope = {}
1031
                exec(code, globals(), scope)
1032
                cu = torch.jit.CompilationUnit(code)
1033
                fusion_group_len = 2 if self.dynamic_shapes else 1
1034
                inputs = get_lstm_inputs(device, training=False)
1035
                self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs))
1036
                forward_graph = cu.cell.graph_for(*inputs)
1037
                self.assertGraphContainsExactly(
1038
                    forward_graph, FUSION_GROUP, fusion_group_len
1039
                )
1040

1041
    # TODO: Fuser doesn't work at all when inputs require grad. Fix that
1042
    def test_lstm_traced(self):
1043
        for device in self.devices:
1044
            inputs = get_lstm_inputs(device)
1045
            ge = self.checkTrace(LSTMCellF, inputs)
1046
            graph = ge.graph_for(*inputs)
1047
            fusion_groups = self.findFusionGroups(graph)
1048
            # TODO: chunk
1049
            fusion_group_len = 2 if self.dynamic_shapes else 1
1050
            self.assertEqual(len(fusion_groups), fusion_group_len)
1051
            f = FileCheck()
1052
            if not self.dynamic_shapes:
1053
                f.check("Chunk")
1054
            f.check("aten::sigmoid").check("aten::tanh").run(
1055
                str(fusion_groups[0 if not self.dynamic_shapes else 1])
1056
            )
1057

1058
    def test_milstm(self):
1059
        if self.dynamic_shapes:
1060
            self.skipTest("don't run conv with dynamic shapes")
1061

1062
        for device in self.devices:
1063
            inputs = get_milstm_inputs(device, training=True)
1064
            module = self.checkScript(MiLSTMCell, inputs)
1065
            forward_graph = module.graph_for(*inputs)
1066
            # TODO: chunk
1067
            fusion_group_len = 2 if self.dynamic_shapes else 1
1068
            self.assertGraphContainsExactly(
1069
                forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True
1070
            )
1071
            FileCheck().check("DifferentiableGraph").check("TupleConstruct").check_next(
1072
                "return"
1073
            ).check(FUSION_GROUP).run(str(forward_graph))
1074
            hy, cy = module(*inputs)
1075
            warmup_backward((hy + cy).sum())
1076

1077
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
1078
    @unittest.skip("rand_like is not supported yet")
1079
    def test_rand_cuda(self):
1080
        class M(torch.jit.ScriptModule):
1081
            __constants__ = ["d"]
1082

1083
            def __init__(self) -> None:
1084
                super().__init__()
1085
                self.d = torch.device("cuda")
1086

1087
            @torch.jit.script_method
1088
            def create(self, x):
1089
                return x * x + x + torch.rand_like(x)
1090

1091
        x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda")
1092
        m = M()
1093
        out1 = m.create(x)
1094
        out2 = m.create(x)
1095
        self.assertNotEqual(out1, out2)
1096
        self.assertTrue(torch.all(out1 >= 0))
1097
        self.assertTrue(torch.all(out1 < 1))
1098
        self.assertTrue(torch.all(out2 >= 0))
1099
        self.assertTrue(torch.all(out2 < 1))
1100
        self.assertAllFused(m.create.graph_for(x))
1101

1102
    @staticmethod
1103
    def fn_test_relu(x, y):
1104
        return F.relu(x + 0.5 * y)
1105

1106
    def test_relu(self):
1107
        for device in self.devices:
1108
            x = torch.randn(4, 4, dtype=torch.float, device=device)
1109
            y = torch.randn(4, 4, dtype=torch.float, device=device)
1110

1111
            ge = self.checkTrace(self.fn_test_relu, (x, y))
1112
            self.assertAllFused(ge.graph_for(x, y))
1113

1114
    def test_erf(self):
1115
        for device in self.devices:
1116
            # only enabled on gpu
1117
            if device == "cpu":
1118
                continue
1119

1120
            def fn_test_erf(x):
1121
                return F.relu(torch.erf(x) - torch.erfc(x))
1122

1123
            x = torch.randn(4, 4, dtype=torch.float, device=device)
1124
            ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING)
1125
            self.assertAllFused(ge.graph_for(x))
1126
            x.requires_grad_(True)
1127
            ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING)
1128
            self.assertAllFused(
1129
                ge.graph_for(x),
1130
                except_for=(
1131
                    "aten::size",
1132
                    "prim::BroadcastSizes",
1133
                    "aten::_size_if_not_equal",
1134
                ),
1135
            )
1136

1137
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
1138
    @unittest.skip("rand_like is not supported yet")
1139
    def test_rand_broadcast_cuda(self):
1140
        def fn_test_rand(x, y):
1141
            r = torch.rand_like(y)
1142
            return r * x + x
1143

1144
        # If using profiling, a different function is needed to test different
1145
        # shapes, or we'll use a cached script.
1146
        def fn_test_rand2(x, y):
1147
            r = torch.rand_like(y)
1148
            return r * x * x
1149

1150
        x = torch.randn(4, 4, dtype=torch.float, device="cuda")
1151
        y = torch.randn(4, 4, dtype=torch.float, device="cuda")
1152
        script_f = torch.jit.script(fn_test_rand)
1153
        warmup_forward(script_f, x, y)
1154
        out = script_f(x, y)
1155
        self.assertAllFused(script_f.graph_for(x, y))
1156
        x.requires_grad_(True)
1157
        out = script_f(x, y)
1158
        self.assertAllFused(
1159
            script_f.graph_for(x, y),
1160
            except_for=(
1161
                "aten::size",
1162
                "prim::BroadcastSizes",
1163
                "aten::_size_if_not_equal",
1164
            ),
1165
        )
1166

1167
        # test that broadcasting random produces correct results
1168
        x = torch.ones(4, 4, dtype=torch.float, device="cuda")
1169
        y = torch.ones(4, dtype=torch.float, device="cuda")
1170
        script_f = torch.jit.script(fn_test_rand2)
1171
        warmup_forward(script_f, x, y)
1172
        out = script_f(x, y)
1173
        self.assertEqual(out[0, :] + torch.zeros(4, 4, device="cuda"), out)
1174

1175
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
1176
    @unittest.skip("rand_like is not supported yet")
1177
    def test_rand_diamond(self):
1178
        def fn_test_diamond(x, y):
1179
            r = torch.rand_like(y)
1180
            a = x + r
1181
            b = y - r
1182
            return a + b
1183

1184
        x = torch.randn(4, 4, dtype=torch.float, device="cuda")
1185
        y = torch.randn(4, 4, dtype=torch.float, device="cuda")
1186
        script_f = torch.jit.script(fn_test_diamond)
1187
        warmup_forward(script_f, x, y)
1188
        out = script_f(x, y)
1189
        self.assertEqual(out, x + y)
1190

1191
    def test_scalar(self):
1192
        def fn(x, y):
1193
            return 2 * x + y
1194

1195
        x = torch.tensor(0.1, dtype=torch.float, device="cpu")
1196
        y = torch.tensor(1, dtype=torch.float, device="cpu")
1197
        ge = self.checkScript(fn, (x, y))
1198
        self.assertAllFused(ge.graph_for(x, y))
1199

1200
    def test_inlined_optimized_graph(self):
1201
        @torch.jit.script
1202
        def foo(x):
1203
            return torch.relu(x + x)
1204

1205
        for _ in range(3):
1206
            foo(torch.rand([4, 4]))
1207

1208
        for _ in range(3):
1209
            foo(torch.rand([10]))
1210

1211
        for _ in range(3):
1212
            foo(torch.rand([2, 2, 2]))
1213

1214
        g = torch.jit.last_executed_optimized_graph()
1215

1216
        FileCheck().check_count("prim::If", 1, exactly=True).check(
1217
            "prim::TensorExpr"
1218
        ).run(g)
1219
        torch._C._jit_pass_inline(g)
1220
        f = FileCheck()
1221
        for _ in range(3):
1222
            f.check("prim::If").check("prim::TensorExpr")
1223
        f.run(g)
1224

1225
    def test_small_constant(self):
1226
        for device in self.devices:
1227

1228
            def fn_test_small_constant(x, y):
1229
                return (1e-8 * x + 5e-9 * y) * 1e8
1230

1231
            x = torch.randn(4, 4, dtype=torch.float, device=device)
1232
            y = torch.randn(4, 4, dtype=torch.float, device=device)
1233

1234
            ge = self.checkTrace(fn_test_small_constant, (x, y))
1235
            self.assertAllFused(ge.graph_for(x, y))
1236

1237
    # Currently we don't pull constants into fusion groups, because in some
1238
    # cases it could remove the constant from the original graph and now our
1239
    # fusion group needs to return that constant for its other users.
1240
    # Instead of never pulling constants into the fusion group, we should just
1241
    # be more careful at how we rewrite its users.
1242
    # TODO: fix that and reenable the test.
1243
    def test_tensor_scalar_ops(self):
1244
        for device in self.devices:
1245

1246
            def should_fuse(x):
1247
                z = 3.0
1248
                y = x + z
1249
                return x * y
1250

1251
            def should_fuse_scalar(x, z):
1252
                y = x + int(z)
1253
                return x * y
1254

1255
            inputs = [torch.randn(2, 2, dtype=torch.float, device=device)]
1256
            ge = self.checkScript(should_fuse, inputs)
1257
            graph = ge.graph_for(*inputs)
1258
            fusion_groups = self.findFusionGroups(graph)
1259
            self.assertEqual(len(fusion_groups), 1)
1260
            FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0]))
1261

1262
            inputs = [
1263
                torch.randn(2, 2, dtype=torch.float, device=device),
1264
                torch.tensor(3.0, dtype=torch.float, device=device),
1265
            ]
1266
            ge = self.checkScript(should_fuse_scalar, inputs)
1267
            # Check that the fused graph computes correct results when the scalar
1268
            # input changes.
1269
            inputs = [
1270
                torch.randn(2, 2, dtype=torch.float, device=device),
1271
                torch.tensor(7.0, dtype=torch.float, device=device),
1272
            ]
1273
            self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs))
1274
            # The TE fuser supports fusion of non-constant scalars
1275
            self.assertGraphContainsExactly(
1276
                ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True
1277
            )
1278

1279
    def test_where_and_typing(self):
1280
        for device in self.devices:
1281

1282
            def f(x, y):
1283
                mask = x > y
1284
                res = torch.where(mask, x, y)
1285
                return mask, res
1286

1287
            x = torch.randn(4, 4, dtype=torch.double, device=device)
1288
            y = torch.randn(4, 4, dtype=torch.double, device=device)
1289

1290
            script_f = self.checkScript(f, (x, y))
1291
            self.assertAllFused(
1292
                script_f.graph_for(x, y), except_for={"prim::TupleConstruct"}
1293
            )
1294

1295
    def test_disabled(self):
1296
        old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
1297
        torch._C._jit_override_can_fuse_on_cpu(False)
1298

1299
        def fn(a):
1300
            return a**2 + a
1301

1302
        x = torch.randn(4, dtype=torch.float, device="cpu")
1303
        s = self.checkScript(fn, (x,))
1304
        g = s.graph_for(x)
1305
        self.assertEqual(len(self.findFusionGroups(g)), 0)
1306

1307
        torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state)
1308

1309
    def data_for(self, dtype, device="cuda", size=None):
1310
        if size is None:
1311
            v = torch.arange(1, 3, dtype=torch.float, device=device)
1312
        else:
1313
            v = torch.rand(*size, device=device)
1314
        if dtype == torch.bool:
1315
            return v > 2
1316
        elif dtype in [torch.qint8, torch.quint8, torch.qint32]:
1317
            return torch.quantize_per_tensor(v, 0.1, 1, dtype=dtype)
1318
        else:
1319
            return v.to(dtype)
1320

1321
    def test_torch_to(self):
1322
        # test no op
1323
        @torch.jit.script
1324
        def foo(x):
1325
            return x.to(torch.float)
1326

1327
        foo(torch.tensor([3.0], dtype=torch.float))
1328
        foo(torch.tensor([3.0], dtype=torch.float))
1329
        FileCheck().check_not("TensorExpr").run(
1330
            torch.jit.last_executed_optimized_graph()
1331
        )
1332

1333
        # test not fusing non-const inputs
1334
        @torch.jit.script
1335
        def foo(x, dtype: int):
1336
            return x.to(dtype)
1337

1338
        foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1339
        foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1340
        FileCheck().check_not("TensorExpr").run(
1341
            torch.jit.last_executed_optimized_graph()
1342
        )
1343

1344
        # test not fusing to_pinned inputs
1345
        @torch.jit.script
1346
        def foo(x, dtype: int):
1347
            return x.to(pin_memory=True)
1348

1349
        foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1350
        foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1351
        FileCheck().check_not("TensorExpr").run(
1352
            torch.jit.last_executed_optimized_graph()
1353
        )
1354

1355
        # test across-device not supported
1356
        if torch.cuda.is_available():
1357

1358
            @torch.jit.script
1359
            def foo(x):
1360
                return x.to(device="cuda")
1361

1362
            foo(torch.tensor([3.0], dtype=torch.float))
1363
            foo(torch.tensor([3.0], dtype=torch.float))
1364
            FileCheck().check_not("TensorExpr").run(
1365
                torch.jit.last_executed_optimized_graph()
1366
            )
1367

1368
        sizes = [(1, 4), (4, 4)]
1369
        # reuses cast impl, smaller dtype set for faster test
1370
        dtypes = [
1371
            torch.bool,
1372
            torch.int,
1373
            torch.float16,
1374
            torch.float32,
1375
            torch.float64,
1376
        ]
1377

1378
        class MyMod(torch.nn.Module):
1379
            def __init__(self, dtype):
1380
                super().__init__()
1381
                self.dtype = dtype
1382

1383
            def forward(self, x):
1384
                return x.to(self.dtype)
1385

1386
        bad_dtypes = []
1387
        for dtype, output_dtype, device, size in product(
1388
            dtypes, dtypes, self.devices, sizes
1389
        ):
1390
            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1391
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1392
                continue
1393
            if dtype == output_dtype:
1394
                continue
1395

1396
            x = self.data_for(dtype, device, size=size)
1397
            mod = MyMod(output_dtype)
1398
            ref = mod.forward(x)
1399
            # use freezing to make non-Tensor args to `to` constant
1400
            mod = torch.jit.freeze(torch.jit.script(mod.eval()))
1401
            warmup_forward(mod.forward, x)
1402
            self.assertEqual(ref, mod.forward(x))
1403
            self.assertLastGraphAllFused()
1404

1405
    @unittest.skip("Temporarily disabled")
1406
    def test_masked_fill(self):
1407
        dtypes = [
1408
            torch.int8,
1409
            torch.int16,
1410
            torch.int32,
1411
            torch.int64,
1412
            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1413
            # torch.float16,
1414
            torch.float32,
1415
            torch.float64,
1416
            torch.bool,
1417
        ]
1418
        sizes = [(2,), (4, 4)]
1419
        for self_dtype, device, scalar_val, size in product(
1420
            dtypes, self.devices, [0.4, 3], sizes
1421
        ):
1422
            input_v = self.data_for(self_dtype, device, size=size)
1423
            mask = self.data_for(torch.bool, device, size=size)
1424

1425
            def fn(input_v, mask):
1426
                return torch.masked_fill(input_v, mask, scalar_val)
1427

1428
            ref = fn(input_v, mask)
1429
            try:
1430
                t = torch.jit.trace(fn, (input_v, mask))
1431
                torch.testing.assert_close(ref, t(input_v, mask))
1432
                self.assertLastGraphAllFused()
1433
            except Exception as e:
1434
                raise RuntimeError(
1435
                    " ".join(
1436
                        [
1437
                            "Failed:",
1438
                            str(self_dtype),
1439
                            op.__name__,  # noqa: F821
1440
                            device,
1441
                            str(size),
1442
                        ]
1443
                    )
1444
                ) from e
1445

1446
    def test_isnan(self):
1447
        x = torch.rand([4])
1448
        x[0] = float("nan")
1449
        inputs = [x, torch.tensor([float("nan"), 0.5])]
1450
        dtypes = [
1451
            torch.int8,
1452
            torch.int16,
1453
            torch.int32,
1454
            torch.int64,
1455
            torch.float16,
1456
            torch.float32,
1457
            torch.float64,
1458
            torch.bool,
1459
        ]
1460

1461
        for inp, device, dtype in product(inputs, self.devices, dtypes):
1462
            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1463
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1464
                continue
1465
            inp = inp.to(device=device, dtype=dtype)
1466
            try:
1467
                f = torch.jit.trace(lambda x: x.isnan(), (inp,))
1468
                warmup_forward(f, inp)
1469
                self.assertEqual(f(inp), inp.isnan())
1470
                self.assertLastGraphAllFused()
1471
            except Exception as e:
1472
                raise RuntimeError(
1473
                    " ".join(["Failed:", str(dtype), "isnan", device])
1474
                ) from e
1475

1476
    def test_gelu(self):
1477
        def apply(fn):
1478
            return lambda x, approximate: fn(x, approximate)
1479

1480
        unary_ops = [
1481
            F.gelu,
1482
        ]
1483
        sizes = [(1,), (2,), (4, 4)]
1484
        for dtype, op, device, size in product(
1485
            self.dtypes, unary_ops, self.devices, sizes
1486
        ):
1487
            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1488
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1489
                continue
1490
            try:
1491
                x = self.data_for(dtype, device, size=size)
1492
                cond = self.data_for(torch.bool, device)
1493
                fn = apply(op)
1494
                ref = fn(x, cond)
1495
            except Exception:
1496
                # If eager mode doesn't support a dtype/op/device combo,
1497
                # neither does the fuser.  Catch everything to avoid needing to
1498
                # guess what errors might be thrown by eager.
1499
                continue
1500
            try:
1501
                t = torch.jit.trace(fn, (x, cond))
1502
                torch.testing.assert_close(ref, t(x, cond))
1503
                self.assertAllFused(t.graph_for(x, cond))
1504
            except Exception as e:
1505
                raise RuntimeError(
1506
                    " ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
1507
                ) from e
1508

1509
    def test_unary_ops(self):
1510
        with torch._jit_internal._disable_emit_hooks():
1511

1512
            def apply(fn):
1513
                return lambda x: fn(x)
1514

1515
            unary_ops = [
1516
                torch.lgamma,
1517
                torch.sigmoid,
1518
                torch.reciprocal,
1519
                torch.neg,
1520
                torch.relu,
1521
                F.relu6,
1522
                torch.log,
1523
                torch.log10,
1524
                torch.log1p,
1525
                torch.log2,
1526
                torch.exp,
1527
                torch.expm1,
1528
                torch.erf,
1529
                torch.erfc,
1530
                torch.cos,
1531
                torch.sin,
1532
                torch.tan,
1533
                torch.acos,
1534
                torch.asin,
1535
                torch.cosh,
1536
                torch.sinh,
1537
                torch.atan,
1538
                torch.tanh,
1539
                F.hardtanh,
1540
                F.hardsigmoid,
1541
                F.hardswish,
1542
                F.softplus,
1543
                F.silu,
1544
                F.mish,
1545
                F.elu,
1546
                torch.sqrt,
1547
                torch.rsqrt,
1548
                torch.abs,
1549
                # TODO broken on int8 since
1550
                # https://github.com/pytorch/pytorch/pull/85144
1551
                # RuntimeError: Invalid integral op_type: 23
1552
                # torch.ceil,
1553
                # torch.floor,
1554
                # torch.round,
1555
                # torch.trunc,
1556
                torch.frac,
1557
                # TODO: broken on ROCm?
1558
                # F.hardshrink,
1559
                F.leaky_relu,
1560
                lambda x: torch.threshold(x, 0, -10),
1561
                # TODO: broken since type promotion was added
1562
                # lambda x: torch.clamp(x, -10, 10),
1563
            ]
1564
            gpu_only = {torch.erf, torch.erfc}
1565
            sizes = [(1,), (2,), (4, 4)]
1566
            for dtype, op, device, size in product(
1567
                self.dtypes, unary_ops, self.devices, sizes
1568
            ):
1569
                # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
1570
                if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1571
                    continue
1572
                # todo - re-enable. fails with .500
1573
                if dtype == torch.bfloat16 and op == torch.round:
1574
                    continue
1575
                if op in gpu_only and device == "cpu":
1576
                    continue
1577
                try:
1578
                    x = self.data_for(dtype, device, size=size)
1579
                    fn = apply(op)
1580
                    ref = fn(x)
1581
                except Exception:
1582
                    # If eager mode doesn't support a dtype/op/device combo,
1583
                    # neither does the fuser.  Catch everything to avoid needing to
1584
                    # guess what errors might be thrown by eager.
1585
                    continue
1586
                try:
1587
                    t = torch.jit.trace(fn, (x,))
1588
                    torch.testing.assert_close(ref, t(x))
1589
                    self.assertAllFused(t.graph_for(x))
1590
                except Exception as e:
1591
                    raise RuntimeError(
1592
                        " ".join(
1593
                            ["Failed:", str(dtype), op.__name__, device, str(size)]
1594
                        )
1595
                    ) from e
1596

1597
    def test_binary_ops(self):
1598
        def apply(fn):
1599
            return lambda x, y: fn(x, y)
1600

1601
        binary_ops = [
1602
            operator.__and__,
1603
            operator.__or__,
1604
            operator.__xor__,
1605
            torch.add,
1606
            torch.sub,
1607
            torch.mul,
1608
            torch.min,
1609
            torch.max,
1610
            lambda x, y: torch.lerp(x, y, 0.5),
1611
            torch.atan2,
1612
            torch.div,
1613
            torch.eq,
1614
            torch.ne,
1615
            torch.ge,
1616
            torch.gt,
1617
            torch.lt,
1618
            torch.fmod,
1619
            torch.remainder,
1620
            lambda x, y: y.type_as(x),
1621
        ]
1622
        fp_only = [
1623
            torch.fmod,
1624
            torch.remainder,
1625
        ]
1626
        devices = self.devices
1627
        for dtype, op, device in product(self.dtypes, binary_ops, devices):
1628
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1629
                continue
1630
            try:
1631
                x = self.data_for(dtype, device)
1632
                y = self.data_for(dtype, device)
1633
                fn = apply(op)
1634
                ref = fn(x, y)
1635
            except Exception:
1636
                # If eager mode doesn't support a dtype/op/device combo,
1637
                # neither does the fuser.  Catch everything to avoid needing to
1638
                # guess what errors might be thrown by eager.
1639
                continue
1640
            try:
1641
                t = torch.jit.trace(fn, (x, y))
1642
                self.assertEqual(ref, t(x, y))
1643
                if op not in fp_only or dtype.is_floating_point:
1644
                    self.assertAllFused(t.graph_for(x, y))
1645
            except Exception as e:
1646
                raise RuntimeError(
1647
                    " ".join(["Failed:", str(dtype), op.__name__, device])
1648
                ) from e
1649

1650
    def test_binary_scalar_ops(self):
1651
        def apply(fn):
1652
            return lambda x, y: fn(x, y)
1653

1654
        ir_template = """
1655
        graph(%x : {dtype_x}, %y : {dtype_y}):
1656
          %z = {op}(%x, %y)
1657
          return (%z)"""
1658

1659
        binary_ops = [
1660
            "aten::mul",
1661
            "aten::add",
1662
            "aten::sub",
1663
            "aten::div",
1664
            "aten::lt",
1665
            "aten::le",
1666
            "aten::eq",
1667
            "aten::ne",
1668
            "aten::gt",
1669
            "aten::ge",
1670
            "aten::__or__",
1671
            "aten::__xor__",
1672
            "aten::__and__",
1673
            "aten::__lshift__",
1674
            "aten::__rshift__",
1675
        ]
1676
        dtypes = ["int", "float", "bool"]
1677
        values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]}
1678
        devices = self.devices
1679
        for dtype_x, dtype_y, op, device in product(
1680
            dtypes, dtypes, binary_ops, devices
1681
        ):
1682
            code = ir_template.format(**locals())
1683

1684
            # Interpret the graph
1685
            try:
1686
                graph = torch._C.parse_ir(code)
1687
                for x, y in product(values[dtype_x], values[dtype_y]):
1688
                    ref = torch._C._jit_interpret_graph(graph, (x, y))
1689
            except Exception:
1690
                # If we can't interpret this IR, don't bother checking NNC.
1691
                continue
1692

1693
            # Compile the graph
1694
            try:
1695
                k = torch._C._te.TensorExprKernel(graph)
1696
            except Exception as e:
1697
                raise RuntimeError(
1698
                    " ".join(["Compilation failed:", device, str(code)])
1699
                ) from e
1700

1701
            # Run the graph
1702
            for x, y in product(values[dtype_x], values[dtype_y]):
1703
                ref = torch._C._jit_interpret_graph(graph, (x, y))
1704
                try:
1705
                    res = k.run((x, y))
1706
                    self.assertEqual(ref, res)
1707
                except Exception as e:
1708
                    raise RuntimeError(
1709
                        " ".join(
1710
                            ["Failed at runtime:", device, str(x), str(y), str(code)]
1711
                        )
1712
                    ) from e
1713

1714
    def test_matmul(self):
1715
        if self.dynamic_shapes:
1716
            self.skipTest("don't run conv with dynamic shapes")
1717

1718
        def fn(x, y):
1719
            return torch.matmul(x, y)
1720

1721
        devices = ["cpu"]  # No cuda support for ext calls yet
1722
        sizes = [
1723
            [[128, 128], [128, 128]],
1724
            [[10, 10], [10, 10]],
1725
            [[1, 16], [16, 128]],
1726
            [[128], [128]],
1727
            [[128], [128, 128]],
1728
            [[3], [3]],
1729
            [[3, 4], [4]],
1730
            [[10, 3, 4], [4]],
1731
            [[10, 3, 4], [10, 4, 5]],
1732
            [[10, 3, 4], [4, 5]],
1733
        ]
1734

1735
        # Only 2D x 2D matrix multiply is supported. For non-supported sizes we
1736
        # still want to run results verification to test that we didn't
1737
        # accidentally fuse it, but we skip the 'is-fused' check.
1738
        # TODO: add support for other shape combinations and make this set empty:
1739
        skip_is_fused_check_sizes = [
1740
            "[[128], [128]]",
1741
            "[[128], [128, 128]]",
1742
            "[[3], [3]]",
1743
            "[[3, 4], [4]]",
1744
            "[[10, 3, 4], [4]]",
1745
            "[[10, 3, 4], [10, 4, 5]]",
1746
            "[[10, 3, 4], [4, 5]]",
1747
        ]
1748
        for dtype, size, device in product(self.dtypes, sizes, devices):
1749
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1750
                continue
1751
            try:
1752
                size_x, size_y = size
1753
                x = self.data_for(dtype, device, size=size_x)
1754
                y = self.data_for(dtype, device, size=size_y)
1755
                ref = fn(x, y)
1756
            except Exception as e:
1757
                # If eager mode doesn't support a dtype/op/device combo,
1758
                # neither does the fuser.  Catch everything to avoid needing to
1759
                # guess what errors might be thrown by eager.
1760
                continue
1761
            try:
1762
                t = torch.jit.trace(fn, (x, y))
1763
                t(x, y)
1764
                self.assertEqual(ref, t(x, y))
1765
                if str(size) not in skip_is_fused_check_sizes:
1766
                    self.assertAllFused(t.graph_for(x, y))
1767
            except Exception as e:
1768
                raise RuntimeError(" ".join(["Failed:", str(dtype), device])) from e
1769

1770
    def test_binary_tensor_scalar_ops(self):
1771
        with torch._jit_internal._disable_emit_hooks():
1772

1773
            def apply_with_scalar(fn, scalar):
1774
                return lambda x: fn(x, scalar)
1775

1776
            # FIXME: Fails in IR Eval: torch.int64 and_ cpu
1777
            binary_ops = [
1778
                operator.__and__,
1779
                operator.__or__,
1780
                operator.__xor__,
1781
                torch.add,
1782
                torch.sub,
1783
                torch.mul,
1784
                torch.eq,
1785
                torch.ne,
1786
                torch.ge,
1787
                torch.lt,
1788
                torch.gt,
1789
            ]
1790
            devices = self.devices
1791
            # Maybe we should split this into separate tests to speed it up by
1792
            # only using  scalar values relevant to particular ops
1793
            scalars = [1.5, 3, 0, -2.0, -1]
1794
            for dtype, op, device, scalar in product(
1795
                self.dtypes, binary_ops, devices, scalars
1796
            ):
1797
                if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1798
                    continue
1799
                try:
1800
                    x = self.data_for(dtype, device)
1801
                    fn = apply_with_scalar(op, scalar)
1802
                    ref = fn(x)
1803
                except Exception:
1804
                    # If eager mode doesn't support a dtype/op/device combo,
1805
                    # neither does the fuser.  Catch everything to avoid needing to
1806
                    # guess what errors might be thrown by eager.
1807
                    continue
1808
                try:
1809
                    t = torch.jit.trace(fn, (x))
1810
                    self.assertEqual(ref, t(x))
1811
                    self.assertAllFused(t.graph_for(x))
1812
                except Exception as e:
1813
                    raise RuntimeError(
1814
                        " ".join(["Failed:", str(dtype), op.__name__, device])
1815
                    ) from e
1816

1817
    def test_binary_div_ops(self):
1818
        def apply_with_scalar(fn, scalar):
1819
            return lambda x: fn(x, scalar)
1820

1821
        binary_ops = [
1822
            torch.div,
1823
            torch.remainder,
1824
            torch.fmod,
1825
        ]
1826
        devices = self.devices
1827
        # Maybe we should split this into separate tests to speed it up by
1828
        # only using  scalar values relevant to particular ops
1829
        scalars = [1.5, 3, -2.0, -1]  # skip 0
1830
        for dtype, op, device, scalar in product(
1831
            self.dtypes, binary_ops, devices, scalars
1832
        ):
1833
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1834
                continue
1835
            try:
1836
                x = self.data_for(dtype, device)
1837
                fn = apply_with_scalar(op, scalar)
1838
                ref = fn(x)
1839
            except Exception:
1840
                # If eager mode doesn't support a dtype/op/device combo,
1841
                # neither does the fuser.  Catch everything to avoid needing to
1842
                # guess what errors might be thrown by eager.
1843
                continue
1844
            try:
1845
                t = torch.jit.trace(fn, (x))
1846
                self.assertEqual(ref, t(x))
1847
            except Exception as e:
1848
                raise RuntimeError(
1849
                    f"Failed: {dtype} {op.__name__} {device} {scalar}"
1850
                ) from e
1851

1852
    def test_binary_pow(self):
1853
        def apply_with_scalar(fn, scalar):
1854
            return lambda x: fn(x, scalar)
1855

1856
        dtypes = [
1857
            # FIXME: 'pow' fails with dtype=torch.float16/device=cuda/scalar=0
1858
            # torch.float16,
1859
            torch.float32,
1860
            torch.float64,
1861
            # torch.bool intentionally not included
1862
        ]
1863
        binary_ops = [
1864
            torch.pow,
1865
        ]
1866
        # Maybe we should split this into separate tests to speed it up by
1867
        # only using  scalar values relevant to particular ops
1868
        scalars = [1.5, 3, 0, -2.0, -1]
1869
        for dtype, op, device, scalar in product(
1870
            dtypes, binary_ops, self.devices, scalars
1871
        ):
1872
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1873
                continue
1874
            try:
1875
                x = self.data_for(dtype, device)
1876
                fn = apply_with_scalar(op, scalar)
1877
                ref = fn(x)
1878
            except Exception:
1879
                # If eager mode doesn't support a dtype/op/device combo,
1880
                # neither does the fuser.  Catch everything to avoid needing to
1881
                # guess what errors might be thrown by eager.
1882
                continue
1883
            try:
1884
                t = torch.jit.trace(fn, (x))
1885
                self.assertEqual(ref, t(x))
1886
                self.assertAllFused(t.graph_for(x))
1887
            except Exception as e:
1888
                raise RuntimeError(
1889
                    " ".join(["Failed:", str(dtype), op.__name__, device])
1890
                ) from e
1891

1892
    def test_ternary_ops(self):
1893
        def apply(fn):
1894
            return lambda x, y, z: fn(x, y, z)
1895

1896
        ternary_ops = [
1897
            torch.lerp,
1898
            torch.addcmul,
1899
        ]
1900
        devices = self.devices
1901
        for dtype, op, device in product(self.dtypes, ternary_ops, devices):
1902
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1903
                continue
1904
            try:
1905
                x = self.data_for(dtype, device)
1906
                y = self.data_for(dtype, device)
1907
                z = self.data_for(dtype, device)
1908
                fn = apply(op)
1909
                ref = fn(x, y, z)
1910
            except Exception:
1911
                # If eager mode doesn't support a dtype/op/device combo,
1912
                # neither does the fuser.  Catch everything to avoid needing to
1913
                # guess what errors might be thrown by eager.
1914
                continue
1915
            try:
1916
                t = torch.jit.trace(fn, (x, y, z))
1917
                self.assertEqual(ref, t(x, y, z))
1918
                self.assertAllFused(t.graph_for(x, y, z))
1919
            except Exception as e:
1920
                raise RuntimeError(
1921
                    " ".join(["Failed:", str(dtype), op.__name__, device])
1922
                ) from e
1923

1924
    def test_ternary_norm_ops(self):
1925
        def apply(fn):
1926
            return lambda x, y, z: fn(x, y, z)
1927

1928
        ternary_ops = [
1929
            F.batch_norm,
1930
        ]
1931
        devices = self.devices
1932
        for dtype, op, device in product(self.dtypes, ternary_ops, devices):
1933
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1934
                continue
1935
            try:
1936
                x = self.data_for(dtype, device, size=[5, 3, 128, 128])
1937
                y = self.data_for(dtype, device, size=[3])
1938
                z = self.data_for(dtype, device, size=[3])
1939
                fn = apply(op)
1940
                ref = fn(x, y, z)
1941
            except Exception:
1942
                # If eager mode doesn't support a dtype/op/device combo,
1943
                # neither does the fuser.  Catch everything to avoid needing to
1944
                # guess what errors might be thrown by eager.
1945
                continue
1946
            try:
1947
                t = torch.jit.trace(fn, (x, y, z))
1948
                self.assertEqual(ref, t(x, y, z))
1949
                self.assertAllFused(t.graph_for(x, y, z))
1950
            except Exception as e:
1951
                raise RuntimeError(
1952
                    " ".join(["Failed:", str(dtype), op.__name__, device])
1953
                ) from e
1954

1955
    @unittest.skip(
1956
        "FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure"
1957
    )
1958
    def test_list_ops(self):
1959
        def apply(fn):
1960
            return lambda x, y, z: fn([x * x, y * y, z * z])
1961

1962
        devices = self.devices
1963
        list_ops = [
1964
            torch.cat,
1965
        ]
1966
        for dtype, op, device in product(self.dtypes, list_ops, devices):
1967
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1968
                continue
1969
            try:
1970
                x = self.data_for(dtype, device, size=[5, 4, 1, 7])
1971
                y = self.data_for(dtype, device, size=[5, 4, 1, 7])
1972
                z = self.data_for(dtype, device, size=[5, 4, 1, 7])
1973
                fn = apply(op)
1974
                ref = fn(x, y, z)
1975
            except Exception:
1976
                # If eager mode doesn't support a dtype/op/device combo,
1977
                # neither does the fuser.  Catch everything to avoid needing to
1978
                # guess what errors might be thrown by eager.
1979
                continue
1980
            try:
1981
                t = torch.jit.trace(fn, (x, y, z))
1982
                self.assertEqual(ref, t(x, y, z))
1983
                self.assertAllFused(t.graph_for(x, y, z))
1984
            except Exception as e:
1985
                raise RuntimeError(
1986
                    " ".join(["Failed:", str(dtype), op.__name__, device])
1987
                ) from e
1988

1989
    def test_where_ops(self):
1990
        def apply(fn):
1991
            return lambda cond, x, y: fn(cond, x, y)
1992

1993
        ops = [
1994
            torch.where,
1995
            lambda cond, x, y: torch.where(cond, x, 3.1415),
1996
            lambda cond, x, y: torch.where(cond, 42, y),
1997
        ]
1998
        devices = self.devices
1999
        for dtype, op, device in product(self.dtypes, ops, devices):
2000
            if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
2001
                continue
2002
            try:
2003
                cond = self.data_for(torch.bool, device)
2004
                x = self.data_for(dtype, device)
2005
                y = self.data_for(dtype, device)
2006
                fn = apply(op)
2007
                ref = fn(cond, x, y)
2008
            except Exception:
2009
                # If eager mode doesn't support a dtype/op/device combo,
2010
                # neither does the fuser.  Catch everything to avoid needing to
2011
                # guess what errors might be thrown by eager.
2012
                continue
2013
            try:
2014
                t = torch.jit.trace(fn, (cond, x, y))
2015
                self.assertEqual(ref, t(cond, x, y))
2016
                self.assertAllFused(t.graph_for(cond, x, y))
2017
            except Exception as e:
2018
                raise RuntimeError(
2019
                    " ".join(["Failed:", str(dtype), op.__name__, device])
2020
                ) from e
2021

2022
    def test_unsupported_dtypes(self):
2023
        for device in self.devices:
2024

2025
            def fn(x):
2026
                return x * x + x
2027

2028
            unsupported_dtypes = [
2029
                torch.uint8,
2030
                torch.complex32,
2031
                torch.complex64,
2032
                torch.complex128,
2033
                torch.qint8,
2034
                torch.quint8,
2035
                torch.qint32,
2036
            ]
2037
            for dtype in unsupported_dtypes:
2038
                try:
2039
                    x = self.data_for(dtype, device)
2040
                    ref = fn(x)
2041
                except Exception:
2042
                    # If eager mode doesn't support a dtype/op/device combo,
2043
                    # neither does the fuser.  Catch everything to avoid needing to
2044
                    # guess what errors might be thrown by eager.
2045
                    continue
2046
                t = torch.jit.trace(fn, (x,))
2047
                self.assertEqual(ref, t(x))
2048
                self.assertEqual(len(self.findFusionGroups(t.graph_for(x))), 0)
2049

2050
    def test_superslomo(self):
2051
        devices = self.devices.copy()
2052
        if not LLVM_ENABLED:
2053
            devices.remove("cpu")
2054
        for device in devices:
2055
            # Test extracted from Super-SloMo: https://github.com/avinashpaliwal/Super-SloMo
2056
            # A few interesting things happen here: strided inputs of mixed size,
2057
            # plus outputs of mixed shapes.  The latter characteristic happened to
2058
            # expose a memory corruption bug due to not properly guarding the
2059
            # outputs.
2060
            def eager(t0, t1, t2, t3, t4):
2061
                t5 = torch.mul(t0, t4)
2062
                t6 = torch.mul(t2, t3)
2063
                t7 = torch.mul(t6, t1)
2064
                t9 = torch.add(t5, t7)
2065
                t11 = torch.add(t0, t6)
2066
                ft_p = torch.div(t9, t11)
2067
                return (ft_p, t11, t9, t6)
2068

2069
            t0 = torch.rand(1, 6, 352, 352, device=device).transpose(0, 1)
2070
            t1 = torch.rand(6, 3, 352, 352, device=device)
2071
            t2 = torch.rand(6, device=device)[None, None, None, :].permute(3, 0, 1, 2)
2072
            t3 = torch.rand(6, 1, 352, 352, device=device)
2073
            t4 = torch.rand(6, 3, 352, 352, device=device)
2074
            inputs = [t0, t1, t2, t3, t4]
2075

2076
            script = torch.jit.script(eager)
2077
            for _ in range(4):
2078
                for pair in zip(script(*inputs), eager(*inputs)):
2079
                    test, ref = pair
2080
                    torch.testing.assert_close(test, ref)
2081
                    self.assertAllFused(
2082
                        script.graph_for(*inputs), except_for={"prim::TupleConstruct"}
2083
                    )
2084

2085
    def test_sub_gt_and(self):
2086
        for device in self.devices:
2087

2088
            def eager(t1, t2, t3, t4, t: float):
2089
                w = t1 - t2
2090
                h = t3 - t4
2091
                k = (w > t) & (h > t)
2092
                assert k.dtype == torch.bool
2093
                if t > 0.5:
2094
                    # Putting a use of k in a never-executed conditional prevents
2095
                    # profiling its type, which leaves it as "Tensor".  If we
2096
                    # propagate Tensor back to the definition of k, we have to be
2097
                    # careful not to create a fusion group containing it.
2098
                    return k + 1
2099
                return w
2100

2101
            t = torch.rand(8, dtype=torch.float, device=device)
2102
            scripted = self.checkScript(eager, (t, t, t, t, 0.1))
2103

2104
    @skipIfTorchDynamo("too slow")
2105
    def test_chunk_mul_one(self):
2106
        if self.dynamic_shapes:
2107
            self.skipTest("TODO: chunk dynamic shapes")
2108

2109
        for device in self.devices:
2110

2111
            def eager(x):
2112
                z, y, w = torch.chunk(x, 3, -1)
2113
                return z * 3, y, w
2114

2115
            x = torch.rand(64, 1, 3072, dtype=torch.float, device=device)
2116
            z, y, w = eager(x)
2117
            script = self.checkScript(eager, (x,))
2118

2119
    def test_eq_unsqueeze_type_as(self):
2120
        for device in self.devices:
2121

2122
            def eager(a, b):
2123
                mask = b == 1
2124
                mask = torch.unsqueeze(mask, -1)
2125
                x = mask.type_as(a)
2126
                return x, mask
2127

2128
            a = torch.rand(1, 64, 1024, device=device, dtype=torch.float)
2129
            b = torch.randint(-2, 2, (1, 64), device=device, dtype=torch.long)
2130
            script = self.checkScript(eager, (a, b))
2131

2132
    def test_neg_pow(self):
2133
        def eager_tt(a: torch.Tensor, b: torch.Tensor):
2134
            return torch.neg(torch.pow(a, b))
2135

2136
        def eager_ts(a: torch.Tensor, b: float):
2137
            return torch.neg(torch.pow(a, b))
2138

2139
        def eager_st(a: float, b: torch.Tensor):
2140
            return torch.neg(torch.pow(a, b))
2141

2142
        a = torch.rand(1, dtype=torch.float)
2143
        b = torch.rand(1, dtype=torch.float)
2144
        s = b.item()
2145
        script = self.checkScript(eager_tt, (a, b))
2146
        # TODO: re-enable fusion, which doesn't work right now. just test correctness for now
2147
        # self.assertAllFused(script.graph_for(a, b))
2148
        script = self.checkScript(eager_ts, (a, s))
2149
        # self.assertAllFused(script.graph_for(a, s))
2150
        script = self.checkScript(eager_st, (s, b))
2151
        # self.assertAllFused(script.graph_for(s, b))
2152

2153
    @unittest.skipIf(not LLVM_ENABLED, "Too slow to run with the TE interpreter")
2154
    def test_conv2d_depthwise(self):
2155
        if self.dynamic_shapes:
2156
            self.skipTest("don't run conv with dynamic shapes")
2157

2158
        def eager(input, weight, bias):
2159
            return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=72)
2160

2161
        input = torch.rand((1, 72, 56, 56), dtype=torch.float)
2162
        weight = torch.rand((72, 1, 3, 3), dtype=torch.float)
2163
        bias = torch.rand((72), dtype=torch.float)
2164

2165
        script = self.checkScript(eager, (input, weight, bias))
2166
        self.assertAllFused(script.graph_for(input, weight, bias))
2167

2168
    def test_conv2d(self):
2169
        if self.dynamic_shapes:
2170
            self.skipTest("don't run conv with dynamic shapes")
2171

2172
        def eager(input, weight, bias):
2173
            return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=1)
2174

2175
        input = torch.rand((1, 64, 56, 56), dtype=torch.float)
2176
        weight = torch.rand((64, 64, 3, 3), dtype=torch.float)
2177
        bias = torch.rand((64), dtype=torch.float)
2178

2179
        script = self.checkScript(eager, (input, weight, bias))
2180
        FileCheck().check_not("TensorExpr").run(
2181
            torch.jit.last_executed_optimized_graph()
2182
        )
2183

2184
    def test_type_as_cat(self):
2185
        with inline_fusion_groups():
2186

2187
            def eager(x, y):
2188
                return torch.cat((x, y.type_as(x)), dim=1)
2189

2190
            dtypes = self.dtypes.copy()
2191
            # CPU fuser doesn't support float16.
2192
            dtypes.remove(torch.float16)
2193
            dtypes.remove(torch.bfloat16)
2194
            for dtype1, dtype2 in product(dtypes, dtypes):
2195
                x = torch.randint(2, (1, 13)).to(dtype1)
2196
                zero = torch.tensor([[0]]).to(dtype2)
2197
                one = torch.tensor([[1]]).to(dtype2)
2198
                script = torch.jit.trace(eager, (x, zero))
2199
                for _ in range(3):
2200
                    torch.testing.assert_close(script(x, zero), eager(x, zero))
2201
                    torch.testing.assert_close(script(x, one), eager(x, one))
2202
                self.assertAllFused(script.graph_for(x, one))
2203

2204
    def test_to_device(self):
2205
        def eager(x):
2206
            return x.to(device="cpu").relu()
2207

2208
        x = torch.rand(8)
2209
        script = self.checkScript(eager, (x,))
2210
        self.assertAllFused(script.graph_for(x))
2211

2212
    def test_dims(self):
2213
        def eager(x, y):
2214
            return x / (y + 0.0001)
2215

2216
        x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided(
2217
            (1, 1, 768), (768, 1, 1)
2218
        )
2219
        y = torch.tensor([[[2.0]]], dtype=torch.float32)
2220
        script = self.checkScript(eager, (x, y))
2221
        self.assertAllFused(script.graph_for(x, y))
2222

2223
    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
2224
    def test_channels_last_dims_dynamic(self):
2225
        def eager(x, y):
2226
            return x + (y + 0.0001)
2227

2228
        indices = [0, 1, 2, 3]
2229
        sets = []
2230
        for i in range(0, len(indices) + 1):
2231
            for subset in combinations(indices, i):
2232
                sets.append(subset)  # noqa: PERF402
2233

2234
        for set in sets:
2235
            size = [2, 3, 4, 5]
2236
            for index in set:
2237
                size[index] = 1
2238
            inp = torch.rand(size).to(memory_format=torch.channels_last).cuda()
2239
            with texpr_enable_strategy([("DYNAMIC", 20)]):
2240
                foo_s = torch.jit.trace(eager, (inp, inp))
2241
                for _ in range(3):
2242
                    out = foo_s(inp, inp)
2243
                out_eager = eager(inp, inp)
2244
                self.assertEqual(out_eager, out)
2245
                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
2246
                g = torch.jit.last_executed_optimized_graph()
2247
                FileCheck().check("TensorExpr").run(g)
2248

2249
    def test_exhaust_specializations(self):
2250
        with texpr_enable_strategy([("STATIC", 1)]):
2251

2252
            @torch.jit.script
2253
            def foo(x):
2254
                return x + x + x
2255

2256
            for _ in range(3):
2257
                foo(torch.rand([2, 2]))
2258

2259
            for _ in range(3):
2260
                foo(torch.rand([4, 4, 4]))
2261

2262
            g = torch.jit.last_executed_optimized_graph()
2263
            torch._C._jit_pass_inline(g)
2264

2265
            FileCheck().check_count("TensorExpr", 2, exactly=True).run(g)
2266

2267
    def test_unsqueeze_var_dim(self):
2268
        def eager(x, y, z: int):
2269
            return x * torch.unsqueeze(y, dim=z)
2270

2271
        x = torch.rand(4, 4, 64).permute(1, 0, 2)
2272
        y = torch.rand(4, 4)
2273
        z = 2
2274
        script = self.checkScript(eager, (x, y, z))
2275

2276
    def _test_fwd_bwd(self, fn):
2277
        x = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True)
2278
        xs = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True)
2279
        script = torch.jit.script(fn)
2280
        for i in range(11):
2281
            y = fn(x)
2282
            g0 = torch.rand_like(y)
2283
            y.backward(g0)
2284

2285
            ys = script(xs)
2286
            ys.backward(g0)
2287

2288
            with torch.no_grad():
2289
                x -= 0.1 * x.grad
2290
                xs -= 0.1 * xs.grad
2291
                x.grad = None
2292
                xs.grad = None
2293
        torch.testing.assert_close(y, ys)
2294

2295
    def test_relu_fwd_bwd(self):
2296
        def eager(x):
2297
            return torch.relu(x * 1.01)
2298

2299
        self._test_fwd_bwd(eager)
2300

2301
    def test_hardswish_fwd_bwd(self):
2302
        def eager(x):
2303
            return F.hardswish(x) * 1.01
2304

2305
        self._test_fwd_bwd(eager)
2306

2307
    def test_hardsigmoid_fwd_bwd(self):
2308
        def eager(x):
2309
            return F.hardsigmoid(x) * 1.01
2310

2311
        self._test_fwd_bwd(eager)
2312

2313
    def test_cat_graph_opt(self):
2314
        def foo(x, y, z):
2315
            return torch.log(torch.cat([x, y, z]))
2316

2317
        self.checkScript(
2318
            foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5]))
2319
        )
2320
        # TODO: not sure why not updated graph isn't reflected in last_optimized_graph
2321
        self.assertLastGraphAllFused()
2322

2323
    def test_dynamic_cat(self):
2324
        with inline_fusion_groups():
2325

2326
            @torch.jit.script
2327
            def repro(
2328
                xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor]
2329
            ):
2330
                return [
2331
                    torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1)
2332
                    for x, y, z in zip(xs, ys, zs)
2333
                ]
2334

2335
            for _ in range(3):
2336
                N = 3
2337
                xs = [torch.ones(21) for _ in range(N)]
2338
                # Note: concat of ys and zs will have the same size for each
2339
                # pair, even though the individual ys and zs do not.
2340
                ys = [torch.ones(N - i) for i in range(N)]
2341
                zs = [torch.ones(i) for i in range(N)]
2342
                repro(xs, ys, zs)
2343

2344
    def test_scalar_only_inputs(self):
2345
        def eager(b: float):
2346
            a = torch.ones(1)
2347
            return a * b
2348

2349
        script = self.checkScript(eager, (1.0,))
2350

2351
    def test_cat_2k_args(self):
2352
        with inline_fusion_groups():
2353

2354
            def eager(x):
2355
                return torch.relu(torch.cat([x for _ in range(2000)]))
2356

2357
            x = torch.randn(1)
2358
            trace = self.checkTrace(eager, (x,))
2359
            fusion_groups = self.findFusionGroups(trace.graph_for(x))
2360
            self.assertEqual(len(fusion_groups), 0)
2361

2362
    def test_adaptive_avg_pool2d(self):
2363
        # TODO: once the adaptive_avg_pool2d is available in OpInfo DB, this
2364
        # test should be moved there
2365
        with inline_fusion_groups():
2366

2367
            def foo1(x):
2368
                return torch.nn.functional.adaptive_avg_pool2d(x, (2, 2))
2369

2370
            def foo2(x):
2371
                return torch.nn.functional.adaptive_avg_pool2d(x, (2))
2372

2373
            x = torch.randn(4, 4, 4)
2374
            for foo in [foo1, foo2]:
2375
                f = torch.jit.trace(foo, (x,))
2376
                kernel = torch._C._te.TensorExprKernel(f.graph)
2377
                correct_val = f(x)
2378
                self.assertEqual(kernel.run((x,)), correct_val)
2379

2380
    def test_unrolled_cat(self):
2381
        with inline_fusion_groups():
2382

2383
            def eager(x):
2384
                ret = torch.empty(0)
2385
                for i in range(x.shape[0]):
2386
                    ret = torch.cat([ret, x[i].relu()])
2387
                return ret
2388

2389
            script = torch.jit.script(eager)
2390

2391
            # Warm up with size=1 tensor; since the loop iterates once the
2392
            # profile data will be "burned in" assuming size=1, and then
2393
            # unrolled.
2394
            x = torch.ones(1, 1)
2395
            for _ in range(3):
2396
                script(x)
2397

2398
            torch.testing.assert_close(eager(x), script(x))
2399

2400
            # Now when an input hits the unrolled path, it will produce an
2401
            # incorrectly-sized tensor, since size=1 has been burned in.
2402
            x = torch.ones((8, 1))
2403
            torch.testing.assert_close(eager(x), script(x))
2404

2405
    @skipIfTorchDynamo("too slow")
2406
    @unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan")
2407
    @unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans")
2408
    def test_batch_norm(self):
2409
        def test(fn, args):
2410
            trace = torch.jit.trace(fn, args)
2411
            self.assertAllFused(trace.graph_for(*args))
2412
            # TODO: Are `NaN`'s actually ok here or did this pass silently before, because `equal_nan=True` was the
2413
            #  default?
2414
            torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True)
2415

2416
        def bn(i, x):
2417
            return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu()
2418

2419
        def bn_no_weight(i, x):
2420
            return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu()
2421

2422
        def bn_no_bias(i, x):
2423
            return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu()
2424

2425
        def bn_neither(i, x):
2426
            return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu()
2427

2428
        for device in self.devices:
2429
            i = torch.randn(4, 16, 32, 40, device=device)
2430
            x = torch.randn(16, device=device)
2431
            for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]:
2432
                test(fn, (i, x))
2433

2434
    def test_profiler(self):
2435
        @torch.jit.script
2436
        def test(x, y, z):
2437
            return x * y + z
2438

2439
        args = [torch.randn(4) for _ in range(3)]
2440
        with torch.autograd.profiler.profile() as prof:
2441
            for _ in range(3):
2442
                test(*args)
2443
        self.assertIn("fused_mul_add", prof.table())
2444

2445
    def test_skip_grad_in_check(self):
2446
        @torch.jit.script
2447
        def foo(x):
2448
            return (x + 2) / 2
2449

2450
        inp = torch.rand([4, 4])
2451
        for _ in range(3):
2452
            foo(inp)
2453

2454
        inp.requires_grad_(True)
2455
        with torch.inference_mode():
2456
            for _ in range(3):
2457
                foo(inp)
2458
        g = torch.jit.last_executed_optimized_graph()
2459
        torch._C._jit_pass_inline(g)
2460
        torch._C._jit_pass_inline(g)
2461
        FileCheck().check_count("prim::If", 1, exactly=True).run(g)
2462

2463
    def test_dynamic_shapes(self):
2464
        from functools import partial
2465

2466
        n = 10
2467

2468
        gen_tensor = (
2469
            lambda n: R(1, n),
2470
            lambda n: R(n, n),
2471
            lambda n: R(n, n).transpose(0, 1),
2472
            lambda n: R(n + 1, n + 1, 2)[:n, n, 0],
2473
            lambda n: R(n, n, 2)[:, :, 0],
2474
            lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last),
2475
        )
2476

2477
        with texpr_enable_strategy([("DYNAMIC", 20)]):
2478

2479
            def foo(x, y, z):
2480
                return torch.sigmoid(torch.tanh(x))
2481

2482
            foo.__disable_jit_function_caching__ = True
2483

2484
            def fi(x, y, z):
2485
                return torch.tanh(x + y)
2486

2487
            fi.__disable_jit_function_caching__ = True
2488

2489
            def fum(x, y, z):
2490
                return torch.tanh(x + y) + z
2491

2492
            fum.__disable_jit_function_caching__ = True
2493

2494
            funcs = [foo, fi, fum]
2495
            with inline_fusion_groups():
2496
                for device in self.devices:
2497
                    I = partial(torch.randint, 0, 100, device=device)
2498
                    R = partial(torch.randn, device=device)
2499

2500
                    for i, func in enumerate(funcs):
2501
                        num_args = i + 1
2502
                        for j, gen in enumerate(gen_tensor):
2503
                            inps = (gen(n), gen(n), gen(n))
2504
                            func_s = torch.jit.trace(func, inps, check_trace=False)
2505
                            torch._C._jit_pass_erase_shape_information(func_s.graph)
2506
                            for _ in range(2):
2507
                                x, y, z = gen(n), gen(n), gen(n)
2508
                                func_s(x, y, z)
2509

2510
                            for incr in range(3):
2511
                                func_s(*[gen(n + 1) for _ in range(3)])
2512

2513
                            g = torch.jit.last_executed_optimized_graph()
2514
                            torch._C._jit_pass_inline(g)
2515
                            torch._C._jit_pass_dce(g)
2516

2517
                            # We should see only one optimized kernel
2518
                            FileCheck().check_count(
2519
                                "TensorExprDynamicGuard", 1, exactly=True
2520
                            ).run(g)
2521
                            self.assertEqual(func(*inps), func_s(*inps))
2522

2523
                    gen = gen_tensor[0]
2524
                    inps = (gen(n), gen(n), gen(n))
2525
                    foo_s = torch.jit.trace(foo, inps)
2526
                    torch._C._jit_pass_erase_shape_information(foo_s.graph)
2527
                    g_prev = None
2528
                    for gen in gen_tensor:
2529
                        for i in range(3):
2530
                            foo_s(*[gen(n + i) for _ in range(3)])
2531
                            inps = (gen(n), gen(n), gen(n))
2532
                            self.assertEqual(foo_s(*inps), foo(*inps))
2533
                    g = torch.jit.last_executed_optimized_graph()
2534
                    torch._C._jit_pass_inline(g)
2535
                    torch._C._jit_pass_dce(g)
2536
                    FileCheck().check_count(
2537
                        "TensorExprDynamicGuard", len(gen_tensor), exactly=True
2538
                    ).run(g)
2539

2540
    @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA")
2541
    def test_autocast_up(self):
2542
        def f(x):
2543
            y = x._autocast_to_full_precision(True, True)
2544
            z = torch.exp(y)
2545
            return z
2546

2547
        x = torch.rand((2, 2), dtype=torch.half, device="cuda")
2548
        scr = torch.jit.script(f)
2549
        scr(x)
2550
        scr(x)
2551
        self.assertLastGraphAllFused()
2552

2553
    @unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA")
2554
    def test_autocast_down(self):
2555
        def f(x):
2556
            y = torch.sigmoid(x)
2557
            z = y._autocast_to_reduced_precision(True, True, torch.half, torch.half)
2558
            return z
2559

2560
        x = torch.rand((2, 2), dtype=torch.float, device="cuda")
2561
        scr = torch.jit.script(f)
2562
        scr(x)
2563
        scr(x)
2564
        self.assertLastGraphAllFused()
2565

2566
    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2567
    def test_to_dtype(self):
2568
        def f(x):
2569
            y = torch.sigmoid(x)
2570
            z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16)
2571
            h = z._autocast_to_full_precision(True, True)
2572
            i = h.to(dtype=torch.bfloat16)
2573
            j = i.to(dtype=torch.float32)
2574
            return j
2575

2576
        x = torch.rand((2, 2), dtype=torch.float32)
2577
        scr = torch.jit.trace(f, x)
2578
        scr(x)
2579
        scr(x)
2580
        self.assertLastGraphAllFused()
2581
        self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3)
2582

2583
        bf_x = torch.rand((2, 2), dtype=torch.bfloat16)
2584
        bf_scr = torch.jit.trace(f, bf_x)
2585
        bf_scr(bf_x)
2586
        bf_scr(bf_x)
2587
        graph = bf_scr.graph_for(bf_x)
2588
        fusion_groups = self.findFusionGroups(graph)
2589
        self.assertEqual(len(fusion_groups), 2)
2590
        self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3)
2591

2592
    def test_with_strict_fusion(self):
2593
        def success(x):
2594
            with torch.jit.strict_fusion():
2595
                return x + x + x
2596

2597
        scripted = self.checkScript(success, (torch.rand([4]),))
2598
        g = torch.jit.last_executed_optimized_graph()
2599
        FileCheck().check_not("aten::add").check("prim::TensorExprGroup").run(g)
2600

2601
        def foo(x):
2602
            with torch.jit.strict_fusion():
2603
                return x + x + torch.rand([4]) + 3
2604

2605
        with self.assertRaises(Exception) as error_out:
2606
            foo_s = torch.jit.script(foo)
2607
            foo_s(torch.rand([4]))
2608
            foo_s(torch.rand([4]))
2609
            print(torch.jit.last_executed_optimized_graph())
2610
        fc = FileCheck().check("Found unfused operators")
2611
        fc.check("aten::rand(SymInt[] size")
2612
        fc.check("torch.rand([4]").run(str(error_out.exception))
2613

2614
        with warnings.catch_warnings(record=True) as warns:
2615
            foo(torch.rand([4]))
2616

2617
        FileCheck().check("Only works in script mode").run(str(warns[0]))
2618

2619
        def test_autodiff(x):
2620
            with torch.jit.strict_fusion():
2621
                return torch.rand([4]) + x + x + x
2622

2623
        foo_s = torch.jit.script(test_autodiff)
2624
        inp = torch.rand([4], requires_grad=True)
2625
        with self.assertRaises(Exception) as error_out:
2626
            for _ in range(3):
2627
                foo_s(inp)
2628
        f = FileCheck().check("unfused operators").check("aten::rand")
2629
        f.run(str(error_out.exception))
2630

2631
        def test_separate_fusions(x, y):
2632
            with torch.jit.strict_fusion():
2633
                return x + x + x, y + y + y
2634

2635
        inp = torch.rand([4], requires_grad=True)
2636
        with self.assertRaises(Exception) as error_out:
2637
            for _ in range(3):
2638
                foo_s = torch.jit.script(test_separate_fusions)
2639
                foo_s(inp, inp)
2640

2641
        f = FileCheck().check("Found multiple fusions")
2642
        f.run(str(error_out.exception))
2643

2644
    def test_constant_chunk_shapes(self):
2645
        # We had an issue where buildShapeExpressions would fail as show below:
2646
        #
2647
        # %1 : Tensor = Constant[..]  # not supported, we don't build this shape
2648
        # %2 : Tensor = Constant[..]  # not supported
2649
        # %3 : Tensor = aten::add(%1, %2)  # inputs not supported, we don't build shape
2650
        # ... = prim::ConstantChunk[..](%3)  # it forgets to check whether input shapes exist, and fails
2651
        if self.dynamic_shapes:
2652
            self.skipTest("TODO: chunk dynamic shapes")
2653

2654
        for device in self.devices:
2655

2656
            def f(x, y):
2657
                r = torch.tensor(4)
2658
                z1, z2 = (x + y + r).chunk(2, dim=1)
2659
                return z1 * z2
2660

2661
            x = torch.randn(4, 4, dtype=torch.float, device=device)
2662
            y = torch.randn(4, 4, dtype=torch.float, device=device)
2663

2664
            ge = self.checkTrace(f, (x, y))
2665
            graph = ge.graph_for(x, y)
2666

2667
            # make sure that we are actually testing the right scenario
2668
            FileCheck().check("with " + FUSION_GROUP + "_").check_count(
2669
                "ConstantChunk", 1, exactly=True
2670
            ).run(str(graph))
2671

2672
            f_traced = torch.jit.trace(f, (x, y))
2673

2674
            for i in range(4):
2675
                # make sure this doesn't error out
2676
                res = f_traced(x, y)
2677

2678
            self.assertEqual(res, f(x, y))
2679

2680
    @unittest.skipIf(not RUN_CUDA_HALF, "half-precision NNC fusion requires CUDA")
2681
    def test_pow_multiple_dtype(self):
2682
        # https://github.com/pytorch/pytorch/issues/75476
2683
        def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
2684
            p = torch.sigmoid(p)
2685
            result = p**gamma
2686
            return result
2687

2688
        x = torch.rand((2, 2), dtype=torch.half, device="cuda")
2689

2690
        ref = fn(x)
2691

2692
        script_fn = torch.jit.script(fn)
2693
        for i in range(4):
2694
            res = script_fn(x)
2695

2696
        self.assertEqual(ref, res)
2697

2698

2699
class TestTEFuserStatic(TestTEFuser):
2700
    dynamic_shapes = False
2701

2702

2703
class TestTEFuserDynamic(TestTEFuser):
2704
    dynamic_shapes = True
2705

2706

2707
del TestTEFuser
2708

2709
works_list = [
2710
    "__radd__",
2711
    "__rdiv__",
2712
    "__rmul__",
2713
    "__rmod__",
2714
    "abs",
2715
    "acos",
2716
    "add",
2717
    "addcmul",
2718
    "addmm.decomposed",
2719
    "asin",
2720
    "atan",
2721
    "atan2",
2722
    "ceil",
2723
    "clamp",
2724
    "clamp.scalar",
2725
    "contiguous",
2726
    "cos",
2727
    "cosh",
2728
    "div.no_rounding_mode",
2729
    "div.true_rounding",
2730
    "div.floor_rounding",
2731
    "div.trunc_rounding",
2732
    "eq",
2733
    "erf",
2734
    "erfc",
2735
    "exp",
2736
    "expand",
2737
    "expand_as",
2738
    "expm1",
2739
    "floor",
2740
    "fmod",
2741
    "fmod.autodiffed",
2742
    "ge",
2743
    "gt",
2744
    "isnan",
2745
    "le",
2746
    "lerp",
2747
    "lgamma",
2748
    "log",
2749
    "log10",
2750
    "log1p",
2751
    "log2",
2752
    "lt",
2753
    "masked_fill",
2754
    "max.binary",
2755
    "mean",
2756
    "min.binary",
2757
    "mm",
2758
    "mul",
2759
    "ne",
2760
    "neg",
2761
    "nn.functional.hardshrink",
2762
    "nn.functional.hardsigmoid",
2763
    "nn.functional.hardswish",
2764
    "nn.functional.softplus",
2765
    "nn.functional.hardtanh",
2766
    "nn.functional.leaky_relu",
2767
    "nn.functional.relu",
2768
    "nn.functional.relu6",
2769
    "nn.functional.softsign",
2770
    "nn.functional.tanhshrink",
2771
    "nn.functional.threshold",
2772
    "permute",
2773
    "pow",
2774
    "reciprocal",
2775
    "remainder",
2776
    "remainder.autodiffed",
2777
    "reshape",
2778
    "reshape_as",
2779
    "round",
2780
    "rsub",
2781
    "rsub.rsub_tensor",
2782
    "rsqrt",
2783
    "sigmoid",
2784
    "sign",
2785
    "sin",
2786
    "sinh",
2787
    "sqrt",
2788
    "sub",
2789
    "sum",
2790
    "t",
2791
    "tan",
2792
    "tanh",
2793
    "transpose",
2794
    "true_divide",
2795
    "trunc",
2796
    "unsqueeze",
2797
    "view",
2798
    "view_as",
2799
    "where",
2800
    "bool",
2801
    "byte",
2802
    "char",
2803
    "double",
2804
    "float",
2805
    "half",
2806
    "int",
2807
    "long",
2808
    "short",
2809
    "bool.channels_last",
2810
    "byte.channels_last",
2811
    "char.channels_last",
2812
    "double.channels_last",
2813
    "float.channels_last",
2814
    "half.channels_last",
2815
    "int.channels_last",
2816
    "long.channels_last",
2817
    "short.channels_last",
2818
]
2819

2820
known_failures = [
2821
    "__rmatmul__",
2822
    "frac",
2823
    "matmul",
2824
]
2825

2826
# If your OpInfo test causes this test to fail, add it here
2827
skip_ops = ["conj"]
2828

2829

2830
def get_name(op):
2831
    l = [op.name]
2832
    if op.variant_test_name != "":
2833
        l.append(op.variant_test_name)
2834
    return ".".join(l)
2835

2836

2837
# Purpose of this class is to allow super() calls.
2838
# super() [with no arguments] fails, presumably because of how instantiate_device_type_tests works.
2839
# super(TestNNCOpInfo, self) fails because TestNNCOpInfo gets deleted from global scope.
2840
# super(JitCommonTestCase, self).fn() would skip JitCommonTestCase.fn() implementation
2841
class TestNNCOpInfoParent(JitCommonTestCase):
2842
    pass
2843

2844

2845
class TestNNCOpInfo(TestNNCOpInfoParent):
2846
    def setUp(self):
2847
        super(TestNNCOpInfoParent, self).setUp()
2848
        self.tensorexpr_options = TensorExprTestOptions()
2849

2850
    def tearDown(self):
2851
        self.tensorexpr_options.restore()
2852
        super(TestNNCOpInfoParent, self).tearDown()
2853

2854
    def te_compile(self, device, dtype, op):
2855
        if op.name in skip_ops:
2856
            return
2857
        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
2858
        for sample_input in sample_inputs_itr:
2859
            arg_values = [sample_input.input] + list(sample_input.args)
2860
            kwarg_values = sample_input.kwargs
2861
            param_names = []
2862
            param_values = []
2863
            fx_args = []
2864
            for idx, v in enumerate(arg_values):
2865
                if isinstance(v, torch.Tensor):
2866
                    param_names.append(f"arg_{idx}")
2867
                    param_values.append(v)
2868
                    fx_args.append(param_names[-1])
2869
                else:
2870
                    fx_args.append(f"{repr(v)}")
2871

2872
            for k, v in kwarg_values.items():
2873
                if isinstance(v, torch.Tensor):
2874
                    param_names.append(k)
2875
                    param_values.append(v)
2876
                    fx_args.append(f"{k} = {k}")
2877
                else:
2878
                    fx_args.append(f"{k} = {repr(v)}")
2879

2880
            code = f"""
2881
def f({', '.join(param_names)}):
2882
    return op.op({', '.join(fx_args)})"""
2883
            g = {"torch": torch, "inf": math.inf, "op": op}
2884
            exec(code, g)
2885
            f = g["f"]
2886
            f.__module__ = "test"
2887
            out = f(*param_values)
2888

2889
            ts_g = torch.jit.trace(f, param_values)
2890
            kernel = torch._C._te.TensorExprKernel(ts_g.graph)
2891
            correct_val = f(*param_values)
2892
            self.assertEqual(kernel.run(tuple(param_values)), correct_val)
2893
            self.assertEqual(kernel.fallback(tuple(param_values)), correct_val)
2894

2895
    @onlyCPU
2896
    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2897
    @ops(
2898
        [op for op in op_db if get_name(op) in works_list],
2899
        allowed_dtypes=(torch.float,),
2900
    )
2901
    def test_working(self, device, dtype, op):
2902
        self.te_compile(device, dtype, op)
2903

2904
    @onlyCPU
2905
    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2906
    @ops(
2907
        [op for op in op_db if get_name(op) in known_failures],
2908
        allowed_dtypes=(torch.float,),
2909
    )
2910
    def test_failures(self, device, dtype, op):
2911
        try:
2912
            self.te_compile(device, dtype, op)
2913
        except Exception as e:
2914
            pass
2915
        else:
2916
            raise RuntimeError(
2917
                "Expected test to fail. If it now works, move op into works_list"
2918
            )
2919

2920
    @onlyCPU
2921
    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2922
    @ops(
2923
        [op for op in op_db if get_name(op) not in works_list + known_failures],
2924
        allowed_dtypes=(torch.float,),
2925
    )
2926
    def test_unsupported(self, device, dtype, op):
2927
        if get_name(op) in skip_ops:
2928
            return
2929
        try:
2930
            with warnings.catch_warnings():
2931
                warnings.simplefilter("ignore", TracerWarning)  # noqa: F821
2932
                self.te_compile(device, dtype, op)
2933
        except Exception as e:
2934
            pass
2935
        else:
2936
            raise RuntimeError(
2937
                "Expected test to fail. If it now works, move op into works_list"
2938
            )
2939

2940
    @slowTest
2941
    @onlyCPU
2942
    @ops(op_db, dtypes=OpDTypes.supported)
2943
    def test_nnc_correctness(self, device, dtype, op):
2944
        if not op.supports_tracing:
2945
            self.skipTest("Requires tracing support")
2946

2947
        with NoTracerWarnContextManager() as no_warn:
2948
            variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
2949

2950
            for variant, sample in variant_sample_pairs:
2951
                trace = create_traced_fn(self, variant, cache_traced_fn=True)
2952
                ref = variant(
2953
                    *clone_inputs((sample.input, *sample.args)), **sample.kwargs
2954
                )
2955

2956
                trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
2957
                val = trace(
2958
                    *clone_inputs((sample.input, *sample.args)), **sample.kwargs
2959
                )
2960

2961
                atol = 2e-1 if dtype == torch.bfloat16 else 1e-5
2962
                rtol = 2e-1 if dtype == torch.bfloat16 else 1e-5
2963
                self.assertEqual(ref, val, atol=atol, rtol=rtol)
2964

2965
            # https://github.com/pytorch/pytorch/issues/35600
2966
            # each torch.jit.trace adds state to the _python_cu compilation unit
2967
            # since this test traces a lot of functions, out-of-memory can occur
2968
            # if the CU is not cleared.
2969
            torch.jit._state._python_cu.drop_all_functions()
2970

2971

2972
# CPU fuser not currently used in fbcode
2973
only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda")
2974
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
2975

2976

2977
# Purpose of this class is to allow super() calls. (See TestNNCOpInfoParent)
2978
class TestLoopnestRandomizationParent(JitTestCase):
2979
    pass
2980

2981

2982
class TestLoopnestRandomization(TestLoopnestRandomizationParent):
2983
    def setUp(self):
2984
        super(TestLoopnestRandomizationParent, self).setUp()
2985
        self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
2986
        self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu()
2987
        self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
2988

2989
        torch._C._jit_override_can_fuse_on_cpu(True)
2990
        # TODO: force LLVM. need to add it to asan, mac, windows builds + sandcastle
2991
        # torch._C._jit_set_te_must_use_llvm_cpu(True)
2992
        torch._C._jit_override_can_fuse_on_gpu(True)
2993

2994
        self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
2995
        self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
2996

2997
        self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
2998
        torch._C._debug_set_fusion_group_inlining(False)
2999

3000
        self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
3001
        torch._C._jit_set_texpr_fuser_enabled(True)
3002

3003
        self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
3004
        torch._C._jit_set_te_must_use_llvm_cpu(False)
3005

3006
        # Set the seed to 1. This tests the codepath through random
3007
        # transformation.
3008
        os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "1"
3009

3010
    def tearDown(self):
3011
        torch._C._jit_set_profiling_executor(self.old_profiling_executor)
3012
        torch._C._get_graph_executor_optimize(self.old_profiling_mode)
3013

3014
        torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
3015
        torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
3016
        torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state)
3017
        torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
3018

3019
        torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
3020
        torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
3021

3022
        # Set it back to 0.
3023
        os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "0"
3024
        super(TestLoopnestRandomizationParent, self).tearDown()
3025

3026
    @onlyCPU
3027
    @unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
3028
    def test_relu(self, device):
3029
        def fn_test_relu(x, y):
3030
            return F.relu(x + 0.5 * y)
3031

3032
        x = torch.randn(4, 4, dtype=torch.float, device=device)
3033
        y = torch.randn(4, 4, dtype=torch.float, device=device)
3034

3035
        fn = fn_test_relu
3036
        traced_fn = torch.jit.trace(fn, (x, y))
3037

3038
        ref = fn(x, y)
3039
        res = traced_fn(x, y)
3040
        assert torch.allclose(ref, res)
3041

3042

3043
instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("cpu"))
3044

3045

3046
if __name__ == "__main__":
3047
    run_tests()
3048

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.