5
import torch.nn.functional as F
10
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo
12
from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions
14
LLVM_ENABLED = torch._C._llvm_enabled()
16
class BaseTestClass(JitTestCase):
19
self.tensorexpr_options = TensorExprTestOptions()
20
self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
21
self.dtypes = [torch.float32, torch.bfloat16] if LLVM_ENABLED else [torch.float32]
24
self.tensorexpr_options.restore()
27
def assertLastGraphAllFused(self):
28
self.assertAllFused(torch.jit.last_executed_optimized_graph())
31
def warmup_and_run_forward(f, *args):
32
for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
38
class TestTensorExprFuser(BaseTestClass):
44
traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))
48
x = warmup_and_run_forward(traced, a, b)
49
self.assertLastGraphAllFused()
50
np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
52
def test_three_arg(self):
55
bbb = torch.add(aaa, z)
58
traced = torch.jit.trace(
59
easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
65
x = warmup_and_run_forward(traced, a, b, c)
66
self.assertLastGraphAllFused()
67
npr = a.numpy() + b.numpy() + c.numpy()
68
np.testing.assert_allclose(npr, x.numpy())
70
def test_four_arg(self):
71
def run_addcmul(x, y, z, w):
72
c = torch.addcmul(torch.add(x, y), z, w)
75
for dev in self.devices:
76
rand_a = torch.rand(1024, dtype=torch.float, device=dev)
77
rand_b = torch.rand(1024, dtype=torch.float, device=dev)
78
rand_c = torch.rand(1024, dtype=torch.float, device=dev)
79
rand_d = torch.rand(1024, dtype=torch.float, device=dev)
81
traced = torch.jit.trace(
84
torch.zeros(1024, dtype=torch.float, device=dev),
85
torch.zeros(1024, dtype=torch.float, device=dev),
86
torch.zeros(1024, dtype=torch.float, device=dev),
87
torch.zeros(1024, dtype=torch.float, device=dev),
91
x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d)
92
self.assertLastGraphAllFused()
93
y = run_addcmul(rand_a, rand_b, rand_c, rand_d)
94
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6)
96
def test_three_arg2(self):
97
for device in self.devices:
100
bbb = torch.add(aaa, z)
105
traced = torch.jit.trace(
108
torch.rand(M, N, device=device),
109
torch.rand(M, N, device=device),
110
torch.rand(M, N, device=device),
114
a = torch.rand(M, N, device=device)
115
b = torch.rand(M, N, device=device)
116
c = torch.rand(M, N, device=device)
118
x = warmup_and_run_forward(traced, a, b, c)
119
self.assertLastGraphAllFused()
120
npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
121
np.testing.assert_allclose(npr, x.cpu().numpy())
123
def test_broadcast3(self):
124
for device in self.devices:
125
def test_body(M, N, L, K):
128
v2 = torch.add(v1, z)
133
c_shape = [K, L, 1, 1]
134
traced = torch.jit.trace(
137
torch.rand(*a_shape, device=device),
138
torch.rand(*b_shape, device=device),
139
torch.rand(*c_shape, device=device),
143
a = torch.rand(*a_shape, device=device)
144
b = torch.rand(*b_shape, device=device)
145
c = torch.rand(*c_shape, device=device)
146
x = warmup_and_run_forward(traced, a, b, c)
147
self.assertLastGraphAllFused()
148
npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
149
np.testing.assert_allclose(npr, x.cpu().numpy())
151
test_configs = [[5, 2, 7, 3], [8, 8, 8, 8]]
152
for test_config in test_configs:
153
test_body(*test_config)
155
def test_all_combos(self):
163
def np_easy(x, y, z):
170
traced = torch.jit.trace(
171
easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
177
x = warmup_and_run_forward(traced, a, b, c)
178
self.assertLastGraphAllFused()
179
npr = np_easy(a.numpy(), b.numpy(), c.numpy())
180
np.testing.assert_allclose(npr, x.numpy())
182
def test_rank_two(self):
190
def np_easy(x, y, z):
198
traced = torch.jit.trace(
199
easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape))
202
a = torch.rand(shape)
203
b = torch.rand(shape)
204
c = torch.rand(shape)
205
x = warmup_and_run_forward(traced, a, b, c)
206
self.assertLastGraphAllFused()
207
npr = np_easy(a.numpy(), b.numpy(), c.numpy())
208
np.testing.assert_allclose(npr, x.numpy())
210
def test_broadcast(self):
216
def np_easy(x, y, z):
222
traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N)))
227
x = warmup_and_run_forward(traced, a, b, c)
228
self.assertLastGraphAllFused()
229
npr = np_easy(a.numpy(), b.numpy(), c.numpy())
230
np.testing.assert_allclose(npr, x.numpy())
232
def test_broadcast_2(self):
233
zero = torch.tensor([0.0], dtype=torch.float)
236
aaa = torch.add(x, y)
237
bbb = torch.add(zero, aaa)
238
return torch.add(bbb, z)
248
traced = torch.jit.trace(foo, (x, y, z))
250
r = warmup_and_run_forward(traced, x, y, z)
251
self.assertLastGraphAllFused()
253
rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
254
np.testing.assert_allclose(r, rnp)
256
def test_broadcast_big2(self):
257
zero = torch.tensor([0.0], dtype=torch.float)
260
aaa = torch.add(x, y)
261
bbb = torch.add(zero, aaa)
262
return torch.add(bbb, z)
269
x = torch.rand(32, 1024)
270
y = torch.ones(32, 1)
272
traced = torch.jit.trace(foo, (x, y, z))
274
r = warmup_and_run_forward(traced, x, y, z)
275
self.assertLastGraphAllFused()
276
rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
277
np.testing.assert_allclose(r, rnp)
279
def test_alpha(self):
281
aaa = torch.add(x, x, alpha=2.0)
284
traced = torch.jit.trace(alpha, (torch.tensor([1.0])))
286
a = torch.tensor([1.0])
288
np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy())
291
def test_constant(self):
293
bbb = torch.tensor([1.0])
294
aaa = torch.add(x, bbb)
297
traced = torch.jit.trace(constant, (torch.tensor([1.0])))
299
a = torch.tensor([1.0])
300
x = warmup_and_run_forward(traced, a)
301
self.assertLastGraphAllFused()
302
np.testing.assert_allclose(a.numpy() + 1.0, x.numpy())
304
def test_add_sub(self):
306
aaa = torch.add(x, y)
307
bbb = torch.sub(aaa, z)
310
traced = torch.jit.trace(
311
easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
317
x = warmup_and_run_forward(traced, a, b, c)
318
self.assertLastGraphAllFused()
319
np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy())
321
def test_promotion(self):
323
aaa = torch.add(x, y)
326
traced = torch.jit.trace(
328
(torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)),
331
a = torch.zeros(1024, dtype=torch.int32)
332
b = torch.rand(1024, dtype=torch.float32)
333
x = warmup_and_run_forward(traced, a, b)
334
self.assertLastGraphAllFused()
335
np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())
337
def test_double(self):
341
aaa = torch.add(x, y)
342
bbb = torch.mul(aaa, y)
345
traced = torch.jit.trace(
347
(torch.rand(TENSOR_LEN, dtype=torch.float64), torch.full((TENSOR_LEN,), 0.5, dtype=torch.float64)),
350
a = torch.rand(TENSOR_LEN, dtype=torch.double)
351
b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double)
352
x = warmup_and_run_forward(traced, a, b)
353
self.assertLastGraphAllFused()
354
np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
356
def test_short(self):
360
aaa = torch.add(x, y)
361
bbb = torch.mul(aaa, y)
364
traced = torch.jit.trace(
366
(torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16),
367
torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)),
370
a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
371
b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
372
x = warmup_and_run_forward(traced, a, b)
373
self.assertLastGraphAllFused()
374
np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
380
aaa = torch.add(x, y)
381
bbb = torch.mul(aaa, y)
384
traced = torch.jit.trace(
386
(torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
387
torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)),
390
a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
391
b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
392
x = warmup_and_run_forward(traced, a, b)
393
self.assertLastGraphAllFused()
394
np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
396
def test_int64_promotion(self):
400
aaa = torch.add(x, y)
401
bbb = torch.mul(aaa, y)
404
traced = torch.jit.trace(
406
(torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
407
torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)),
410
a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
411
b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)
412
x = warmup_and_run_forward(traced, a, b)
413
self.assertLastGraphAllFused()
414
np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())
421
traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
422
a = torch.zeros(1024, dtype=torch.int32)
423
b = torch.zeros(1024, dtype=torch.int32)
424
x = warmup_and_run_forward(traced, a, b)
425
self.assertLastGraphAllFused()
426
np.testing.assert_allclose(np.ones(1024), x.numpy())
433
traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
434
a = torch.zeros(1024, dtype=torch.int32)
435
b = torch.ones(1024, dtype=torch.int32)
436
x = warmup_and_run_forward(traced, a, b)
437
self.assertLastGraphAllFused()
438
np.testing.assert_allclose(np.ones(1024), x.numpy())
445
traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
446
aa = np.empty([1024], dtype=np.int32)
448
a = torch.from_numpy(aa)
449
b = torch.zeros(1024, dtype=torch.int32)
450
x = warmup_and_run_forward(traced, a, b)
451
self.assertLastGraphAllFused()
452
np.testing.assert_allclose(np.ones(1024), x.numpy())
459
traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
460
a = torch.ones(1024, dtype=torch.int32)
461
b = torch.zeros(1024, dtype=torch.int32)
462
x = warmup_and_run_forward(traced, a, b)
463
self.assertLastGraphAllFused()
464
np.testing.assert_allclose(np.ones(1024), x.numpy())
471
traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
472
aa = np.empty([1024], dtype=np.int32)
474
a = torch.from_numpy(aa)
475
b = torch.zeros(1024, dtype=torch.int32)
476
x = warmup_and_run_forward(traced, a, b)
477
self.assertLastGraphAllFused()
478
np.testing.assert_allclose(np.zeros(1024), x.numpy())
485
for dev in self.devices:
486
traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev)))
487
a = torch.ones(1024, dtype=torch.int32, device=dev)
488
b = torch.zeros(1024, dtype=torch.int32, device=dev)
489
x = warmup_and_run_forward(traced, a, b)
490
self.assertLastGraphAllFused()
491
np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy())
494
def test_min_max(self):
496
return torch.max(torch.min(x, y), torch.tensor([4.0]))
498
traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024)))
499
a = 8.0 * torch.rand(1024)
500
b = 8.0 * torch.rand(1024)
501
np.testing.assert_allclose(
502
warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0])
504
self.assertLastGraphAllFused()
506
def test_min_max_reduction(self):
508
return torch.min(x) + torch.max(x)
510
traced = torch.jit.trace(test, (torch.zeros(1024)))
511
a = 8.0 * torch.rand(1024)
512
np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
513
self.assertLastGraphAllFused()
515
def test_min_max_reduction2(self):
517
return x.min() + x.max()
519
traced = torch.jit.trace(test, (torch.zeros(1024)))
520
a = 8.0 * torch.rand(1024)
521
np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
522
self.assertLastGraphAllFused()
524
def test_min_max_reduction_dim1(self):
526
return torch.min(x, 1)[0] + torch.max(x, 1)[0]
528
traced = torch.jit.trace(test, (torch.zeros(16, 16)))
529
a = 8.0 * torch.rand(16, 16)
530
np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(
531
a.numpy(), axis=1) + np.amax(a.numpy(), axis=1))
532
self.assertLastGraphAllFused()
534
def test_min_max_reduction_dim1_2(self):
536
return torch.min(x * x, 1)
538
traced = torch.jit.trace(test, (torch.zeros(16, 16)))
539
a = 8.0 * torch.rand(16, 16)
540
np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1))
541
self.assertLastGraphAllFused()
543
def test_clamp(self):
545
return torch.clamp(x + 3.0, 0.0, 6.0)
547
for dev in self.devices:
548
traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
549
a = 20.0 * torch.rand(1024, device=dev) - 10.0
551
np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0))
552
self.assertLastGraphAllFused()
556
return torch.clamp(F.relu(x), 0, 0.5)
558
for dev in self.devices:
559
traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
560
a = 20.0 * torch.rand(1024, device=dev) - 10.0
562
np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5))
563
self.assertLastGraphAllFused()
570
traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))
574
b = torch.zeros(1024)
575
x = warmup_and_run_forward(traced, a, b)
576
np.testing.assert_allclose(np.ones(1024), x.numpy())
578
def test_add_const_rhs(self):
582
traced = torch.jit.trace(test, torch.rand(4))
584
y = warmup_and_run_forward(traced, x)
585
self.assertLastGraphAllFused()
586
np.testing.assert_allclose(x.numpy() + 3.0, y.numpy())
588
def test_int_output(self):
592
xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)]
594
xn, yn, zn = (t.numpy() for t in xs)
595
traced = torch.jit.trace(test, (x, y, z))
596
res = warmup_and_run_forward(traced, x, y, z)
597
self.assertLastGraphAllFused()
598
np.testing.assert_allclose(xn * yn * zn, res.numpy())
600
def test_binary_ops(self):
601
def test_atan2(x, y):
602
c = torch.atan2(torch.add(x, y), y)
606
c = torch.gt(torch.add(x, y), y)
610
c = torch.ge(torch.add(x, y), y)
614
c = torch.lt(torch.add(x, y), y)
618
c = torch.le(torch.add(x, y), y)
622
c = torch.lerp(torch.add(x, 1), x, 2.0)
626
c = torch.mul(torch.add(x, y), y)
630
c = torch.ne(torch.add(x, y), y)
634
c = torch.div(torch.add(x, y), 2)
638
c = torch.eq(torch.add(x, y), y)
642
c = torch.fmod(torch.add(x, y), 2)
646
c = torch.sub(torch.add(x, y), x)
649
def test_remainder(x, y):
650
c = torch.remainder(torch.add(x, y), 3.0)
654
c = torch.pow(torch.add(x, y), 2.0)
657
def test_type_as(x, y):
658
return x.type_as(torch.add(x, y))
681
all_test_fns = cmp_fns.union(non_cmp_fns)
682
fn_dev_dtype = itertools.product(all_test_fns, self.devices, self.dtypes)
683
for torch_fn, dev, data_type in fn_dev_dtype:
684
if torch_fn is test_lerp and data_type is torch.bfloat16:
686
rand_a = torch.rand(1024, dtype=data_type, device=dev)
687
rand_b = torch.rand(1024, dtype=data_type, device=dev)
688
in1 = 20 * torch.rand(1024, dtype=data_type, device=dev)
689
in2 = 20 * torch.rand(1024, dtype=data_type, device=dev)
690
traced = torch.jit.trace(torch_fn, (in1, in2))
691
x = warmup_and_run_forward(traced, rand_a, rand_b)
692
self.assertLastGraphAllFused()
696
if data_type is torch.bfloat16:
708
y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
709
if torch_fn not in cmp_fns:
713
y = torch_fn(rand_a, rand_b)
714
self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)
716
def test_unary_ops(self):
717
def test_cast_float(x, y):
718
c = torch.ops.aten._cast_Float(torch.add(x, y))
721
def test_round(x, y):
722
c = torch.round(torch.add(x, y))
726
c = torch.sin(torch.add(x, y))
730
c = torch.asin(torch.add(x, y))
734
c = torch.sinh(torch.add(x, y))
738
c = torch.cos(torch.add(x, y))
742
c = torch.acos(torch.add(x, y))
746
c = torch.cosh(torch.add(x, y))
750
c = torch.tan(torch.add(x, y))
754
c = torch.atan(torch.add(x, y))
758
c = torch.tanh(torch.add(x, y))
762
c = torch.sqrt(torch.add(x, y))
765
def test_rsqrt(x, y):
766
c = torch.rsqrt(torch.add(x, y))
769
def test_floor(x, y):
770
c = torch.floor(torch.add(x, y))
774
c = torch.ceil(torch.add(x, y))
777
def test_trunc(x, y):
778
c = torch.trunc(torch.add(x, y))
782
c = torch.abs(torch.add(x, y))
786
c = torch.log(torch.add(x, y))
790
c = torch.log2(torch.add(x, y))
793
def test_log10(x, y):
794
c = torch.log10(torch.add(x, y))
797
def test_log1p(x, y):
798
c = torch.log1p(torch.add(x, y))
802
c = torch.rsqrt(torch.add(x, y))
806
c = torch.erf(torch.add(x, y))
810
c = torch.exp(torch.add(x, y))
813
def test_expm1(x, y):
814
c = torch.expm1(torch.add(x, y))
818
c = torch.erfc(torch.add(x, y))
822
c = torch.frac(torch.add(x, y))
825
def test_lgamma(x, y):
826
c = torch.lgamma(torch.add(x, y))
829
def test_sigmoid(x, y):
830
c = torch.sigmoid(torch.add(x, y))
833
def test_reciprocal(x, y):
834
c = torch.reciprocal(torch.add(x, y))
838
c = torch.neg(torch.add(x, y))
842
c = torch.relu(torch.add(x, y))
845
def test_hardtanh(x, y):
846
c = F.hardtanh(torch.add(x, y), -1.0, 1.0)
849
def test_threshold(x, y):
850
c = F.threshold(torch.add(x, y), 0.5, 10)
889
fn_dev_dtype = itertools.product(gpu_only_fns.union(fns), self.devices, self.dtypes)
892
for torch_fn, dev, data_type in fn_dev_dtype:
893
if torch_fn == test_lgamma and dev == "cuda":
896
rand_a = torch.rand(1024, dtype=data_type, device=dev)
897
rand_b = torch.rand(1024, dtype=data_type, device=dev)
899
ins = 20 * torch.rand(1024, dtype=data_type, device=dev)
900
cc = np.empty([1024], dtype=np.float32)
902
nans = torch.from_numpy(cc).to(dev)
903
traced = torch.jit.trace(torch_fn, (ins, ins))
904
x = warmup_and_run_forward(traced, rand_a, rand_b)
905
self.assertLastGraphAllFused()
907
_atol = 5e-3 if data_type is torch.bfloat16 else 2e-3
909
if data_type is torch.bfloat16 and torch_fn not in gpu_only_fns:
910
y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
913
y = torch_fn(rand_a, rand_b)
915
self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)
930
def test_round_2(self):
932
return torch.round(x)
934
for data_type in [torch.float32, torch.double]:
935
a = torch.tensor([0.2, 1.6, 2.5, 3.5]).to(data_type)
936
traced = torch.jit.trace(round, (a))
937
x = warmup_and_run_forward(traced, a)
938
self.assertLastGraphAllFused()
940
self.assertEqual(x, y)
942
def test_rand_like(self):
945
def run_rand_like(x, y):
946
return torch.rand_like(torch.add(x, y))
948
for device in self.devices:
949
x = torch.rand(N, device=device)
950
traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False)
952
for data_type in self.dtypes:
953
_x = x.to(dtype=data_type)
954
x_v = warmup_and_run_forward(traced, _x, _x)
955
self.assertLastGraphAllFused()
957
x_np = x.cpu().numpy()
958
x1_mean = np.mean(x_np)
959
x2_mean = np.mean(x_np ** 2)
960
x3_mean = np.mean(x_np ** 3)
961
np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2)
962
np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2)
963
np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2)
967
return torch.max(2 * x, 2 * y)
970
return torch.min(2 * x, 2 * y)
972
tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1)))
973
tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1)))
975
for data_type in self.dtypes:
976
x = torch.tensor([np.nan]).to(dtype=data_type)
977
y = torch.tensor([1.0]).to(dtype=data_type)
979
assert np.isnan(warmup_and_run_forward(tmin, x, y).float().item())
980
assert np.isnan(warmup_and_run_forward(tmin, y, x).float().item())
981
self.assertLastGraphAllFused()
982
assert np.isnan(warmup_and_run_forward(tmax, x, y).float().item())
983
assert np.isnan(warmup_and_run_forward(tmax, y, x).float().item())
984
self.assertLastGraphAllFused()
986
def test_double_intrinsics(self):
988
return torch.pow(x, 7)
990
for device in self.devices:
991
x = torch.rand(10, dtype=torch.double, device=device)
992
traced = torch.jit.trace(do_pow, (x))
993
x = warmup_and_run_forward(traced, x)
994
self.assertLastGraphAllFused()
996
def test_remainder(self):
997
def run_remainder(x, y):
998
c = torch.remainder(torch.add(x, y), x)
1001
for data_type in self.dtypes:
1002
a = torch.rand(1024, dtype=data_type)
1003
b = torch.rand(1024, dtype=data_type)
1004
zeros = torch.zeros(1024, dtype=data_type)
1005
cc = np.array(1024, dtype=float)
1007
nans = torch.from_numpy(cc).to(dtype=data_type)
1010
zeros1 = torch.zeros(1024, dtype=data_type)
1011
zeros2 = torch.zeros(1024, dtype=data_type)
1013
traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1014
x = warmup_and_run_forward(traced, a, b)
1015
self.assertLastGraphAllFused()
1016
y = run_remainder(a, b)
1017
if data_type is torch.bfloat16:
1018
self.assertEqual(x, y, atol=4e-3, rtol=2e-3)
1020
self.assertEqual(x, y)
1023
traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1024
x = warmup_and_run_forward(traced, zeros, a)
1025
self.assertLastGraphAllFused()
1026
y = run_remainder(zeros, a)
1027
self.assertEqual(x, y)
1030
traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
1031
x = warmup_and_run_forward(traced, nans, a)
1032
self.assertLastGraphAllFused()
1033
y = run_remainder(nans, a)
1034
self.assertEqual(x, y)
1036
def test_multioutput(self):
1042
traced = torch.jit.trace(easy, (torch.zeros(1024)))
1044
a = torch.zeros(1024)
1045
b, c = warmup_and_run_forward(traced, a)
1046
self.assertLastGraphAllFused()
1049
np.testing.assert_allclose(b.numpy(), bp)
1050
np.testing.assert_allclose(c.numpy(), cp)
1052
def test_chunk(self):
1055
aaa, bbb = torch.chunk(y, 2)
1058
for data_type in self.dtypes:
1059
trace_input = torch.zeros(1024, 1024, dtype=data_type)
1060
traced = torch.jit.trace(easy, (trace_input))
1062
a = torch.zeros(32, 32, dtype=data_type)
1063
x = warmup_and_run_forward(traced, a)
1064
self.assertLastGraphAllFused()
1065
npr = a.float().numpy()
1067
npr_a, npr_b = np.array_split(npr2, 2)
1068
np.testing.assert_allclose(npr_a + npr_b, x.float().numpy())
1071
for device in self.devices:
1075
args_2 = [v + i for i, v in enumerate(args)]
1076
v = torch.cat(args_2, dim=_dim)
1079
for data_type in self.dtypes:
1082
values = [torch.zeros(M, N, dtype=data_type, device=device) for N in Ns]
1083
traced = torch.jit.trace(foo, values)
1085
x = warmup_and_run_forward(traced, *values)
1086
self.assertLastGraphAllFused()
1088
np.testing.assert_allclose(ref.cpu().float().numpy(), x.cpu().float().numpy())
1091
for _cur_dim in range(4):
1093
values = [torch.randn((2, 3, 4, 5), device=device).to(memory_format=torch.channels_last) for _ in range(10)]
1094
traced = torch.jit.trace(foo, values)
1096
x = warmup_and_run_forward(traced, *values)
1097
self.assertLastGraphAllFused()
1099
self.assertEqual(ref, x)
1105
@unittest.skip("cat is broken with fusion group inlining disabled")
1106
def test_cat_only(self):
1107
for device in self.devices:
1109
args_2 = [v + i for i, v in enumerate(args)]
1110
v = torch.cat(args_2, dim=1)
1115
values = [torch.zeros(M, N, device=device) for N in Ns]
1116
traced = torch.jit.trace(foo, values)
1118
x = warmup_and_run_forward(traced, *values)
1119
self.assertLastGraphAllFused()
1121
np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1123
def test_cat_negative_dim(self):
1124
for device in self.devices:
1126
v = torch.cat(args, dim=-1)
1131
values = [torch.randn(M, N, device=device) for N in Ns]
1132
traced = torch.jit.trace(foo, values)
1134
x = warmup_and_run_forward(traced, *values)
1135
self.assertLastGraphAllFused()
1137
np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1139
def test_cat_promote_inputs(self):
1140
for device in self.devices:
1142
v = torch.cat(args, dim=1)
1147
dtypes = [torch.half, torch.float32, torch.double]
1148
values = [torch.randn(M, N, device=device, dtype=dt) for N, dt in zip(Ns, dtypes)]
1149
traced = torch.jit.trace(foo, values)
1151
x = warmup_and_run_forward(traced, *values)
1152
self.assertLastGraphAllFused()
1154
np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1156
def test_cat_empty_tensors(self):
1157
for device in self.devices:
1159
v = torch.cat(args, dim=1)
1164
empty = torch.tensor([], device=device, dtype=torch.double)
1165
values = [empty] + [torch.randn(M, N, device=device) for N in Ns]
1166
traced = torch.jit.trace(foo, values)
1168
x = warmup_and_run_forward(traced, *values)
1169
self.assertLastGraphAllFused()
1171
np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1174
values = [empty for i in range(3)]
1175
traced = torch.jit.trace(foo, values)
1176
x = warmup_and_run_forward(traced, *values)
1177
self.assertLastGraphAllFused()
1179
np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1181
def test_cat_with_constant_dim(self):
1182
for device in self.devices:
1184
v1 = torch.cat(args, dim=1)
1185
v2 = torch.cat([v1], dim=1)
1188
empty = torch.tensor([], device=device, dtype=torch.float32)
1189
inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)]
1190
traced = torch.jit.trace(foo, inputs)
1192
x = warmup_and_run_forward(traced, *inputs)
1193
self.assertLastGraphAllFused()
1195
np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())
1197
def test_scalar(self):
1199
def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor:
1200
return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
1203
def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor:
1204
return torch.add(torch.add(x, y, alpha=a), z, alpha=b)
1206
for test in (test_float, test_int):
1207
for data_type in self.dtypes:
1208
x, y, z = (torch.rand(4, dtype=data_type) for i in range(3))
1211
r = test(x, y, z, a, b)
1212
self.assertEqual(r, x + y * a + z * b)
1214
def test_loop(self):
1216
def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor:
1218
for i in range(0, z):
1223
x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4)
1227
def test_slice(self):
1233
traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))
1235
a = torch.ones(1024, 1024)
1239
np.testing.assert_allclose(npr.numpy(), x.numpy())
1241
def test_unsqueeze(self, N=256):
1243
a = torch.unsqueeze(x, 0)
1244
b = torch.unsqueeze(y, 0)
1247
traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N)))
1249
a = torch.rand(N, N)
1251
npr = np.expand_dims(a, 0)
1253
np.testing.assert_allclose(npr, x.numpy())
1255
def _test_softmax(self, device):
1256
def test_softmax(x, y):
1257
a = F.softmax(x, dim=0, dtype=torch.float32)
1258
b = F.softmax(y, dim=0, dtype=torch.float32)
1259
c = F.softmax(x, dim=1, dtype=torch.float32)
1260
d = F.softmax(y, dim=1, dtype=torch.float32)
1261
return a + b + c + d
1263
def test_softmax_neg_index(x, y):
1264
a = F.softmax(x, dim=-2, dtype=torch.float32)
1265
b = F.softmax(y, dim=-2, dtype=torch.float32)
1266
c = F.softmax(x, dim=-1, dtype=torch.float32)
1267
d = F.softmax(y, dim=-1, dtype=torch.float32)
1268
return a + b + c + d
1270
def test_log_softmax(x, y):
1271
a = F.log_softmax(x, dim=0, dtype=torch.float32)
1272
b = F.log_softmax(y, dim=0, dtype=torch.float32)
1273
c = F.log_softmax(x, dim=1, dtype=torch.float32)
1274
d = F.log_softmax(y, dim=1, dtype=torch.float32)
1275
return a + b + c + d
1277
for test in (test_softmax, test_log_softmax, test_softmax_neg_index):
1278
for data_type in self.dtypes:
1279
old = torch._C._jit_set_texpr_reductions_enabled(True)
1280
traced_input = torch.randn(2, 3, dtype=data_type, device=device)
1281
traced = torch.jit.trace(test, (traced_input, traced_input))
1282
inp = torch.randn(2, 3, dtype=data_type, device=device)
1283
res = traced(inp, inp)
1285
ref = test(inp, inp)
1286
np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06)
1287
torch._C._jit_set_texpr_reductions_enabled(old)
1289
def test_softmax_cpu(self):
1290
self._test_softmax('cpu')
1292
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1293
@unittest.skip("global allocs are not supported yet.")
1294
def test_softmax_cuda(self):
1295
self._test_softmax('cuda')
1297
def test_half_gelu(self):
1298
devices = ["cuda"] if torch.cuda.is_available() else []
1301
def bias_gelu(bias, y):
1303
return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
1305
for device in devices:
1306
a = torch.rand(1024, dtype=torch.half, device=device)
1307
b = torch.rand(1024, dtype=torch.half, device=device)
1308
traced = torch.jit.trace(bias_gelu, (a, b))
1309
x = warmup_and_run_forward(traced, a, b)
1310
self.assertLastGraphAllFused()
1312
def test_half_bn_relu(self):
1313
devices = ["cuda"] if torch.cuda.is_available() else []
1316
y = torch.nn.functional.batch_norm(a, b, c)
1320
for device in devices:
1321
a = torch.rand(16, 16, dtype=torch.half, device=device)
1322
b = torch.rand(16, dtype=torch.half, device=device)
1323
c = torch.rand(16, dtype=torch.half, device=device)
1324
traced = torch.jit.trace(foo, (a, b, c))
1326
x = warmup_and_run_forward(traced, a, b, c)
1327
self.assertLastGraphAllFused()
1329
def test_exp_pow(self):
1331
def do_exp(x, y, z):
1332
return ((x * y) * 2) * torch.pow(z, 2)
1334
for device in self.devices:
1335
x = torch.rand(10, dtype=torch.double, device=device)
1336
y = torch.rand(10, dtype=torch.double, device=device)
1337
z = torch.rand(10, dtype=torch.double, device=device)
1338
traced = torch.jit.trace(do_exp, (x, y, z))
1339
x = warmup_and_run_forward(traced, x, y, z)
1340
self.assertLastGraphAllFused()
1342
def test_sin_pow(self):
1344
return torch.sin(torch.pow(x, 0))
1346
for data_type, shape in itertools.product(self.dtypes, [[3], [5], [10]]):
1347
x = torch.rand(shape, dtype=data_type)
1348
scripted = torch.jit.script(test)
1349
out = warmup_and_run_forward(scripted, x)
1350
self.assertLastGraphAllFused()
1351
self.assertEqual(out, test(x))
1353
def test_transpose(self):
1356
return x.transpose(0, 1) + y + z
1357
x = torch.rand(4, 5, 2, 3)
1358
y = torch.rand(5, 4, 2, 3)
1359
z = torch.rand(5, 4, 2, 3)
1362
np.testing.assert_allclose(ref.numpy(), res.numpy())
1364
def test_sliced_stride(self):
1368
x = torch.rand(16, 4, 2, 3)[::2]
1369
y = torch.rand(8, 4, 2, 3)
1370
z = torch.rand(8, 4, 2, 3)
1373
np.testing.assert_allclose(ref.numpy(), res.numpy())
1375
@unittest.skip("dynamic shapes are not quite there yet")
1376
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1377
def test_dynamic_shape(self):
1378
with num_profiled_runs(2):
1382
x, y, z = (torch.rand(4, 8).cuda() for _ in range(3))
1384
_ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
1386
np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())
1389
x = torch.rand(4, 8).cuda()
1390
y = torch.rand(1, 8).cuda()
1391
z = torch.rand(4, 1).cuda()
1393
xn, yn, zn = (t.cpu().numpy() for t in (x, y, z))
1394
np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)
1397
x = torch.rand(4, 8).cuda()
1398
y = torch.rand(4, 8).cuda()
1399
z = torch.rand(5, 8).cuda()
1402
except RuntimeError as e:
1403
assert "The size of tensor a (4) must match" in e.args[0]
1412
@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
1413
def test_guard_fails(self):
1417
r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
1418
r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
1419
r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
1420
r4 = test(*[torch.rand(7).cuda() for _ in range(3)])
1422
def test_bitwise_ops(self):
1432
def run_lshift(x, y):
1435
def run_rshift(x, y):
1438
fns = {run_and, run_or, run_xor, run_lshift, run_rshift}
1440
for device in self.devices:
1442
a = torch.ones(128, dtype=torch.int32, device=device)
1443
b = torch.zeros(128, dtype=torch.int32, device=device)
1444
inp = torch.ones(128, dtype=torch.int32, device=device)
1445
traced = torch.jit.trace(fn, (inp, inp))
1446
x = warmup_and_run_forward(traced, a, b)
1447
self.assertLastGraphAllFused()
1449
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
1451
def test_where(self):
1452
def run_where(x, y):
1453
return torch.where(torch.gt(x, y), x, y)
1455
for data_type in self.dtypes:
1456
a = torch.rand(1024, dtype=data_type)
1457
b = torch.rand(1024, dtype=data_type)
1458
zeros = torch.zeros(1024, dtype=data_type)
1459
traced = torch.jit.trace(run_where, (zeros, zeros))
1460
x = warmup_and_run_forward(traced, a, b)
1461
self.assertLastGraphAllFused()
1463
np.testing.assert_allclose(x.float().numpy(), y.float().numpy())
1465
def test_multi_rand(self):
1466
for device in self.devices:
1468
y = torch.rand_like(x)
1469
return (x + y) - (y - x)
1473
for data_type in self.dtypes:
1474
if data_type is torch.bfloat16:
1476
a = torch.rand(4, dtype=data_type, device=device)
1477
scripted = torch.jit.script(test)
1478
out = warmup_and_run_forward(scripted, a)
1479
self.assertLastGraphAllFused()
1480
assert torch.allclose(out, 2 * a, atol=_atol, rtol=_rtol)
1482
def test_mask(self):
1484
return x.unsqueeze(1) == 0
1486
for d in self.devices:
1487
for data_type in self.dtypes:
1488
x = torch.rand(4, dtype=data_type, device=d) > 0.5
1489
scripted = torch.jit.script(test)
1490
out = warmup_and_run_forward(scripted, x)
1491
self.assertLastGraphAllFused()
1492
assert torch.equal(out, test(x))
1494
def test_simple_add(self):
1495
val = torch._C._jit_get_te_generate_block_code()
1496
torch._C._jit_set_te_generate_block_code(True)
1497
fall_bk = torch._C._jit_texpr_fallback_allowed()
1498
torch._C._jit_texpr_set_fallback_allowed(True)
1501
return torch.add(a, b)
1503
a = torch.ones(256, 256)
1504
b = torch.ones(256, 256)
1505
traced = torch.jit.trace(simple,
1506
(torch.ones(256, 256), torch.ones(256, 256)))
1508
f_test = np.full((256, 256), 2, dtype=float)
1509
np.testing.assert_allclose(f.numpy(), f_test)
1510
torch._C._jit_set_te_generate_block_code(val)
1511
torch._C._jit_texpr_set_fallback_allowed(fall_bk)
1513
def test_strided_output_preserved(self):
1519
x = torch.as_strided(x, (2, 3), (1, 2))
1525
foo_script = torch.jit.script(foo)
1528
out_s = foo_script(x, x)
1529
out_eager = foo(x, x)
1530
self.assertEqual(out_s, out_eager)
1531
self.assertEqual(out_s.stride(), out_eager.stride())
1532
self.assertLastGraphAllFused()
1535
N, C, H, W, = 2, 3, 4, 5
1536
x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last)
1537
foo_script = torch.jit.script(foo)
1540
out_s = foo_script(x, x)
1541
out_eager = foo(x, x)
1542
self.assertEqual(out_s, out_eager)
1543
self.assertEqual(out_s.stride(), out_eager.stride())
1544
self.assertLastGraphAllFused()
1546
def test_alias_analysis_module(self):
1547
class AliasModule(nn.Module):
1550
torch.manual_seed(1337)
1551
self.a = torch.randn(128, 128)
1552
self.b = torch.randn(128, 128)
1553
self.c = torch.randn(128, 128)
1555
def forward(self, x, y, z):
1561
x = torch.randn(128, 128)
1563
def getModule(script):
1566
return torch.jit.script(am)
1569
am = getModule(False)
1570
am_s = getModule(True)
1572
test = am_s(x, x, x)
1573
torch.testing.assert_close(ref, test)
1580
test = am_s(x, x, x)
1582
torch.testing.assert_close(ref, test)
1584
def test_alias_analysis_inputs(self):
1585
class AliasModule(nn.Module):
1588
torch.manual_seed(1337)
1589
self.a = torch.randn(128, 128)
1590
self.b = torch.randn(128, 128)
1591
self.c = torch.randn(128, 128)
1593
def forward(self, x, y, z):
1599
def getModule(script):
1602
return torch.jit.script(am)
1604
am = getModule(False)
1605
am_s = getModule(True)
1607
torch.manual_seed(1337)
1608
x = torch.randn(128, 128)
1611
torch.manual_seed(1337)
1612
x = torch.randn(128, 128)
1613
test = am_s(x, x, x)
1615
torch.testing.assert_close(ref, test)
1617
def test_alias_analysis_input_and_module(self):
1618
class AliasModule(nn.Module):
1621
torch.manual_seed(1337)
1622
self.a = torch.randn(128, 128)
1623
self.b = torch.randn(128, 128)
1624
self.c = torch.randn(128, 128)
1626
def forward(self, x, y, z):
1632
def getModule(script):
1635
return torch.jit.script(am)
1637
am = getModule(False)
1638
am_s = getModule(True)
1640
torch.manual_seed(1337)
1641
x = torch.randn(128, 128)
1645
torch.manual_seed(1337)
1646
x = torch.randn(128, 128)
1648
test = am_s(x, x, x)
1650
torch.testing.assert_close(ref, test)
1652
def test_multiple_outputs(self):
1653
for device in self.devices:
1658
t6 = torch.unsqueeze(t_next, 1)
1660
return (t7, t5, t_next)
1662
for data_type in self.dtypes:
1663
a = torch.rand(20, 20, dtype=data_type, device=device)
1664
b = torch.rand(20 * 29, dtype=data_type, device=device).as_strided([20], [29])
1665
c = torch.ones(20, dtype=torch.int64, device=device)
1666
traced = torch.jit.trace(foo, (a, b, c))
1668
exp = traced(a, b, c)
1669
exp = traced(a, b, c)
1670
self.assertEqual(ref, exp)
1672
def test_propagated_mem_layout(self):
1679
def foo_multi_outputs(a, b, c):
1683
return (t7, t5, t_next)
1685
def foo_multi_outputs_i_nhwc_o_nchw(a, b, c):
1689
t8 = t7.to(memory_format=torch.contiguous_format)
1690
return (t8, t7, t5, t_next)
1692
def run_foo_case(foo, a, b, c):
1693
traced_contiguous = torch.jit.trace(foo, (a, b, c))
1695
exp = traced_contiguous(a, b, c)
1696
exp = traced_contiguous(a, b, c)
1697
self.assertEqual(ref, exp)
1699
mem_layouts = list(itertools.product([torch.contiguous_format, torch.channels_last], repeat=3))
1700
shapes = [(2, 3, 4, 5), (2, 1, 1, 5), (1, 1, 1, 1)]
1701
permutes = [(0, 3, 2, 1), (0, 3, 1, 2)]
1702
funcs = [foo, foo_multi_outputs, foo_multi_outputs_i_nhwc_o_nchw]
1703
configs = itertools.product(funcs, shapes, mem_layouts, permutes)
1704
for strategy in ["STATIC", "DYNAMIC"]:
1705
old_strategy = torch.jit.set_fusion_strategy([(strategy, 10)])
1706
for _func, _shape, _mem_layouts, _permute in configs:
1707
a = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[0])
1708
b = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[1])
1709
c = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[2])
1710
run_foo_case(_func, a, b, c)
1712
a = a.permute(dims=_permute)
1713
b = b.permute(dims=_permute)
1714
c = c.permute(dims=_permute)
1715
run_foo_case(_func, a, b, c)
1717
torch.jit.set_fusion_strategy(old_strategy)
1719
if __name__ == '__main__':