pytorch

Форк
0
/
test_tensorexpr.py 
1720 строк · 57.4 Кб
1
# Owner(s): ["NNC"]
2

3
import numpy as np
4
import torch
5
import torch.nn.functional as F
6
from torch import nn
7
import unittest
8
import itertools
9

10
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo
11

12
from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
13

14
LLVM_ENABLED = torch._C._llvm_enabled()
15

16
class BaseTestClass(JitTestCase):
17
    def setUp(self):
18
        super().setUp()
19
        self.tensorexpr_options = TensorExprTestOptions()
20
        self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
21
        self.dtypes = [torch.float32, torch.bfloat16] if LLVM_ENABLED else [torch.float32]
22

23
    def tearDown(self):
24
        self.tensorexpr_options.restore()
25
        super().tearDown()
26

27
    def assertLastGraphAllFused(self):
28
        self.assertAllFused(torch.jit.last_executed_optimized_graph())
29

30

31
def warmup_and_run_forward(f, *args):
32
    for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
33
        results = f(*args)
34
    return results
35

36

37
@skipIfTorchDynamo()
38
class TestTensorExprFuser(BaseTestClass):
39
    def test_easy(self):
40
        def easy(x, y):
41
            aaa = torch.add(x, y)
42
            return aaa
43

44
        traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))
45

46
        a = torch.rand(1024)
47
        b = torch.rand(1024)
48
        x = warmup_and_run_forward(traced, a, b)
49
        self.assertLastGraphAllFused()
50
        np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
51

52
    def test_three_arg(self):
53
        def easy(x, y, z):
54
            aaa = torch.add(x, y)
55
            bbb = torch.add(aaa, z)
56
            return bbb
57

58
        traced = torch.jit.trace(
59
            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
60
        )
61

62
        a = torch.rand(1024)
63
        b = torch.rand(1024)
64
        c = torch.rand(1024)
65
        x = warmup_and_run_forward(traced, a, b, c)
66
        self.assertLastGraphAllFused()
67
        npr = a.numpy() + b.numpy() + c.numpy()
68
        np.testing.assert_allclose(npr, x.numpy())
69

70
    def test_four_arg(self):
71
        def run_addcmul(x, y, z, w):
72
            c = torch.addcmul(torch.add(x, y), z, w)
73
            return c
74

75
        for dev in self.devices:
76
            rand_a = torch.rand(1024, dtype=torch.float, device=dev)
77
            rand_b = torch.rand(1024, dtype=torch.float, device=dev)
78
            rand_c = torch.rand(1024, dtype=torch.float, device=dev)
79
            rand_d = torch.rand(1024, dtype=torch.float, device=dev)
80

81
            traced = torch.jit.trace(
82
                run_addcmul,
83
                (
84
                    torch.zeros(1024, dtype=torch.float, device=dev),
85
                    torch.zeros(1024, dtype=torch.float, device=dev),
86
                    torch.zeros(1024, dtype=torch.float, device=dev),
87
                    torch.zeros(1024, dtype=torch.float, device=dev),
88
                ),
89
            )
90

91
            x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d)
92
            self.assertLastGraphAllFused()
93
            y = run_addcmul(rand_a, rand_b, rand_c, rand_d)
94
            np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6)
95

96
    def test_three_arg2(self):
97
        for device in self.devices:
98
            def test(x, y, z):
99
                aaa = torch.add(x, y)
100
                bbb = torch.add(aaa, z)
101
                return bbb
102

103
            M = 32
104
            N = 32
105
            traced = torch.jit.trace(
106
                test,
107
                (
108
                    torch.rand(M, N, device=device),
109
                    torch.rand(M, N, device=device),
110
                    torch.rand(M, N, device=device),
111
                ),
112
            )
113

114
            a = torch.rand(M, N, device=device)
115
            b = torch.rand(M, N, device=device)
116
            c = torch.rand(M, N, device=device)
117
            x = traced(a, b, c)
118
            x = warmup_and_run_forward(traced, a, b, c)
119
            self.assertLastGraphAllFused()
120
            npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
121
            np.testing.assert_allclose(npr, x.cpu().numpy())
122

123
    def test_broadcast3(self):
124
        for device in self.devices:
125
            def test_body(M, N, L, K):
126
                def test(x, y, z):
127
                    v1 = torch.add(x, y)
128
                    v2 = torch.add(v1, z)
129
                    return v2
130

131
                a_shape = [M, N]
132
                b_shape = [L, M, 1]
133
                c_shape = [K, L, 1, 1]
134
                traced = torch.jit.trace(
135
                    test,
136
                    (
137
                        torch.rand(*a_shape, device=device),
138
                        torch.rand(*b_shape, device=device),
139
                        torch.rand(*c_shape, device=device),
140
                    ),
141
                )
142

143
                a = torch.rand(*a_shape, device=device)
144
                b = torch.rand(*b_shape, device=device)
145
                c = torch.rand(*c_shape, device=device)
146
                x = warmup_and_run_forward(traced, a, b, c)
147
                self.assertLastGraphAllFused()
148
                npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
149
                np.testing.assert_allclose(npr, x.cpu().numpy())
150

151
            test_configs = [[5, 2, 7, 3], [8, 8, 8, 8]]
152
            for test_config in test_configs:
153
                test_body(*test_config)
154

155
    def test_all_combos(self):
156
        def easy(x, y, z):
157
            a = torch.add(x, y)
158
            b = torch.add(a, z)
159
            c = torch.add(x, b)
160
            d = torch.add(c, a)
161
            return d
162

163
        def np_easy(x, y, z):
164
            a = x + y
165
            b = a + z
166
            c = x + b
167
            d = c + a
168
            return d
169

170
        traced = torch.jit.trace(
171
            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
172
        )
173

174
        a = torch.rand(1024)
175
        b = torch.rand(1024)
176
        c = torch.rand(1024)
177
        x = warmup_and_run_forward(traced, a, b, c)
178
        self.assertLastGraphAllFused()
179
        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
180
        np.testing.assert_allclose(npr, x.numpy())
181

182
    def test_rank_two(self):
183
        def easy(x, y, z):
184
            a = torch.add(x, y)
185
            b = torch.add(a, z)
186
            c = torch.add(x, b)
187
            d = torch.add(c, a)
188
            return d
189

190
        def np_easy(x, y, z):
191
            a = x + y
192
            b = a + z
193
            c = x + b
194
            d = c + a
195
            return d
196

197
        shape = 32, 32
198
        traced = torch.jit.trace(
199
            easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape))
200
        )
201

202
        a = torch.rand(shape)
203
        b = torch.rand(shape)
204
        c = torch.rand(shape)
205
        x = warmup_and_run_forward(traced, a, b, c)
206
        self.assertLastGraphAllFused()
207
        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
208
        np.testing.assert_allclose(npr, x.numpy())
209

210
    def test_broadcast(self):
211
        def easy(x, y, z):
212
            a = torch.add(x, y)
213
            b = torch.add(a, z)
214
            return b
215

216
        def np_easy(x, y, z):
217
            a = x + y
218
            b = a + z
219
            return b
220

221
        N = 32
222
        traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N)))
223

224
        a = torch.rand(N, N)
225
        b = torch.rand(N)
226
        c = torch.rand(N, N)
227
        x = warmup_and_run_forward(traced, a, b, c)
228
        self.assertLastGraphAllFused()
229
        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
230
        np.testing.assert_allclose(npr, x.numpy())
231

232
    def test_broadcast_2(self):
233
        zero = torch.tensor([0.0], dtype=torch.float)
234

235
        def foo(x, y, z):
236
            aaa = torch.add(x, y)
237
            bbb = torch.add(zero, aaa)
238
            return torch.add(bbb, z)
239

240
        def foo_np(x, y, z):
241
            a = x + y
242
            b = zero.numpy() + a
243
            return b + z
244

245
        x = torch.rand(3, 4)
246
        y = torch.ones(3, 1)
247
        z = torch.rand(4)
248
        traced = torch.jit.trace(foo, (x, y, z))
249

250
        r = warmup_and_run_forward(traced, x, y, z)
251
        self.assertLastGraphAllFused()
252

253
        rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
254
        np.testing.assert_allclose(r, rnp)
255

256
    def test_broadcast_big2(self):
257
        zero = torch.tensor([0.0], dtype=torch.float)
258

259
        def foo(x, y, z):
260
            aaa = torch.add(x, y)
261
            bbb = torch.add(zero, aaa)
262
            return torch.add(bbb, z)
263

264
        def foo_np(x, y, z):
265
            a = x + y
266
            b = zero.numpy() + a
267
            return b + z
268

269
        x = torch.rand(32, 1024)
270
        y = torch.ones(32, 1)
271
        z = torch.rand(1024)
272
        traced = torch.jit.trace(foo, (x, y, z))
273

274
        r = warmup_and_run_forward(traced, x, y, z)
275
        self.assertLastGraphAllFused()
276
        rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
277
        np.testing.assert_allclose(r, rnp)
278

279
    def test_alpha(self):
280
        def alpha(x):
281
            aaa = torch.add(x, x, alpha=2.0)
282
            return aaa
283

284
        traced = torch.jit.trace(alpha, (torch.tensor([1.0])))
285

286
        a = torch.tensor([1.0])
287
        x = traced(a)
288
        np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy())
289

290
    @suppress_warnings
291
    def test_constant(self):
292
        def constant(x):
293
            bbb = torch.tensor([1.0])
294
            aaa = torch.add(x, bbb)
295
            return aaa
296

297
        traced = torch.jit.trace(constant, (torch.tensor([1.0])))
298

299
        a = torch.tensor([1.0])
300
        x = warmup_and_run_forward(traced, a)
301
        self.assertLastGraphAllFused()
302
        np.testing.assert_allclose(a.numpy() + 1.0, x.numpy())
303

304
    def test_add_sub(self):
305
        def easy(x, y, z):
306
            aaa = torch.add(x, y)
307
            bbb = torch.sub(aaa, z)
308
            return bbb
309

310
        traced = torch.jit.trace(
311
            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
312
        )
313

314
        a = torch.rand(1024)
315
        b = torch.rand(1024)
316
        c = torch.rand(1024)
317
        x = warmup_and_run_forward(traced, a, b, c)
318
        self.assertLastGraphAllFused()
319
        np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy())
320

321
    def test_promotion(self):
322
        def easy(x, y):
323
            aaa = torch.add(x, y)
324
            return aaa
325

326
        traced = torch.jit.trace(
327
            easy,
328
            (torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)),
329
        )
330

331
        a = torch.zeros(1024, dtype=torch.int32)
332
        b = torch.rand(1024, dtype=torch.float32)
333
        x = warmup_and_run_forward(traced, a, b)
334
        self.assertLastGraphAllFused()
335
        np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
336

337
    def test_double(self):
338
        TENSOR_LEN = 8
339

340
        def easy(x, y):
341
            aaa = torch.add(x, y)
342
            bbb = torch.mul(aaa, y)
343
            return bbb
344

345
        traced = torch.jit.trace(
346
            easy,
347
            (torch.rand(TENSOR_LEN, dtype=torch.float64), torch.full((TENSOR_LEN,), 0.5, dtype=torch.float64)),
348
        )
349

350
        a = torch.rand(TENSOR_LEN, dtype=torch.double)
351
        b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double)
352
        x = warmup_and_run_forward(traced, a, b)
353
        self.assertLastGraphAllFused()
354
        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
355

356
    def test_short(self):
357
        TENSOR_LEN = 8
358

359
        def easy(x, y):
360
            aaa = torch.add(x, y)
361
            bbb = torch.mul(aaa, y)
362
            return bbb
363

364
        traced = torch.jit.trace(
365
            easy,
366
            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16),
367
             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)),
368
        )
369

370
        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
371
        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
372
        x = warmup_and_run_forward(traced, a, b)
373
        self.assertLastGraphAllFused()
374
        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
375

376
    def test_char(self):
377
        TENSOR_LEN = 8
378

379
        def easy(x, y):
380
            aaa = torch.add(x, y)
381
            bbb = torch.mul(aaa, y)
382
            return bbb
383

384
        traced = torch.jit.trace(
385
            easy,
386
            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
387
             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)),
388
        )
389

390
        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
391
        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
392
        x = warmup_and_run_forward(traced, a, b)
393
        self.assertLastGraphAllFused()
394
        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
395

396
    def test_int64_promotion(self):
397
        TENSOR_LEN = 8
398

399
        def easy(x, y):
400
            aaa = torch.add(x, y)
401
            bbb = torch.mul(aaa, y)
402
            return bbb
403

404
        traced = torch.jit.trace(
405
            easy,
406
            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
407
             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)),
408
        )
409

410
        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
411
        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)
412
        x = warmup_and_run_forward(traced, a, b)
413
        self.assertLastGraphAllFused()
414
        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
415

416
    def test_eq(self):
417
        def easy(x, y):
418
            c = torch.eq(x, y)
419
            return c
420

421
        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
422
        a = torch.zeros(1024, dtype=torch.int32)
423
        b = torch.zeros(1024, dtype=torch.int32)
424
        x = warmup_and_run_forward(traced, a, b)
425
        self.assertLastGraphAllFused()
426
        np.testing.assert_allclose(np.ones(1024), x.numpy())
427

428
    def test_ne(self):
429
        def easy(x, y):
430
            c = torch.ne(x, y)
431
            return c
432

433
        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
434
        a = torch.zeros(1024, dtype=torch.int32)
435
        b = torch.ones(1024, dtype=torch.int32)
436
        x = warmup_and_run_forward(traced, a, b)
437
        self.assertLastGraphAllFused()
438
        np.testing.assert_allclose(np.ones(1024), x.numpy())
439

440
    def test_ge(self):
441
        def easy(x, y):
442
            c = torch.ge(x, y)
443
            return c
444

445
        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
446
        aa = np.empty([1024], dtype=np.int32)
447
        aa.fill(5)
448
        a = torch.from_numpy(aa)
449
        b = torch.zeros(1024, dtype=torch.int32)
450
        x = warmup_and_run_forward(traced, a, b)
451
        self.assertLastGraphAllFused()
452
        np.testing.assert_allclose(np.ones(1024), x.numpy())
453

454
    def test_gt(self):
455
        def easy(x, y):
456
            c = torch.gt(x, y)
457
            return c
458

459
        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
460
        a = torch.ones(1024, dtype=torch.int32)
461
        b = torch.zeros(1024, dtype=torch.int32)
462
        x = warmup_and_run_forward(traced, a, b)
463
        self.assertLastGraphAllFused()
464
        np.testing.assert_allclose(np.ones(1024), x.numpy())
465

466
    def test_le(self):
467
        def easy(x, y):
468
            c = torch.le(x, y)
469
            return c
470

471
        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
472
        aa = np.empty([1024], dtype=np.int32)
473
        aa.fill(5)
474
        a = torch.from_numpy(aa)
475
        b = torch.zeros(1024, dtype=torch.int32)
476
        x = warmup_and_run_forward(traced, a, b)
477
        self.assertLastGraphAllFused()
478
        np.testing.assert_allclose(np.zeros(1024), x.numpy())
479

480
    def test_lt(self):
481
        def easy(x, y):
482
            c = torch.lt(x, y)
483
            return c
484

485
        for dev in self.devices:
486
            traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev)))
487
            a = torch.ones(1024, dtype=torch.int32, device=dev)
488
            b = torch.zeros(1024, dtype=torch.int32, device=dev)
489
            x = warmup_and_run_forward(traced, a, b)
490
            self.assertLastGraphAllFused()
491
            np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy())
492

493
    @suppress_warnings
494
    def test_min_max(self):
495
        def test(x, y):
496
            return torch.max(torch.min(x, y), torch.tensor([4.0]))
497

498
        traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024)))
499
        a = 8.0 * torch.rand(1024)
500
        b = 8.0 * torch.rand(1024)
501
        np.testing.assert_allclose(
502
            warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0])
503
        )
504
        self.assertLastGraphAllFused()
505

506
    def test_min_max_reduction(self):
507
        def test(x):
508
            return torch.min(x) + torch.max(x)
509

510
        traced = torch.jit.trace(test, (torch.zeros(1024)))
511
        a = 8.0 * torch.rand(1024)
512
        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
513
        self.assertLastGraphAllFused()
514

515
    def test_min_max_reduction2(self):
516
        def test(x):
517
            return x.min() + x.max()
518

519
        traced = torch.jit.trace(test, (torch.zeros(1024)))
520
        a = 8.0 * torch.rand(1024)
521
        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
522
        self.assertLastGraphAllFused()
523

524
    def test_min_max_reduction_dim1(self):
525
        def test(x):
526
            return torch.min(x, 1)[0] + torch.max(x, 1)[0]
527

528
        traced = torch.jit.trace(test, (torch.zeros(16, 16)))
529
        a = 8.0 * torch.rand(16, 16)
530
        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(
531
            a.numpy(), axis=1) + np.amax(a.numpy(), axis=1))
532
        self.assertLastGraphAllFused()
533

534
    def test_min_max_reduction_dim1_2(self):
535
        def test(x):
536
            return torch.min(x * x, 1)
537

538
        traced = torch.jit.trace(test, (torch.zeros(16, 16)))
539
        a = 8.0 * torch.rand(16, 16)
540
        np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1))
541
        self.assertLastGraphAllFused()
542

543
    def test_clamp(self):
544
        def test(x):
545
            return torch.clamp(x + 3.0, 0.0, 6.0)
546

547
        for dev in self.devices:
548
            traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
549
            a = 20.0 * torch.rand(1024, device=dev) - 10.0
550
            an = a.cpu().numpy()
551
            np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0))
552
            self.assertLastGraphAllFused()
553

554
    def test_relu(self):
555
        def test(x):
556
            return torch.clamp(F.relu(x), 0, 0.5)
557

558
        for dev in self.devices:
559
            traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
560
            a = 20.0 * torch.rand(1024, device=dev) - 10.0
561
            an = a.cpu().numpy()
562
            np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5))
563
            self.assertLastGraphAllFused()
564

565
    def test_reps(self):
566
        def easy(x, y):
567
            c = torch.add(x, y)
568
            return c
569

570
        traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))
571

572
        for _ in range(32):
573
            a = torch.ones(1024)
574
            b = torch.zeros(1024)
575
            x = warmup_and_run_forward(traced, a, b)
576
            np.testing.assert_allclose(np.ones(1024), x.numpy())
577

578
    def test_add_const_rhs(self):
579
        def test(x):
580
            return x + 3.0
581

582
        traced = torch.jit.trace(test, torch.rand(4))
583
        x = torch.rand(4)
584
        y = warmup_and_run_forward(traced, x)
585
        self.assertLastGraphAllFused()
586
        np.testing.assert_allclose(x.numpy() + 3.0, y.numpy())
587

588
    def test_int_output(self):
589
        def test(x, y, z):
590
            return x * y * z
591

592
        xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)]
593
        x, y, z = xs
594
        xn, yn, zn = (t.numpy() for t in xs)
595
        traced = torch.jit.trace(test, (x, y, z))
596
        res = warmup_and_run_forward(traced, x, y, z)
597
        self.assertLastGraphAllFused()
598
        np.testing.assert_allclose(xn * yn * zn, res.numpy())
599

600
    def test_binary_ops(self):
601
        def test_atan2(x, y):
602
            c = torch.atan2(torch.add(x, y), y)
603
            return c
604

605
        def test_gt(x, y):
606
            c = torch.gt(torch.add(x, y), y)
607
            return c
608

609
        def test_ge(x, y):
610
            c = torch.ge(torch.add(x, y), y)
611
            return c
612

613
        def test_lt(x, y):
614
            c = torch.lt(torch.add(x, y), y)
615
            return c
616

617
        def test_le(x, y):
618
            c = torch.le(torch.add(x, y), y)
619
            return c
620

621
        def test_lerp(x, y):
622
            c = torch.lerp(torch.add(x, 1), x, 2.0)
623
            return c
624

625
        def test_mul(x, y):
626
            c = torch.mul(torch.add(x, y), y)
627
            return c
628

629
        def test_ne(x, y):
630
            c = torch.ne(torch.add(x, y), y)
631
            return c
632

633
        def test_div(x, y):
634
            c = torch.div(torch.add(x, y), 2)
635
            return c
636

637
        def test_eq(x, y):
638
            c = torch.eq(torch.add(x, y), y)
639
            return c
640

641
        def test_fmod(x, y):
642
            c = torch.fmod(torch.add(x, y), 2)
643
            return c
644

645
        def test_sub(x, y):
646
            c = torch.sub(torch.add(x, y), x)
647
            return c
648

649
        def test_remainder(x, y):
650
            c = torch.remainder(torch.add(x, y), 3.0)
651
            return c
652

653
        def test_pow(x, y):
654
            c = torch.pow(torch.add(x, y), 2.0)
655
            return c
656

657
        def test_type_as(x, y):
658
            return x.type_as(torch.add(x, y))
659

660
        cmp_fns = {
661
            test_gt,
662
            test_ge,
663
            test_lt,
664
            test_le,
665
            test_ne,
666
            test_eq
667
        }
668

669
        non_cmp_fns = {
670
            test_atan2,
671
            test_lerp,
672
            test_mul,
673
            test_div,
674
            test_fmod,
675
            test_sub,
676
            test_remainder,
677
            test_pow,
678
            test_type_as,
679
        }
680

681
        all_test_fns = cmp_fns.union(non_cmp_fns)
682
        fn_dev_dtype = itertools.product(all_test_fns, self.devices, self.dtypes)
683
        for torch_fn, dev, data_type in fn_dev_dtype:
684
            if torch_fn is test_lerp and data_type is torch.bfloat16:
685
                continue
686
            rand_a = torch.rand(1024, dtype=data_type, device=dev)
687
            rand_b = torch.rand(1024, dtype=data_type, device=dev)
688
            in1 = 20 * torch.rand(1024, dtype=data_type, device=dev)
689
            in2 = 20 * torch.rand(1024, dtype=data_type, device=dev)
690
            traced = torch.jit.trace(torch_fn, (in1, in2))
691
            x = warmup_and_run_forward(traced, rand_a, rand_b)
692
            self.assertLastGraphAllFused()
693

694
            _atol = 2e-3
695
            _rtol = 1e-5
696
            if data_type is torch.bfloat16:
697
                # Compared to aten logic, NNC coudl save addtional BF16/Fp32 conversion.
698
                # Take d = a + b - c as an example, the aten logic is as follows at
699
                # operator level:
700
                #    tmp = to_bf16(to_fp32(a) + to_fp32(b))
701
                #    d = to_bf16(to_fp32(tmp) + to_fp32(c))
702
                # But NNC could fuse the compression and remove the redudant conversions.
703
                # The final statement is as follows
704
                #    d = to_bf16(to_fp32(a) + to_fp32(b) + to_fp32(c))
705
                # Hence, we simulate NNC computation by feeding fp32 tensors and converting
706
                # the result tensor back to bf16. The simulation could avoid the numeric
707
                # deviation to simplify the result comprasion
708
                y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
709
                if torch_fn not in cmp_fns:
710
                    y = y.bfloat16()
711
                _atol = 2e-2
712
            else:
713
                y = torch_fn(rand_a, rand_b)
714
            self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)
715

716
    def test_unary_ops(self):
717
        def test_cast_float(x, y):
718
            c = torch.ops.aten._cast_Float(torch.add(x, y))
719
            return c
720

721
        def test_round(x, y):
722
            c = torch.round(torch.add(x, y))
723
            return c
724

725
        def test_sin(x, y):
726
            c = torch.sin(torch.add(x, y))
727
            return c
728

729
        def test_asin(x, y):
730
            c = torch.asin(torch.add(x, y))
731
            return c
732

733
        def test_sinh(x, y):
734
            c = torch.sinh(torch.add(x, y))
735
            return c
736

737
        def test_cos(x, y):
738
            c = torch.cos(torch.add(x, y))
739
            return c
740

741
        def test_acos(x, y):
742
            c = torch.acos(torch.add(x, y))
743
            return c
744

745
        def test_cosh(x, y):
746
            c = torch.cosh(torch.add(x, y))
747
            return c
748

749
        def test_tan(x, y):
750
            c = torch.tan(torch.add(x, y))
751
            return c
752

753
        def test_atan(x, y):
754
            c = torch.atan(torch.add(x, y))
755
            return c
756

757
        def test_tanh(x, y):
758
            c = torch.tanh(torch.add(x, y))
759
            return c
760

761
        def test_sqrt(x, y):
762
            c = torch.sqrt(torch.add(x, y))
763
            return c
764

765
        def test_rsqrt(x, y):
766
            c = torch.rsqrt(torch.add(x, y))
767
            return c
768

769
        def test_floor(x, y):
770
            c = torch.floor(torch.add(x, y))
771
            return c
772

773
        def test_ceil(x, y):
774
            c = torch.ceil(torch.add(x, y))
775
            return c
776

777
        def test_trunc(x, y):
778
            c = torch.trunc(torch.add(x, y))
779
            return c
780

781
        def test_abs(x, y):
782
            c = torch.abs(torch.add(x, y))
783
            return c
784

785
        def test_log(x, y):
786
            c = torch.log(torch.add(x, y))
787
            return c
788

789
        def test_log2(x, y):
790
            c = torch.log2(torch.add(x, y))
791
            return c
792

793
        def test_log10(x, y):
794
            c = torch.log10(torch.add(x, y))
795
            return c
796

797
        def test_log1p(x, y):
798
            c = torch.log1p(torch.add(x, y))
799
            return c
800

801
        def test_rqrt(x, y):
802
            c = torch.rsqrt(torch.add(x, y))
803
            return c
804

805
        def test_erf(x, y):
806
            c = torch.erf(torch.add(x, y))
807
            return c
808

809
        def test_exp(x, y):
810
            c = torch.exp(torch.add(x, y))
811
            return c
812

813
        def test_expm1(x, y):
814
            c = torch.expm1(torch.add(x, y))
815
            return c
816

817
        def test_erfc(x, y):
818
            c = torch.erfc(torch.add(x, y))
819
            return c
820

821
        def test_frac(x, y):
822
            c = torch.frac(torch.add(x, y))
823
            return c
824

825
        def test_lgamma(x, y):
826
            c = torch.lgamma(torch.add(x, y))
827
            return c
828

829
        def test_sigmoid(x, y):
830
            c = torch.sigmoid(torch.add(x, y))
831
            return c
832

833
        def test_reciprocal(x, y):
834
            c = torch.reciprocal(torch.add(x, y))
835
            return c
836

837
        def test_neg(x, y):
838
            c = torch.neg(torch.add(x, y))
839
            return c
840

841
        def test_relu(x, y):
842
            c = torch.relu(torch.add(x, y))
843
            return c
844

845
        def test_hardtanh(x, y):
846
            c = F.hardtanh(torch.add(x, y), -1.0, 1.0)
847
            return c
848

849
        def test_threshold(x, y):
850
            c = F.threshold(torch.add(x, y), 0.5, 10)
851
            return c
852

853
        gpu_only_fns = {
854
            test_erf,
855
            test_erfc
856
        }
857
        fns = {
858
            test_round,
859
            test_sin,
860
            test_asin,
861
            test_sinh,
862
            test_cos,
863
            test_acos,
864
            test_cosh,
865
            test_tan,
866
            test_atan,
867
            test_sqrt,
868
            test_floor,
869
            test_ceil,
870
            test_trunc,
871
            test_abs,
872
            test_log,
873
            test_log2,
874
            test_log10,
875
            test_log1p,
876
            test_rsqrt,
877
            test_exp,
878
            test_expm1,
879
            test_frac,
880
            test_lgamma,
881
            test_reciprocal,
882
            test_neg,
883
            test_threshold,
884
            test_relu,
885
            test_tanh,
886
            test_hardtanh,
887
            test_sigmoid,
888
        }
889
        fn_dev_dtype = itertools.product(gpu_only_fns.union(fns), self.devices, self.dtypes)
890

891
        torch.manual_seed(0)
892
        for torch_fn, dev, data_type in fn_dev_dtype:
893
            if torch_fn == test_lgamma and dev == "cuda":
894
                # lgamma_cuda does not support BF16
895
                continue
896
            rand_a = torch.rand(1024, dtype=data_type, device=dev)
897
            rand_b = torch.rand(1024, dtype=data_type, device=dev)
898

899
            ins = 20 * torch.rand(1024, dtype=data_type, device=dev)
900
            cc = np.empty([1024], dtype=np.float32)
901
            cc.fill(np.nan)
902
            nans = torch.from_numpy(cc).to(dev)
903
            traced = torch.jit.trace(torch_fn, (ins, ins))
904
            x = warmup_and_run_forward(traced, rand_a, rand_b)
905
            self.assertLastGraphAllFused()
906

907
            _atol = 5e-3 if data_type is torch.bfloat16 else 2e-3
908
            _rtol = 1e-5
909
            if data_type is torch.bfloat16 and torch_fn not in gpu_only_fns:
910
                y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
911
                y = y.bfloat16()
912
            else:
913
                y = torch_fn(rand_a, rand_b)
914

915
            self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)
916
            # nans
917
            # TODO: reenable. Currently all of the tests fail
918
            # traced = torch.jit.trace(torch_fn, (ins, ins))
919
            # x = warmup_and_run_forward(traced, rand_a, rand_b)
920
            # y = torch_fn(nans, rand_b)
921
            # try:
922
            #     np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
923
            #     print("Succeeded on dev=", dev, "function=", torch_fn)
924
            # except AssertionError:
925
            #     # Print extra info before exiting:
926
            #     print("Failed on dev=", dev, "function=", torch_fn)
927
            #     # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
928

929

930
    def test_round_2(self):
931
        def round(x):
932
            return torch.round(x)
933

934
        for data_type in [torch.float32, torch.double]:
935
            a = torch.tensor([0.2, 1.6, 2.5, 3.5]).to(data_type)
936
            traced = torch.jit.trace(round, (a))
937
            x = warmup_and_run_forward(traced, a)
938
            self.assertLastGraphAllFused()
939
            y = round(x)
940
            self.assertEqual(x, y)
941

942
    def test_rand_like(self):
943
        N = 1 << 16
944

945
        def run_rand_like(x, y):
946
            return torch.rand_like(torch.add(x, y))
947

948
        for device in self.devices:
949
            x = torch.rand(N, device=device)
950
            traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False)
951

952
            for data_type in self.dtypes:
953
                _x = x.to(dtype=data_type)
954
                x_v = warmup_and_run_forward(traced, _x, _x)
955
                self.assertLastGraphAllFused()
956

957
            x_np = x.cpu().numpy()
958
            x1_mean = np.mean(x_np)
959
            x2_mean = np.mean(x_np ** 2)
960
            x3_mean = np.mean(x_np ** 3)
961
            np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2)
962
            np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2)
963
            np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2)
964

965
    def test_nans(self):
966
        def test_max(x, y):
967
            return torch.max(2 * x, 2 * y)
968

969
        def test_min(x, y):
970
            return torch.min(2 * x, 2 * y)
971

972
        tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1)))
973
        tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1)))
974

975
        for data_type in self.dtypes:
976
            x = torch.tensor([np.nan]).to(dtype=data_type)
977
            y = torch.tensor([1.0]).to(dtype=data_type)
978

979
        assert np.isnan(warmup_and_run_forward(tmin, x, y).float().item())
980
        assert np.isnan(warmup_and_run_forward(tmin, y, x).float().item())
981
        self.assertLastGraphAllFused()
982
        assert np.isnan(warmup_and_run_forward(tmax, x, y).float().item())
983
        assert np.isnan(warmup_and_run_forward(tmax, y, x).float().item())
984
        self.assertLastGraphAllFused()
985

986
    def test_double_intrinsics(self):
987
        def do_pow(x):
988
            return torch.pow(x, 7)
989

990
        for device in self.devices:
991
            x = torch.rand(10, dtype=torch.double, device=device)
992
            traced = torch.jit.trace(do_pow, (x))
993
            x = warmup_and_run_forward(traced, x)
994
            self.assertLastGraphAllFused()
995

996
    def test_remainder(self):
997
        def run_remainder(x, y):
998
            c = torch.remainder(torch.add(x, y), x)
999
            return c
1000

1001
        for data_type in self.dtypes:
1002
            a = torch.rand(1024, dtype=data_type)
1003
            b = torch.rand(1024, dtype=data_type)
1004
            zeros = torch.zeros(1024, dtype=data_type)
1005
            cc = np.array(1024, dtype=float)
1006
            cc.fill(np.nan)
1007
            nans = torch.from_numpy(cc).to(dtype=data_type)
1008

1009
            # random floats
1010
            zeros1 = torch.zeros(1024, dtype=data_type)
1011
            zeros2 = torch.zeros(1024, dtype=data_type)
1012

1013
            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1014
            x = warmup_and_run_forward(traced, a, b)
1015
            self.assertLastGraphAllFused()
1016
            y = run_remainder(a, b)
1017
            if data_type is torch.bfloat16:
1018
                self.assertEqual(x, y, atol=4e-3, rtol=2e-3)
1019
            else:
1020
                self.assertEqual(x, y)
1021

1022
            # div by 0
1023
            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1024
            x = warmup_and_run_forward(traced, zeros, a)
1025
            self.assertLastGraphAllFused()
1026
            y = run_remainder(zeros, a)
1027
            self.assertEqual(x, y)
1028

1029
            # numerators and denominatos are nan
1030
            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1031
            x = warmup_and_run_forward(traced, nans, a)
1032
            self.assertLastGraphAllFused()
1033
            y = run_remainder(nans, a)
1034
            self.assertEqual(x, y)
1035

1036
    def test_multioutput(self):
1037
        def easy(x):
1038
            b = x + 1
1039
            c = b + b
1040
            return (b, c)
1041

1042
        traced = torch.jit.trace(easy, (torch.zeros(1024)))
1043

1044
        a = torch.zeros(1024)
1045
        b, c = warmup_and_run_forward(traced, a)
1046
        self.assertLastGraphAllFused()
1047
        bp = a.numpy() + 1
1048
        cp = bp + bp
1049
        np.testing.assert_allclose(b.numpy(), bp)
1050
        np.testing.assert_allclose(c.numpy(), cp)
1051

1052
    def test_chunk(self):
1053
        def easy(x):
1054
            y = x + 1
1055
            aaa, bbb = torch.chunk(y, 2)
1056
            return aaa + bbb
1057

1058
        for data_type in self.dtypes:
1059
            trace_input = torch.zeros(1024, 1024, dtype=data_type)
1060
            traced = torch.jit.trace(easy, (trace_input))
1061

1062
            a = torch.zeros(32, 32, dtype=data_type)
1063
            x = warmup_and_run_forward(traced, a)
1064
            self.assertLastGraphAllFused()
1065
            npr = a.float().numpy()
1066
            npr2 = npr + 1
1067
            npr_a, npr_b = np.array_split(npr2, 2)
1068
            np.testing.assert_allclose(npr_a + npr_b, x.float().numpy())
1069

1070
    def test_cat(self):
1071
        for device in self.devices:
1072
            _dim = 1
1073

1074
            def foo(*args):
1075
                args_2 = [v + i for i, v in enumerate(args)]
1076
                v = torch.cat(args_2, dim=_dim)
1077
                return v * v
1078

1079
            for data_type in self.dtypes:
1080
                M = 16
1081
                Ns = [128, 16, 1]
1082
                values = [torch.zeros(M, N, dtype=data_type, device=device) for N in Ns]
1083
                traced = torch.jit.trace(foo, values)
1084

1085
                x = warmup_and_run_forward(traced, *values)
1086
                self.assertLastGraphAllFused()
1087
                ref = foo(*values)
1088
                np.testing.assert_allclose(ref.cpu().float().numpy(), x.cpu().float().numpy())
1089

1090
            # Test channels-last
1091
            for _cur_dim in range(4):
1092
                _dim = _cur_dim
1093
                values = [torch.randn((2, 3, 4, 5), device=device).to(memory_format=torch.channels_last) for _ in range(10)]
1094
                traced = torch.jit.trace(foo, values)
1095

1096
                x = warmup_and_run_forward(traced, *values)
1097
                self.assertLastGraphAllFused()
1098
                ref = foo(*values)
1099
                self.assertEqual(ref, x)
1100

1101
    # This test checks that we correctly handle fusion group with just aten::cat in it.
1102
    # Note that the test only makes sense with min_fusion_group=1, otherwise no
1103
    # fusion groups would be formed at all.
1104
    # TODO: Fix and re-enable the test.
1105
    @unittest.skip("cat is broken with fusion group inlining disabled")
1106
    def test_cat_only(self):
1107
        for device in self.devices:
1108
            def foo(*args):
1109
                args_2 = [v + i for i, v in enumerate(args)]
1110
                v = torch.cat(args_2, dim=1)
1111
                return v
1112

1113
            M = 16
1114
            Ns = [128, 16, 1]
1115
            values = [torch.zeros(M, N, device=device) for N in Ns]
1116
            traced = torch.jit.trace(foo, values)
1117

1118
            x = warmup_and_run_forward(traced, *values)
1119
            self.assertLastGraphAllFused()
1120
            ref = foo(*values)
1121
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1122

1123
    def test_cat_negative_dim(self):
1124
        for device in self.devices:
1125
            def foo(*args):
1126
                v = torch.cat(args, dim=-1)
1127
                return v * v
1128

1129
            M = 16
1130
            Ns = [128, 16, 1]
1131
            values = [torch.randn(M, N, device=device) for N in Ns]
1132
            traced = torch.jit.trace(foo, values)
1133

1134
            x = warmup_and_run_forward(traced, *values)
1135
            self.assertLastGraphAllFused()
1136
            ref = foo(*values)
1137
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1138

1139
    def test_cat_promote_inputs(self):
1140
        for device in self.devices:
1141
            def foo(*args):
1142
                v = torch.cat(args, dim=1)
1143
                return v * v
1144

1145
            M = 16
1146
            Ns = [128, 16, 1]
1147
            dtypes = [torch.half, torch.float32, torch.double]
1148
            values = [torch.randn(M, N, device=device, dtype=dt) for N, dt in zip(Ns, dtypes)]
1149
            traced = torch.jit.trace(foo, values)
1150

1151
            x = warmup_and_run_forward(traced, *values)
1152
            self.assertLastGraphAllFused()
1153
            ref = foo(*values)
1154
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1155

1156
    def test_cat_empty_tensors(self):
1157
        for device in self.devices:
1158
            def foo(*args):
1159
                v = torch.cat(args, dim=1)
1160
                return v * v
1161

1162
            M = 16
1163
            Ns = [128, 16, 1]
1164
            empty = torch.tensor([], device=device, dtype=torch.double)
1165
            values = [empty] + [torch.randn(M, N, device=device) for N in Ns]
1166
            traced = torch.jit.trace(foo, values)
1167

1168
            x = warmup_and_run_forward(traced, *values)
1169
            self.assertLastGraphAllFused()
1170
            ref = foo(*values)
1171
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1172

1173
            # now test with only empty tensors
1174
            values = [empty for i in range(3)]
1175
            traced = torch.jit.trace(foo, values)
1176
            x = warmup_and_run_forward(traced, *values)
1177
            self.assertLastGraphAllFused()
1178
            ref = foo(*values)
1179
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1180

1181
    def test_cat_with_constant_dim(self):
1182
        for device in self.devices:
1183
            def foo(*args):
1184
                v1 = torch.cat(args, dim=1)
1185
                v2 = torch.cat([v1], dim=1)
1186
                return v2 * v2
1187

1188
            empty = torch.tensor([], device=device, dtype=torch.float32)
1189
            inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)]
1190
            traced = torch.jit.trace(foo, inputs)
1191

1192
            x = warmup_and_run_forward(traced, *inputs)
1193
            self.assertLastGraphAllFused()
1194
            ref = foo(*inputs)
1195
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1196

1197
    def test_scalar(self):
1198
        @torch.jit.script
1199
        def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor:
1200
            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
1201

1202
        @torch.jit.script
1203
        def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor:
1204
            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
1205

1206
        for test in (test_float, test_int):
1207
            for data_type in self.dtypes:
1208
                x, y, z = (torch.rand(4, dtype=data_type) for i in range(3))
1209
                a, b = 1, 2
1210
                test(x, y, z, a, b)
1211
                r = test(x, y, z, a, b)
1212
                self.assertEqual(r, x + y * a + z * b)
1213

1214
    def test_loop(self):
1215
        @torch.jit.script
1216
        def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor:
1217
            b = y
1218
            for i in range(0, z):
1219
                a = x + y
1220
                b = b + y
1221
            return b
1222

1223
        x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4)
1224
        test(x, y, z)
1225
        r = test(x, y, z)
1226

1227
    def test_slice(self):
1228
        def easy(x, y):
1229
            a = x[0:512:2]
1230
            b = y[0:512:2]
1231
            return a + b
1232

1233
        traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))
1234

1235
        a = torch.ones(1024, 1024)
1236
        x = traced(a, a)
1237
        npr = a[0:512:2]
1238
        npr = npr + npr
1239
        np.testing.assert_allclose(npr.numpy(), x.numpy())
1240

1241
    def test_unsqueeze(self, N=256):
1242
        def easy(x, y):
1243
            a = torch.unsqueeze(x, 0)
1244
            b = torch.unsqueeze(y, 0)
1245
            return a + b
1246

1247
        traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N)))
1248

1249
        a = torch.rand(N, N)
1250
        x = traced(a, a)
1251
        npr = np.expand_dims(a, 0)
1252
        npr = npr + npr
1253
        np.testing.assert_allclose(npr, x.numpy())
1254

1255
    def _test_softmax(self, device):
1256
        def test_softmax(x, y):
1257
            a = F.softmax(x, dim=0, dtype=torch.float32)
1258
            b = F.softmax(y, dim=0, dtype=torch.float32)
1259
            c = F.softmax(x, dim=1, dtype=torch.float32)
1260
            d = F.softmax(y, dim=1, dtype=torch.float32)
1261
            return a + b + c + d
1262

1263
        def test_softmax_neg_index(x, y):
1264
            a = F.softmax(x, dim=-2, dtype=torch.float32)
1265
            b = F.softmax(y, dim=-2, dtype=torch.float32)
1266
            c = F.softmax(x, dim=-1, dtype=torch.float32)
1267
            d = F.softmax(y, dim=-1, dtype=torch.float32)
1268
            return a + b + c + d
1269

1270
        def test_log_softmax(x, y):
1271
            a = F.log_softmax(x, dim=0, dtype=torch.float32)
1272
            b = F.log_softmax(y, dim=0, dtype=torch.float32)
1273
            c = F.log_softmax(x, dim=1, dtype=torch.float32)
1274
            d = F.log_softmax(y, dim=1, dtype=torch.float32)
1275
            return a + b + c + d
1276

1277
        for test in (test_softmax, test_log_softmax, test_softmax_neg_index):
1278
            for data_type in self.dtypes:
1279
                old = torch._C._jit_set_texpr_reductions_enabled(True)
1280
                traced_input = torch.randn(2, 3, dtype=data_type, device=device)
1281
                traced = torch.jit.trace(test, (traced_input, traced_input))
1282
                inp = torch.randn(2, 3, dtype=data_type, device=device)
1283
                res = traced(inp, inp)
1284
                # Use eager mode as reference.
1285
                ref = test(inp, inp)
1286
                np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06)
1287
                torch._C._jit_set_texpr_reductions_enabled(old)
1288

1289
    def test_softmax_cpu(self):
1290
        self._test_softmax('cpu')
1291

1292
    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1293
    @unittest.skip("global allocs are not supported yet.")
1294
    def test_softmax_cuda(self):
1295
        self._test_softmax('cuda')
1296

1297
    def test_half_gelu(self):
1298
        devices = ["cuda"] if torch.cuda.is_available() else []
1299

1300
        @torch.jit.script
1301
        def bias_gelu(bias, y):
1302
            x = bias + y
1303
            return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
1304

1305
        for device in devices:
1306
            a = torch.rand(1024, dtype=torch.half, device=device)
1307
            b = torch.rand(1024, dtype=torch.half, device=device)
1308
            traced = torch.jit.trace(bias_gelu, (a, b))
1309
            x = warmup_and_run_forward(traced, a, b)
1310
            self.assertLastGraphAllFused()
1311

1312
    def test_half_bn_relu(self):
1313
        devices = ["cuda"] if torch.cuda.is_available() else []
1314

1315
        def foo(a, b, c):
1316
            y = torch.nn.functional.batch_norm(a, b, c)
1317
            z = y.relu()
1318
            return z
1319

1320
        for device in devices:
1321
            a = torch.rand(16, 16, dtype=torch.half, device=device)
1322
            b = torch.rand(16, dtype=torch.half, device=device)
1323
            c = torch.rand(16, dtype=torch.half, device=device)
1324
            traced = torch.jit.trace(foo, (a, b, c))
1325
            print(traced.graph)
1326
            x = warmup_and_run_forward(traced, a, b, c)
1327
            self.assertLastGraphAllFused()
1328

1329
    def test_exp_pow(self):
1330
        @torch.jit.script
1331
        def do_exp(x, y, z):
1332
            return ((x * y) * 2) * torch.pow(z, 2)
1333

1334
        for device in self.devices:
1335
            x = torch.rand(10, dtype=torch.double, device=device)
1336
            y = torch.rand(10, dtype=torch.double, device=device)
1337
            z = torch.rand(10, dtype=torch.double, device=device)
1338
            traced = torch.jit.trace(do_exp, (x, y, z))
1339
            x = warmup_and_run_forward(traced, x, y, z)
1340
            self.assertLastGraphAllFused()
1341

1342
    def test_sin_pow(self):
1343
        def test(x):
1344
            return torch.sin(torch.pow(x, 0))
1345

1346
        for data_type, shape in itertools.product(self.dtypes, [[3], [5], [10]]):
1347
            x = torch.rand(shape, dtype=data_type)
1348
            scripted = torch.jit.script(test)
1349
            out = warmup_and_run_forward(scripted, x)
1350
            self.assertLastGraphAllFused()
1351
            self.assertEqual(out, test(x))
1352

1353
    def test_transpose(self):
1354
        @torch.jit.script
1355
        def test(x, y, z):
1356
            return x.transpose(0, 1) + y + z
1357
        x = torch.rand(4, 5, 2, 3)
1358
        y = torch.rand(5, 4, 2, 3)
1359
        z = torch.rand(5, 4, 2, 3)
1360
        ref = test(x, y, z)
1361
        res = test(x, y, z)
1362
        np.testing.assert_allclose(ref.numpy(), res.numpy())
1363

1364
    def test_sliced_stride(self):
1365
        @torch.jit.script
1366
        def test(x, y, z):
1367
            return x + y + z
1368
        x = torch.rand(16, 4, 2, 3)[::2]
1369
        y = torch.rand(8, 4, 2, 3)
1370
        z = torch.rand(8, 4, 2, 3)
1371
        ref = test(x, y, z)
1372
        res = test(x, y, z)
1373
        np.testing.assert_allclose(ref.numpy(), res.numpy())
1374

1375
    @unittest.skip("dynamic shapes are not quite there yet")
1376
    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1377
    def test_dynamic_shape(self):
1378
        with num_profiled_runs(2):
1379
            @torch.jit.script
1380
            def test(x, y, z):
1381
                return x * y * z
1382
            x, y, z = (torch.rand(4, 8).cuda() for _ in range(3))
1383
            ref = test(x, y, z)
1384
            _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
1385
            res = test(x, y, z)
1386
            np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
1387

1388
            # A wild broadcast appears.
1389
            x = torch.rand(4, 8).cuda()
1390
            y = torch.rand(1, 8).cuda()
1391
            z = torch.rand(4, 1).cuda()
1392
            res = test(x, y, z)
1393
            xn, yn, zn = (t.cpu().numpy() for t in (x, y, z))
1394
            np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
1395

1396
            # Mismatched shapes shouldn't reach codegen.
1397
            x = torch.rand(4, 8).cuda()
1398
            y = torch.rand(4, 8).cuda()
1399
            z = torch.rand(5, 8).cuda()
1400
            try:
1401
                res = test(x, y, z)
1402
            except RuntimeError as e:
1403
                assert "The size of tensor a (4) must match" in e.args[0]
1404

1405
            # Changing a static dimension fails guards.
1406
            # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
1407
            # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
1408
            # res = test(x, y, z)
1409
            # print(test.graph_for(x, y, z))
1410
            # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
1411

1412
    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1413
    def test_guard_fails(self):
1414
        @torch.jit.script
1415
        def test(x, y, z):
1416
            return x * y * z
1417
        r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
1418
        r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
1419
        r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
1420
        r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
1421

1422
    def test_bitwise_ops(self):
1423
        def run_and(x, y):
1424
            return x & (x & y)
1425

1426
        def run_or(x, y):
1427
            return x & (x | y)
1428

1429
        def run_xor(x, y):
1430
            return x ^ (x ^ y)
1431

1432
        def run_lshift(x, y):
1433
            return x & (x << y)
1434

1435
        def run_rshift(x, y):
1436
            return x & (x >> y)
1437

1438
        fns = {run_and, run_or, run_xor, run_lshift, run_rshift}
1439

1440
        for device in self.devices:
1441
            for fn in fns:
1442
                a = torch.ones(128, dtype=torch.int32, device=device)
1443
                b = torch.zeros(128, dtype=torch.int32, device=device)
1444
                inp = torch.ones(128, dtype=torch.int32, device=device)
1445
                traced = torch.jit.trace(fn, (inp, inp))
1446
                x = warmup_and_run_forward(traced, a, b)
1447
                self.assertLastGraphAllFused()
1448
                y = fn(a, b)
1449
                np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
1450

1451
    def test_where(self):
1452
        def run_where(x, y):
1453
            return torch.where(torch.gt(x, y), x, y)
1454

1455
        for data_type in self.dtypes:
1456
            a = torch.rand(1024, dtype=data_type)
1457
            b = torch.rand(1024, dtype=data_type)
1458
            zeros = torch.zeros(1024, dtype=data_type)
1459
            traced = torch.jit.trace(run_where, (zeros, zeros))
1460
            x = warmup_and_run_forward(traced, a, b)
1461
            self.assertLastGraphAllFused()
1462
            y = run_where(a, b)
1463
            np.testing.assert_allclose(x.float().numpy(), y.float().numpy())
1464

1465
    def test_multi_rand(self):
1466
        for device in self.devices:
1467
            def test(x):
1468
                y = torch.rand_like(x)
1469
                return (x + y) - (y - x)
1470

1471
            _atol = 2e-3
1472
            _rtol = 1e-5
1473
            for data_type in self.dtypes:
1474
                if data_type is torch.bfloat16:
1475
                    _atol = 2e-2
1476
                a = torch.rand(4, dtype=data_type, device=device)
1477
                scripted = torch.jit.script(test)
1478
                out = warmup_and_run_forward(scripted, a)
1479
                self.assertLastGraphAllFused()
1480
                assert torch.allclose(out, 2 * a, atol=_atol, rtol=_rtol)
1481

1482
    def test_mask(self):
1483
        def test(x):
1484
            return x.unsqueeze(1) == 0
1485

1486
        for d in self.devices:
1487
            for data_type in self.dtypes:
1488
                x = torch.rand(4, dtype=data_type, device=d) > 0.5
1489
                scripted = torch.jit.script(test)
1490
                out = warmup_and_run_forward(scripted, x)
1491
                self.assertLastGraphAllFused()
1492
                assert torch.equal(out, test(x))
1493

1494
    def test_simple_add(self):
1495
        val = torch._C._jit_get_te_generate_block_code()
1496
        torch._C._jit_set_te_generate_block_code(True)
1497
        fall_bk = torch._C._jit_texpr_fallback_allowed()
1498
        torch._C._jit_texpr_set_fallback_allowed(True)
1499

1500
        def simple(a, b):
1501
            return torch.add(a, b)
1502

1503
        a = torch.ones(256, 256)
1504
        b = torch.ones(256, 256)
1505
        traced = torch.jit.trace(simple,
1506
                                 (torch.ones(256, 256), torch.ones(256, 256)))
1507
        f = traced(a, b)
1508
        f_test = np.full((256, 256), 2, dtype=float)
1509
        np.testing.assert_allclose(f.numpy(), f_test)
1510
        torch._C._jit_set_te_generate_block_code(val)
1511
        torch._C._jit_texpr_set_fallback_allowed(fall_bk)
1512

1513
    def test_strided_output_preserved(self):
1514
        def foo(a, b):
1515
            return a + b - a
1516

1517
        # smaller, easier to debug example
1518
        x = torch.arange(6)
1519
        x = torch.as_strided(x, (2, 3), (1, 2))
1520
        total = 0
1521
        for i in range(2):
1522
            for j in range(3):
1523
                x[i, j] = total
1524
                total += 1
1525
        foo_script = torch.jit.script(foo)
1526
        foo_script(x, x)
1527
        foo_script(x, x)
1528
        out_s = foo_script(x, x)
1529
        out_eager = foo(x, x)
1530
        self.assertEqual(out_s, out_eager)
1531
        self.assertEqual(out_s.stride(), out_eager.stride())
1532
        self.assertLastGraphAllFused()
1533

1534
        # more dims
1535
        N, C, H, W, = 2, 3, 4, 5
1536
        x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last)
1537
        foo_script = torch.jit.script(foo)
1538
        foo_script(x, x)
1539
        foo_script(x, x)
1540
        out_s = foo_script(x, x)
1541
        out_eager = foo(x, x)
1542
        self.assertEqual(out_s, out_eager)
1543
        self.assertEqual(out_s.stride(), out_eager.stride())
1544
        self.assertLastGraphAllFused()
1545

1546
    def test_alias_analysis_module(self):
1547
        class AliasModule(nn.Module):
1548
            def __init__(self):
1549
                super().__init__()
1550
                torch.manual_seed(1337)
1551
                self.a = torch.randn(128, 128)
1552
                self.b = torch.randn(128, 128)
1553
                self.c = torch.randn(128, 128)
1554

1555
            def forward(self, x, y, z):
1556
                z = z + self.a
1557
                self.b.add_(y)
1558
                w = z + self.a
1559
                z = w + x
1560
                return z
1561
        x = torch.randn(128, 128)
1562

1563
        def getModule(script):
1564
            am = AliasModule()
1565
            if script:
1566
                return torch.jit.script(am)
1567
            return am
1568

1569
        am = getModule(False)
1570
        am_s = getModule(True)
1571
        ref = am(x, x, x)
1572
        test = am_s(x, x, x)
1573
        torch.testing.assert_close(ref, test)
1574

1575
        # Now do the aliasing
1576
        am.a = am.b
1577
        ref = am(x, x, x)
1578

1579
        am_s.a = am_s.b
1580
        test = am_s(x, x, x)
1581

1582
        torch.testing.assert_close(ref, test)
1583

1584
    def test_alias_analysis_inputs(self):
1585
        class AliasModule(nn.Module):
1586
            def __init__(self):
1587
                super().__init__()
1588
                torch.manual_seed(1337)
1589
                self.a = torch.randn(128, 128)
1590
                self.b = torch.randn(128, 128)
1591
                self.c = torch.randn(128, 128)
1592

1593
            def forward(self, x, y, z):
1594
                x.add_(y)
1595
                w = z + self.a
1596
                z = w + x
1597
                return z
1598

1599
        def getModule(script):
1600
            am = AliasModule()
1601
            if script:
1602
                return torch.jit.script(am)
1603
            return am
1604
        am = getModule(False)
1605
        am_s = getModule(True)
1606

1607
        torch.manual_seed(1337)
1608
        x = torch.randn(128, 128)
1609
        ref = am(x, x, x)
1610

1611
        torch.manual_seed(1337)
1612
        x = torch.randn(128, 128)
1613
        test = am_s(x, x, x)
1614

1615
        torch.testing.assert_close(ref, test)
1616

1617
    def test_alias_analysis_input_and_module(self):
1618
        class AliasModule(nn.Module):
1619
            def __init__(self):
1620
                super().__init__()
1621
                torch.manual_seed(1337)
1622
                self.a = torch.randn(128, 128)
1623
                self.b = torch.randn(128, 128)
1624
                self.c = torch.randn(128, 128)
1625

1626
            def forward(self, x, y, z):
1627
                x.add_(y)
1628
                w = z + self.b
1629
                z = w + x
1630
                return z
1631

1632
        def getModule(script):
1633
            am = AliasModule()
1634
            if script:
1635
                return torch.jit.script(am)
1636
            return am
1637
        am = getModule(False)
1638
        am_s = getModule(True)
1639

1640
        torch.manual_seed(1337)
1641
        x = torch.randn(128, 128)
1642
        am.b = x
1643
        ref = am(x, x, x)
1644

1645
        torch.manual_seed(1337)
1646
        x = torch.randn(128, 128)
1647
        am_s.b = x
1648
        test = am_s(x, x, x)
1649

1650
        torch.testing.assert_close(ref, test)
1651

1652
    def test_multiple_outputs(self):
1653
        for device in self.devices:
1654
            # A bug reported internally similar to the one reported in #48533
1655
            def foo(a, b, c):
1656
                t_next = c + 1
1657
                t5 = t_next * b
1658
                t6 = torch.unsqueeze(t_next, 1)
1659
                t7 = a * t6
1660
                return (t7, t5, t_next)
1661

1662
            for data_type in self.dtypes:
1663
                a = torch.rand(20, 20, dtype=data_type, device=device)
1664
                b = torch.rand(20 * 29, dtype=data_type, device=device).as_strided([20], [29])
1665
                c = torch.ones(20, dtype=torch.int64, device=device)
1666
                traced = torch.jit.trace(foo, (a, b, c))
1667
                ref = foo(a, b, c)
1668
                exp = traced(a, b, c)
1669
                exp = traced(a, b, c)
1670
                self.assertEqual(ref, exp)
1671

1672
    def test_propagated_mem_layout(self):
1673
        def foo(a, b, c):
1674
            t_next = c + 1
1675
            t5 = t_next * b
1676
            t7 = a * t5
1677
            return t7
1678

1679
        def foo_multi_outputs(a, b, c):
1680
            t_next = c + 1
1681
            t5 = b * t_next
1682
            t7 = a * t5
1683
            return (t7, t5, t_next)
1684

1685
        def foo_multi_outputs_i_nhwc_o_nchw(a, b, c):
1686
            t_next = c + 1
1687
            t5 = b * t_next
1688
            t7 = a * t5
1689
            t8 = t7.to(memory_format=torch.contiguous_format)
1690
            return (t8, t7, t5, t_next)
1691

1692
        def run_foo_case(foo, a, b, c):
1693
            traced_contiguous = torch.jit.trace(foo, (a, b, c))
1694
            ref = foo(a, b, c)
1695
            exp = traced_contiguous(a, b, c)
1696
            exp = traced_contiguous(a, b, c)
1697
            self.assertEqual(ref, exp)
1698

1699
        mem_layouts = list(itertools.product([torch.contiguous_format, torch.channels_last], repeat=3))
1700
        shapes = [(2, 3, 4, 5), (2, 1, 1, 5), (1, 1, 1, 1)]
1701
        permutes = [(0, 3, 2, 1), (0, 3, 1, 2)]
1702
        funcs = [foo, foo_multi_outputs, foo_multi_outputs_i_nhwc_o_nchw]
1703
        configs = itertools.product(funcs, shapes, mem_layouts, permutes)
1704
        for strategy in ["STATIC", "DYNAMIC"]:
1705
            old_strategy = torch.jit.set_fusion_strategy([(strategy, 10)])
1706
            for _func, _shape, _mem_layouts, _permute in configs:
1707
                a = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[0])
1708
                b = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[1])
1709
                c = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[2])
1710
                run_foo_case(_func, a, b, c)
1711

1712
                a = a.permute(dims=_permute)
1713
                b = b.permute(dims=_permute)
1714
                c = c.permute(dims=_permute)
1715
                run_foo_case(_func, a, b, c)
1716

1717
            torch.jit.set_fusion_strategy(old_strategy)
1718

1719
if __name__ == '__main__':
1720
    run_tests()
1721

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.