12
import torch.nn.functional as F
13
from torch.testing import FileCheck
22
torch._C._jit_set_profiling_executor(True)
23
torch._C._get_graph_executor_optimize(True)
25
from itertools import combinations, permutations, product
26
from textwrap import dedent
28
from jit.test_fuser_common import TestFuserCommon
39
from torch.testing._internal.common_device_type import (
40
instantiate_device_type_tests,
45
from torch.testing._internal.common_jit import JitCommonTestCase
46
from torch.testing._internal.common_methods_invocations import op_db
47
from torch.testing._internal.common_utils import (
48
enable_profiling_mode_for_profiling_tests,
58
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
59
from torch.testing._internal.jit_utils import (
61
get_traced_sample_variant_pairs,
63
NoTracerWarnContextManager,
67
set_fusion_group_inlining,
68
TensorExprTestOptions,
73
FUSION_GROUP = "prim::TensorExprGroup"
74
LLVM_ENABLED = torch._C._llvm_enabled()
78
"prim::AutogradAllNonZero",
79
"prim::AutogradAllZero",
80
"prim::ListConstruct",
84
def strip_profiling_nodes(nodes):
85
profiling_opcodes = {"prim::BailoutTemplate", "prim::BailOut"}
86
return [n for n in nodes if n.kind() not in profiling_opcodes]
89
def warmup_forward(f, *args, profiling_count=2):
90
for i in range(profiling_count):
96
@contextlib.contextmanager
97
def texpr_reductions_enabled():
98
old = torch._C._jit_set_texpr_reductions_enabled(True)
102
torch._C._jit_set_texpr_reductions_enabled(old)
105
@contextlib.contextmanager
106
def texpr_enable_strategy(strategy):
107
old = torch._C._jit_set_fusion_strategy(strategy)
111
torch._C._jit_set_fusion_strategy(old)
114
@contextlib.contextmanager
115
def inline_fusion_groups():
116
old_inlining = torch._C._debug_get_fusion_group_inlining()
117
torch._C._debug_set_fusion_group_inlining(True)
121
torch._C._debug_set_fusion_group_inlining(old_inlining)
124
class TestTEFuser(JitTestCase):
127
self.tensorexpr_options = TensorExprTestOptions()
132
fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)]
133
self.old_fusion_strategy = torch._C._jit_set_fusion_strategy(fusion_strategy)
135
self.devices = ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
149
self.dtypes = self.int_dtypes + self.fp_dtypes
152
self.tensorexpr_options.restore()
153
torch._C._jit_set_fusion_strategy(self.old_fusion_strategy)
156
def assertAllFused(self, graph, except_for=None):
157
except_for = except_for if except_for is not None else set()
161
"prim::RequiresGradCheck",
162
"prim::TensorExprDynamicGuard",
166
def autodiff_guard(node):
167
if node.kind() != "aten::all":
169
inps = list(node.inputs())
170
if len(inps) != 1 or inps[0].node().kind() != "prim::ListConstruct":
172
li_inps = list(inps[0].node().inputs())
173
for li_inp in li_inps:
174
if li_inp.node().kind() in (
175
"prim::AutogradAllNonZero",
176
"prim::AutogradAllZero",
182
return node.kind() in guards or autodiff_guard(node)
184
for node in graph.block().nodes():
185
if node.kind() == "prim::Constant":
188
self.assertFalse(guard_found)
191
if node.kind() in except_for:
193
if node.kind() == "prim::If":
194
self.assertTrue(is_guard(node.prev()))
196
self.assertTrue(False, "Found unexpected node:" + node.kind())
198
self.assertTrue(guard_found)
200
def assertLastGraphAllFused(self):
201
self.assertAllFused(torch.jit.last_executed_optimized_graph())
203
def findFusionGroups(self, graph):
205
for n in graph.nodes():
206
if n.kind() == FUSION_GROUP:
207
result.append(n.g("Subgraph"))
209
for block in n.blocks():
210
result += self.findFusionGroups(block)
213
def test_typecheck(self):
216
def fused_kernel(a, b):
219
scripted = self.checkScript(fused_kernel, (a, a))
220
graph = scripted.graph_for(a, a)
222
fusion_groups = self.findFusionGroups(graph)
223
self.assertEqual(len(fusion_groups), 1)
231
self.assertEqual(scripted(a, a), fused_kernel(a, a))
233
def test_sum_simple(self):
238
with texpr_reductions_enabled():
239
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
241
scripted = self.checkScript(func, (a,))
242
self.assertLastGraphAllFused()
247
def test_sum_dim(self):
249
return x.sum((0,)) * 2
252
return x.sum((-2,)) * 2
254
with texpr_reductions_enabled():
255
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
257
scripted = self.checkScript(func, (a,))
258
self.assertLastGraphAllFused()
259
scripted = self.checkScript(func_neg, (a,))
260
self.assertLastGraphAllFused()
262
def test_sum_keepdim_cast(self):
264
return x.sum((0,), keepdim=True, dtype=torch.double) * 2
266
with texpr_reductions_enabled():
267
a = torch.tensor(list(range(0, 15)), dtype=torch.float, device="cpu")
270
self.checkScript(func, (a,))
271
self.assertLastGraphAllFused()
274
for device in self.devices:
279
a = torch.randn(5, device=device)
280
scripted = self.checkScript(func, (a,))
281
self.assertLastGraphAllFused()
283
def test_unsqueeze_size_calculation(self):
284
for device in self.devices:
294
torch.rand(20, 28, device=device, requires_grad=True),
295
torch.rand(20, device=device),
297
scripted = self.checkScript(foo, inputs)
298
self.assertAllFused(scripted.graph_for(*inputs))
300
def test_zero_element_tensors(self):
301
for device in self.devices:
303
def decode(sin_t, cos_t):
304
theta = torch.atan2(sin_t.float(), cos_t.float())
307
sin = torch.zeros(0, device=device)
308
cos = torch.zeros(0, device=device)
310
ge = self.checkScript(decode, inputs)
312
def test_arg_configurations_smoke(self):
313
if self.dynamic_shapes:
314
self.skipTest("TODO: chunk dynamic shapes")
320
for device in self.devices:
323
z1, z2 = (x + y).chunk(2, dim=1)
326
x = torch.randn(4, 4, dtype=torch.float, device=device)
327
y = torch.randn(4, 4, dtype=torch.float, device=device)
328
traced_f = torch.jit.trace(f, (x, y))
329
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
331
def test_broadcast(self):
332
for device in self.devices:
334
def scaleshift(x, scale, shift):
335
return x * scale + shift
338
torch.randn(4, 4, dtype=torch.float, device=device),
339
torch.randn(4, dtype=torch.float, device=device),
340
torch.randn(4, dtype=torch.float, device=device),
342
self.checkScript(scaleshift, inputs)
344
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
345
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
347
GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on"
349
def test_cuda_half(self):
350
x = torch.randn(4, 4, dtype=torch.half, device="cuda")
351
y = torch.randn(4, 4, dtype=torch.half, device="cuda")
353
funcs = [self.fn_test_comparison_gt_lt, self.fn_test_relu, self.fn_test_exp]
356
inputs = (x.float(), y.float())
357
fusion_inputs = (x, y)
359
local_inputs = [t.clone().requires_grad_() for t in inputs]
360
local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
363
fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False)
364
outputs = fn(*local_inputs)
365
fusion_outputs = fusion(*local_fusion_inputs)
366
outputs_half = [t.half() for t in outputs]
367
self.assertEqual(outputs_half, fusion_outputs)
370
for output, fusion_output in zip(outputs_half, fusion_outputs):
371
grads = torch.autograd.grad(
372
output.float().sum(),
377
fusion_grads = torch.autograd.grad(
383
grads_half = [t.half() for t in grads]
384
self.assertEqual(grads_half, fusion_grads)
386
def test_checks_cat_inputs(self):
388
with set_fusion_group_inlining(True):
389
for device in self.devices:
394
return torch.cat([x + 2 * x + x**2, y + 4 * y + y**3], dim=0)
398
x = torch.randn(2, 4, dtype=torch.float, device=device)
399
y = torch.randn(1, 4, dtype=torch.float, device=device)
401
scripted = self.checkScript(f, (x, y))
402
self.assertEqual(scripted(x, y).shape, (3, 4))
403
self.assertAllFused(scripted.graph_for(x, y))
405
def test_chunk(self):
406
if self.dynamic_shapes:
407
self.skipTest("TODO: chunk dynamic shapes")
409
for device in self.devices:
412
a, b, c = x.chunk(3, 1)
415
inputs = [torch.randn(10, 6, dtype=torch.float, device=device)]
417
self.checkScript(fn, inputs)
418
self.assertLastGraphAllFused()
420
def test_chunk_correctness(self):
421
if self.dynamic_shapes:
422
self.skipTest("TODO: chunk dynamic shapes")
424
for device in self.devices:
427
x0, x1, x2, x3 = x.chunk(4, 0)
428
return x0 + x1 + x2 + x3
431
x0, x1, x2, x3 = x.chunk(4, 1)
432
return x0 + x1 + x2 + x3
435
x0, x1, x2, x3 = x.chunk(4, 2)
436
return x0 + x1 + x2 + x3
438
fns = [chunk_4_0, chunk_4_1, chunk_4_last]
441
torch.randn(4, 4, 4, dtype=torch.float, device=device),
443
torch.randn(12, 8, 16, dtype=torch.float, device=device),
445
torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(
450
for tensor in tensors:
452
self.checkScript(fn, [tensor])
453
self.assertLastGraphAllFused()
455
def test_chunk_distributes(self):
456
if self.dynamic_shapes:
457
self.skipTest("TODO: chunk dynamic shapes")
459
if self.dynamic_shapes:
460
self.skipTest("TODO: chunk dynamic shapes")
462
for device in self.devices:
465
z1, z2 = (x + y).chunk(2, dim=1)
468
x = torch.randn(4, 4, dtype=torch.float, device=device)
469
y = torch.randn(4, 4, dtype=torch.float, device=device)
471
ge = self.checkTrace(f, (x, y))
472
graph = ge.graph_for(x, y)
476
FileCheck().check("with " + FUSION_GROUP + "_").check_count(
477
"ConstantChunk", 1, exactly=True
480
def test_chunk_motion_deduplicates_inputs(self):
481
if self.dynamic_shapes:
482
self.skipTest("TODO: chunk dynamic shapes")
484
for device in self.devices:
496
inputs = [torch.tensor([1.1, 1.2], device=device, dtype=torch.float)]
497
for func in [func1, func2]:
498
self.checkScript(func, inputs)
499
self.assertLastGraphAllFused()
501
def test_chunk_multiple(self):
502
if self.dynamic_shapes:
503
self.skipTest("TODO: chunk dynamic shapes")
505
for device in self.devices:
509
z1, z2 = z.chunk(2, 2)
510
x1, x2, x3 = x.chunk(3, 1)
511
y1, y2 = y.chunk(2, 0)
512
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
515
torch.randn(5, 2, 3, dtype=torch.float, device=device),
516
torch.randn(5, 6, 3, dtype=torch.float, device=device),
517
torch.randn(10, 2, 3, dtype=torch.float, device=device),
518
torch.randn(5, 2, 6, dtype=torch.float, device=device),
521
ge = self.checkScript(fn, inputs)
522
self.assertAllFused(ge.graph_for(*inputs))
524
def test_minmax(self):
525
for device in self.devices:
528
return torch.max(2 * a, b)
531
return torch.min(2 * a, b)
533
a = torch.randn(4, 4, dtype=torch.float)
534
b = torch.randn(4, 4, dtype=torch.float)
535
nan = torch.tensor(float("nan"), dtype=torch.float)
537
for f, inputs, device in product(
538
(tmax, tmin), ([a, b], [a, nan], [b, nan]), self.devices
540
inputs = [t.to(device) for t in inputs]
541
s = self.checkScript(f, inputs)
542
self.assertAllFused(s.graph_for(*inputs))
544
def test_clamp(self):
545
for device in self.devices:
548
return torch.clamp(a + b, min=0, max=2)
551
return torch.clamp(a + b, min=0, max=float("inf"))
553
def funcNegInf(a, b):
554
return torch.clamp(a + b, min=float("-inf"), max=0)
556
def funcOptMin(a, b):
557
return torch.clamp(a + b, max=2)
559
def funcOptMax(a, b):
560
return torch.clamp(a + b, min=0)
562
a = torch.randn(4, 4, dtype=torch.float, device=device, requires_grad=True)
563
b = torch.randn(4, 4, dtype=torch.float, device=device)
564
nan = torch.tensor(float("nan"), dtype=torch.float, device=device)
566
funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
567
for f, inputs in product(funcs, [[a, b], [a, nan]]):
569
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
571
s.graph_for(inp1, inp2),
572
except_for={"aten::size", "aten::_size_if_not_equal"},
575
with enable_profiling_mode_for_profiling_tests():
576
warmup_backward(c.sum())
577
graph = backward_graph(s)
580
except_for={"aten::Float", "aten::_grad_sum_to_size"}.union(
585
def test_clamp_double(self):
586
for device in self.devices:
588
def clamp_double(x, eta: float):
589
return 1 - x.clamp(eta, 1 - eta)
591
x = torch.tensor([1.0, 1.0], dtype=torch.double, device=device)
593
s = self.checkScript(
596
profiling=ProfilingMode.PROFILING,
600
self.assertAllFused(s.graph_for(x, eta), except_for={"aten::sub"})
602
def test_clamp_int(self):
603
for device in self.devices:
605
def clamp_int(x, eta: int):
606
return x.clamp(0, eta)
608
x = torch.tensor([1, 1], device=device)
610
s = self.checkScript(clamp_int, (x, eta), profiling=ProfilingMode.PROFILING)
611
self.assertAllFused(s.graph_for(x, eta))
613
def test_add_bool(self):
614
sizes = [(1,), (2,), (4, 4)]
615
for device, size in product(self.devices, sizes):
620
x = torch.randint(0, 2, size, dtype=torch.bool, device=device)
621
y = torch.randint(0, 2, size, dtype=torch.bool, device=device)
622
z = torch.randint(0, 2, size, dtype=torch.bool, device=device)
623
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
624
self.assertAllFused(ge.graph_for(x, y, z))
626
def test_mul_bool(self):
627
for device in self.devices:
632
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
633
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
634
z = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
636
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
637
self.assertAllFused(ge.graph_for(x, y, z))
639
def test_div_bool(self):
640
for device in self.devices:
645
x = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
646
y = torch.randint(0, 2, (4, 4), dtype=torch.bool, device=device)
647
z = torch.ones_like(x, dtype=torch.bool, device=device)
649
ge = self.checkTrace(f, (x, y, z), inputs_require_grads=False)
650
self.assertAllFused(ge.graph_for(x, y, z))
652
def test_bitwise_ops(self):
654
return lambda x, y, z: fn(fn(x, y), z)
663
devices = self.devices
664
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
666
x = self.data_for(dtype, device)
667
y = self.data_for(dtype, device)
668
z = self.data_for(dtype, device)
677
t = torch.jit.trace(fn, (x, y, z))
678
self.assertEqual(ref, t(x, y, z))
679
self.assertAllFused(t.graph_for(x, y, z))
680
except Exception as e:
682
" ".join(["Failed:", str(dtype), op.__name__, device])
685
def test_minmax_int_ops(self):
687
return lambda x, y, z: fn(fn(x, y), z)
689
binary_ops = [torch.min, torch.max]
690
devices = self.devices
691
for dtype, op, device in product(self.int_dtypes, binary_ops, devices):
693
x = self.data_for(dtype, device)
694
y = self.data_for(dtype, device)
695
z = self.data_for(dtype, device)
704
t = torch.jit.trace(fn, (x, y, z))
705
self.assertEqual(ref, t(x, y, z))
706
self.assertAllFused(t.graph_for(x, y, z))
707
except Exception as e:
709
" ".join(["Failed:", str(dtype), op.__name__, device])
712
def test_comparison_eq_ne(self):
713
for device in self.devices:
716
mask = (x == 0).type_as(x)
718
mask = (x != 0).type_as(x)
722
x = torch.randn(4, 4, dtype=torch.float, device=device)
723
y = torch.randn(4, 4, dtype=torch.float, device=device)
725
ge = self.checkTrace(f, (x, y))
726
self.assertAllFused(ge.graph_for(x, y))
729
def fn_test_comparison_gt_lt(x, y):
730
mask = (x > 0).type_as(x)
732
mask = (x < 0).type_as(x)
736
def test_comparison_gt_lt(self):
737
for device in self.devices:
738
x = torch.randn(4, 4, dtype=torch.float, device=device)
739
y = torch.randn(4, 4, dtype=torch.float, device=device)
741
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
742
self.assertAllFused(ge.graph_for(x, y))
744
def test_comparison_ge_le(self):
745
for device in self.devices:
748
mask = (x >= 0).type_as(x)
750
mask = (x <= 0).type_as(x)
754
x = torch.randn(4, 4, dtype=torch.float, device=device)
755
y = torch.randn(4, 4, dtype=torch.float, device=device)
757
ge = self.checkTrace(f, (x, y))
758
self.assertAllFused(ge.graph_for(x, y))
759
x.requires_grad_(True)
760
y.requires_grad_(True)
765
"prim::BroadcastSizes",
766
"aten::_size_if_not_equal",
770
def test_addcmul(self):
771
for device in self.devices:
772
t = torch.randn(1, 4, dtype=torch.float, device=device)
773
t1 = torch.randn(4, 1, dtype=torch.float, device=device)
774
t2 = torch.randn(1, 4, dtype=torch.float, device=device)
777
return t.addcmul(t + 1, t2, value=0.1)
779
ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
780
graph = ge.graph_for(t, t1, t2)
781
fusion_groups = self.findFusionGroups(graph)
782
self.assertEqual(len(fusion_groups), 1)
783
FileCheck().check("aten::add(").check("aten::addcmul(").run(
784
str(fusion_groups[0])
794
for device in self.devices:
795
start = torch.randn(4, 1, dtype=torch.float, device=device)
796
end = torch.randn(1, 4, dtype=torch.float, device=device)
797
weight = torch.tensor(0.5, dtype=torch.float, device=device)
800
def foo_weight_scalar(start, end):
801
return torch.lerp(start + 1, end, 0.5)
804
def foo_weight_tensor(start, end):
805
return torch.lerp(start + 1, end, weight)
807
ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
808
graph = ge_weight_scalar.graph_for(start, end)
809
self.assertAllFused(graph)
816
def test_concat(self):
818
with set_fusion_group_inlining(True):
819
for device in self.devices:
820
hx = torch.randn(3, 20, dtype=torch.float, device=device)
821
cx = torch.randn(3, 20, dtype=torch.float, device=device)
824
return torch.cat((hx + cx, hx * cx))
826
ge = self.checkTrace(foo, (hx, cx))
827
graph = ge.graph_for(hx, cx)
828
self.assertAllFused(graph)
832
def test_remove_output_used_only_in_size(self):
833
for device in self.devices:
840
scripted_f = torch.jit.script(test_fuse)
841
x = torch.ones(1, requires_grad=True, device=device)
842
y = torch.ones(1, requires_grad=True, device=device)
843
warmup_forward(scripted_f, x, y, profiling_count=3)
844
g = scripted_f.graph_for(x, y)
845
diff_nodes = g.findAllNodes("prim::DifferentiableGraph")
846
self.assertEqual(len(diff_nodes), 1)
847
g = diff_nodes[0].g("Subgraph")
848
if_nodes = [n for n in g.nodes() if n.kind() == "prim::If"]
849
self.assertEqual(len(if_nodes), 1)
852
self.assertEqual(len(list(if_nodes[0].outputs())), 1)
854
def test_concat_invariant(self):
855
for device in self.devices:
861
w = torch.cat([x1, y1])
864
x = torch.randn(2, 2, dtype=torch.float, device=device)
865
y = torch.randn(2, 2, dtype=torch.float, device=device)
866
z = torch.randn(4, 2, dtype=torch.float, device=device)
867
ge = self.checkTrace(fn, (x, y, z))
868
graph = ge.graph_for(x, y, z)
869
self.assertAllFused(graph, except_for={"aten::add"})
874
def fn_test_exp(x, y):
875
return (x + 0.5 * y).exp()
878
for device in self.devices:
879
x = torch.randn(4, 4, dtype=torch.float, device=device)
880
y = torch.randn(4, 4, dtype=torch.float, device=device)
882
ge = self.checkTrace(self.fn_test_exp, (x, y))
883
self.assertAllFused(ge.graph_for(x, y))
885
def test_threshold(self):
886
for device in self.devices:
889
return torch.threshold(x, 0, -10) + x + x + x
891
x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device=device)
892
scripted = self.checkScript(f, (x,))
893
self.assertAllFused(scripted.graph_for(x))
895
def test_scalar_arg(self):
896
for device in self.devices:
898
def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
899
return p * (x * x + x)
901
x = torch.randn(4, 4, dtype=torch.float, device=device)
903
scripted = self.checkScript(fn_test_scalar_arg, (x, p))
904
self.assertAllFused(scripted.graph_for(x, p))
906
x.requires_grad_(True)
910
def fn_test_scalar_arg_requires_grad(
911
x: torch.Tensor, p: float
913
return p * (x * x + x)
915
scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
920
scripted.graph_for(x, p),
923
"prim::BroadcastSizes",
924
"aten::_size_if_not_equal",
928
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
929
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
930
def test_fusion_reuse_multi_gpu(self):
935
torch.randn(4, 4, dtype=torch.float),
936
torch.randn(4, 4, dtype=torch.float),
938
inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
939
inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
942
ge = self.checkScript(fn, inputs_cpu)
943
self.assertAllFused(ge.graph_for(*inputs_cpu))
949
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
950
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
951
def test_kernel_cache_multi_gpu(self):
956
x_out = x * x * x * x * x
957
y_out = y * y * y * y * y
958
z_out = z * z * z * z * z
959
return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
962
torch.randn(4, 4, dtype=torch.float),
963
torch.randn(4, 4, dtype=torch.float, device="cuda:0"),
964
torch.randn(4, 4, dtype=torch.float, device="cuda:1"),
967
prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
971
ge = self.checkScript(fn, inputs)
972
self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 3, True)
973
new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
978
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
979
def test_nonzero_device_cuda(self):
980
device = "cuda:" + str(1)
981
x = torch.tensor([0.4], dtype=torch.float, device=device)
982
y = torch.tensor([0.7], dtype=torch.float, device=device)
985
return torch.sigmoid(torch.tanh(x * (x + y) + x))
987
ge = self.checkTrace(doit, (x, y))
988
self.assertAllFused(ge.graph_for(x, y))
991
for device in self.devices:
992
inputs = get_lstm_inputs(device, training=True)
993
module = self.checkScript(LSTMCellS, inputs)
995
module.graph_for(inputs), except_for={"prim::TupleConstruct"}
998
def test_lstm_concat(self):
1000
with set_fusion_group_inlining(True):
1001
for device in self.devices:
1002
inputs = get_lstm_inputs(device)
1003
ge = self.checkTrace(LSTMCellC, inputs)
1004
graph = ge.graph_for(*inputs)
1005
except_nodes = {"prim::TupleConstruct", "aten::linear"}
1007
if self.dynamic_shapes:
1008
except_nodes = except_nodes.union(
1009
{"aten::add", "prim::ConstantChunk"}
1011
self.assertAllFused(ge.graph_for(*inputs), except_for=except_nodes)
1015
def test_lstm_gates_permutations(self):
1016
for device in self.devices:
1019
choices = ["x.mm(w_ih.t())", "hx.mm(w_hh.t())", "b_ih", "b_hh"]
1022
def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
1023
gates = {} + {} + {} + {}
1024
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
1025
return ingate * forgetgate * cellgate * outgate
1028
for permutation in permutations(choices, len(choices)):
1029
code = template.format(*permutation)
1031
exec(code, globals(), scope)
1032
cu = torch.jit.CompilationUnit(code)
1033
fusion_group_len = 2 if self.dynamic_shapes else 1
1034
inputs = get_lstm_inputs(device, training=False)
1035
self.assertEqual(cu.cell(*inputs), scope["cell"](*inputs))
1036
forward_graph = cu.cell.graph_for(*inputs)
1037
self.assertGraphContainsExactly(
1038
forward_graph, FUSION_GROUP, fusion_group_len
1042
def test_lstm_traced(self):
1043
for device in self.devices:
1044
inputs = get_lstm_inputs(device)
1045
ge = self.checkTrace(LSTMCellF, inputs)
1046
graph = ge.graph_for(*inputs)
1047
fusion_groups = self.findFusionGroups(graph)
1049
fusion_group_len = 2 if self.dynamic_shapes else 1
1050
self.assertEqual(len(fusion_groups), fusion_group_len)
1052
if not self.dynamic_shapes:
1054
f.check("aten::sigmoid").check("aten::tanh").run(
1055
str(fusion_groups[0 if not self.dynamic_shapes else 1])
1058
def test_milstm(self):
1059
if self.dynamic_shapes:
1060
self.skipTest("don't run conv with dynamic shapes")
1062
for device in self.devices:
1063
inputs = get_milstm_inputs(device, training=True)
1064
module = self.checkScript(MiLSTMCell, inputs)
1065
forward_graph = module.graph_for(*inputs)
1067
fusion_group_len = 2 if self.dynamic_shapes else 1
1068
self.assertGraphContainsExactly(
1069
forward_graph, FUSION_GROUP, fusion_group_len, consider_subgraphs=True
1071
FileCheck().check("DifferentiableGraph").check("TupleConstruct").check_next(
1073
).check(FUSION_GROUP).run(str(forward_graph))
1074
hy, cy = module(*inputs)
1075
warmup_backward((hy + cy).sum())
1077
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
1078
@unittest.skip("rand_like is not supported yet")
1079
def test_rand_cuda(self):
1080
class M(torch.jit.ScriptModule):
1081
__constants__ = ["d"]
1083
def __init__(self) -> None:
1085
self.d = torch.device("cuda")
1087
@torch.jit.script_method
1088
def create(self, x):
1089
return x * x + x + torch.rand_like(x)
1091
x = torch.zeros([3, 4, 5], dtype=torch.float, device="cuda")
1095
self.assertNotEqual(out1, out2)
1096
self.assertTrue(torch.all(out1 >= 0))
1097
self.assertTrue(torch.all(out1 < 1))
1098
self.assertTrue(torch.all(out2 >= 0))
1099
self.assertTrue(torch.all(out2 < 1))
1100
self.assertAllFused(m.create.graph_for(x))
1103
def fn_test_relu(x, y):
1104
return F.relu(x + 0.5 * y)
1106
def test_relu(self):
1107
for device in self.devices:
1108
x = torch.randn(4, 4, dtype=torch.float, device=device)
1109
y = torch.randn(4, 4, dtype=torch.float, device=device)
1111
ge = self.checkTrace(self.fn_test_relu, (x, y))
1112
self.assertAllFused(ge.graph_for(x, y))
1115
for device in self.devices:
1121
return F.relu(torch.erf(x) - torch.erfc(x))
1123
x = torch.randn(4, 4, dtype=torch.float, device=device)
1124
ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING)
1125
self.assertAllFused(ge.graph_for(x))
1126
x.requires_grad_(True)
1127
ge = self.checkScript(fn_test_erf, (x,), profiling=ProfilingMode.PROFILING)
1128
self.assertAllFused(
1132
"prim::BroadcastSizes",
1133
"aten::_size_if_not_equal",
1137
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
1138
@unittest.skip("rand_like is not supported yet")
1139
def test_rand_broadcast_cuda(self):
1140
def fn_test_rand(x, y):
1141
r = torch.rand_like(y)
1146
def fn_test_rand2(x, y):
1147
r = torch.rand_like(y)
1150
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
1151
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
1152
script_f = torch.jit.script(fn_test_rand)
1153
warmup_forward(script_f, x, y)
1154
out = script_f(x, y)
1155
self.assertAllFused(script_f.graph_for(x, y))
1156
x.requires_grad_(True)
1157
out = script_f(x, y)
1158
self.assertAllFused(
1159
script_f.graph_for(x, y),
1162
"prim::BroadcastSizes",
1163
"aten::_size_if_not_equal",
1168
x = torch.ones(4, 4, dtype=torch.float, device="cuda")
1169
y = torch.ones(4, dtype=torch.float, device="cuda")
1170
script_f = torch.jit.script(fn_test_rand2)
1171
warmup_forward(script_f, x, y)
1172
out = script_f(x, y)
1173
self.assertEqual(out[0, :] + torch.zeros(4, 4, device="cuda"), out)
1175
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
1176
@unittest.skip("rand_like is not supported yet")
1177
def test_rand_diamond(self):
1178
def fn_test_diamond(x, y):
1179
r = torch.rand_like(y)
1184
x = torch.randn(4, 4, dtype=torch.float, device="cuda")
1185
y = torch.randn(4, 4, dtype=torch.float, device="cuda")
1186
script_f = torch.jit.script(fn_test_diamond)
1187
warmup_forward(script_f, x, y)
1188
out = script_f(x, y)
1189
self.assertEqual(out, x + y)
1191
def test_scalar(self):
1195
x = torch.tensor(0.1, dtype=torch.float, device="cpu")
1196
y = torch.tensor(1, dtype=torch.float, device="cpu")
1197
ge = self.checkScript(fn, (x, y))
1198
self.assertAllFused(ge.graph_for(x, y))
1200
def test_inlined_optimized_graph(self):
1203
return torch.relu(x + x)
1206
foo(torch.rand([4, 4]))
1209
foo(torch.rand([10]))
1212
foo(torch.rand([2, 2, 2]))
1214
g = torch.jit.last_executed_optimized_graph()
1216
FileCheck().check_count("prim::If", 1, exactly=True).check(
1219
torch._C._jit_pass_inline(g)
1222
f.check("prim::If").check("prim::TensorExpr")
1225
def test_small_constant(self):
1226
for device in self.devices:
1228
def fn_test_small_constant(x, y):
1229
return (1e-8 * x + 5e-9 * y) * 1e8
1231
x = torch.randn(4, 4, dtype=torch.float, device=device)
1232
y = torch.randn(4, 4, dtype=torch.float, device=device)
1234
ge = self.checkTrace(fn_test_small_constant, (x, y))
1235
self.assertAllFused(ge.graph_for(x, y))
1243
def test_tensor_scalar_ops(self):
1244
for device in self.devices:
1251
def should_fuse_scalar(x, z):
1255
inputs = [torch.randn(2, 2, dtype=torch.float, device=device)]
1256
ge = self.checkScript(should_fuse, inputs)
1257
graph = ge.graph_for(*inputs)
1258
fusion_groups = self.findFusionGroups(graph)
1259
self.assertEqual(len(fusion_groups), 1)
1260
FileCheck().check("aten::add").check("aten::mul").run(str(fusion_groups[0]))
1263
torch.randn(2, 2, dtype=torch.float, device=device),
1264
torch.tensor(3.0, dtype=torch.float, device=device),
1266
ge = self.checkScript(should_fuse_scalar, inputs)
1270
torch.randn(2, 2, dtype=torch.float, device=device),
1271
torch.tensor(7.0, dtype=torch.float, device=device),
1273
self.assertEqual(ge(*inputs), should_fuse_scalar(*inputs))
1275
self.assertGraphContainsExactly(
1276
ge.graph_for(*inputs), FUSION_GROUP, 1, consider_subgraphs=True
1279
def test_where_and_typing(self):
1280
for device in self.devices:
1284
res = torch.where(mask, x, y)
1287
x = torch.randn(4, 4, dtype=torch.double, device=device)
1288
y = torch.randn(4, 4, dtype=torch.double, device=device)
1290
script_f = self.checkScript(f, (x, y))
1291
self.assertAllFused(
1292
script_f.graph_for(x, y), except_for={"prim::TupleConstruct"}
1295
def test_disabled(self):
1296
old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
1297
torch._C._jit_override_can_fuse_on_cpu(False)
1302
x = torch.randn(4, dtype=torch.float, device="cpu")
1303
s = self.checkScript(fn, (x,))
1305
self.assertEqual(len(self.findFusionGroups(g)), 0)
1307
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state)
1309
def data_for(self, dtype, device="cuda", size=None):
1311
v = torch.arange(1, 3, dtype=torch.float, device=device)
1313
v = torch.rand(*size, device=device)
1314
if dtype == torch.bool:
1316
elif dtype in [torch.qint8, torch.quint8, torch.qint32]:
1317
return torch.quantize_per_tensor(v, 0.1, 1, dtype=dtype)
1321
def test_torch_to(self):
1325
return x.to(torch.float)
1327
foo(torch.tensor([3.0], dtype=torch.float))
1328
foo(torch.tensor([3.0], dtype=torch.float))
1329
FileCheck().check_not("TensorExpr").run(
1330
torch.jit.last_executed_optimized_graph()
1335
def foo(x, dtype: int):
1338
foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1339
foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1340
FileCheck().check_not("TensorExpr").run(
1341
torch.jit.last_executed_optimized_graph()
1346
def foo(x, dtype: int):
1347
return x.to(pin_memory=True)
1349
foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1350
foo(torch.tensor([3.0], dtype=torch.float), torch.int)
1351
FileCheck().check_not("TensorExpr").run(
1352
torch.jit.last_executed_optimized_graph()
1356
if torch.cuda.is_available():
1360
return x.to(device="cuda")
1362
foo(torch.tensor([3.0], dtype=torch.float))
1363
foo(torch.tensor([3.0], dtype=torch.float))
1364
FileCheck().check_not("TensorExpr").run(
1365
torch.jit.last_executed_optimized_graph()
1368
sizes = [(1, 4), (4, 4)]
1378
class MyMod(torch.nn.Module):
1379
def __init__(self, dtype):
1383
def forward(self, x):
1384
return x.to(self.dtype)
1387
for dtype, output_dtype, device, size in product(
1388
dtypes, dtypes, self.devices, sizes
1391
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1393
if dtype == output_dtype:
1396
x = self.data_for(dtype, device, size=size)
1397
mod = MyMod(output_dtype)
1398
ref = mod.forward(x)
1400
mod = torch.jit.freeze(torch.jit.script(mod.eval()))
1401
warmup_forward(mod.forward, x)
1402
self.assertEqual(ref, mod.forward(x))
1403
self.assertLastGraphAllFused()
1405
@unittest.skip("Temporarily disabled")
1406
def test_masked_fill(self):
1418
sizes = [(2,), (4, 4)]
1419
for self_dtype, device, scalar_val, size in product(
1420
dtypes, self.devices, [0.4, 3], sizes
1422
input_v = self.data_for(self_dtype, device, size=size)
1423
mask = self.data_for(torch.bool, device, size=size)
1425
def fn(input_v, mask):
1426
return torch.masked_fill(input_v, mask, scalar_val)
1428
ref = fn(input_v, mask)
1430
t = torch.jit.trace(fn, (input_v, mask))
1431
torch.testing.assert_close(ref, t(input_v, mask))
1432
self.assertLastGraphAllFused()
1433
except Exception as e:
1446
def test_isnan(self):
1449
inputs = [x, torch.tensor([float("nan"), 0.5])]
1461
for inp, device, dtype in product(inputs, self.devices, dtypes):
1463
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1465
inp = inp.to(device=device, dtype=dtype)
1467
f = torch.jit.trace(lambda x: x.isnan(), (inp,))
1468
warmup_forward(f, inp)
1469
self.assertEqual(f(inp), inp.isnan())
1470
self.assertLastGraphAllFused()
1471
except Exception as e:
1473
" ".join(["Failed:", str(dtype), "isnan", device])
1476
def test_gelu(self):
1478
return lambda x, approximate: fn(x, approximate)
1483
sizes = [(1,), (2,), (4, 4)]
1484
for dtype, op, device, size in product(
1485
self.dtypes, unary_ops, self.devices, sizes
1488
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1491
x = self.data_for(dtype, device, size=size)
1492
cond = self.data_for(torch.bool, device)
1501
t = torch.jit.trace(fn, (x, cond))
1502
torch.testing.assert_close(ref, t(x, cond))
1503
self.assertAllFused(t.graph_for(x, cond))
1504
except Exception as e:
1506
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
1509
def test_unary_ops(self):
1510
with torch._jit_internal._disable_emit_hooks():
1513
return lambda x: fn(x)
1560
lambda x: torch.threshold(x, 0, -10),
1564
gpu_only = {torch.erf, torch.erfc}
1565
sizes = [(1,), (2,), (4, 4)]
1566
for dtype, op, device, size in product(
1567
self.dtypes, unary_ops, self.devices, sizes
1570
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1573
if dtype == torch.bfloat16 and op == torch.round:
1575
if op in gpu_only and device == "cpu":
1578
x = self.data_for(dtype, device, size=size)
1587
t = torch.jit.trace(fn, (x,))
1588
torch.testing.assert_close(ref, t(x))
1589
self.assertAllFused(t.graph_for(x))
1590
except Exception as e:
1593
["Failed:", str(dtype), op.__name__, device, str(size)]
1597
def test_binary_ops(self):
1599
return lambda x, y: fn(x, y)
1610
lambda x, y: torch.lerp(x, y, 0.5),
1620
lambda x, y: y.type_as(x),
1626
devices = self.devices
1627
for dtype, op, device in product(self.dtypes, binary_ops, devices):
1628
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1631
x = self.data_for(dtype, device)
1632
y = self.data_for(dtype, device)
1641
t = torch.jit.trace(fn, (x, y))
1642
self.assertEqual(ref, t(x, y))
1643
if op not in fp_only or dtype.is_floating_point:
1644
self.assertAllFused(t.graph_for(x, y))
1645
except Exception as e:
1647
" ".join(["Failed:", str(dtype), op.__name__, device])
1650
def test_binary_scalar_ops(self):
1652
return lambda x, y: fn(x, y)
1655
graph(%x : {dtype_x}, %y : {dtype_y}):
1676
dtypes = ["int", "float", "bool"]
1677
values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]}
1678
devices = self.devices
1679
for dtype_x, dtype_y, op, device in product(
1680
dtypes, dtypes, binary_ops, devices
1682
code = ir_template.format(**locals())
1686
graph = torch._C.parse_ir(code)
1687
for x, y in product(values[dtype_x], values[dtype_y]):
1688
ref = torch._C._jit_interpret_graph(graph, (x, y))
1695
k = torch._C._te.TensorExprKernel(graph)
1696
except Exception as e:
1698
" ".join(["Compilation failed:", device, str(code)])
1702
for x, y in product(values[dtype_x], values[dtype_y]):
1703
ref = torch._C._jit_interpret_graph(graph, (x, y))
1706
self.assertEqual(ref, res)
1707
except Exception as e:
1710
["Failed at runtime:", device, str(x), str(y), str(code)]
1714
def test_matmul(self):
1715
if self.dynamic_shapes:
1716
self.skipTest("don't run conv with dynamic shapes")
1719
return torch.matmul(x, y)
1723
[[128, 128], [128, 128]],
1724
[[10, 10], [10, 10]],
1725
[[1, 16], [16, 128]],
1727
[[128], [128, 128]],
1731
[[10, 3, 4], [10, 4, 5]],
1732
[[10, 3, 4], [4, 5]],
1739
skip_is_fused_check_sizes = [
1741
"[[128], [128, 128]]",
1744
"[[10, 3, 4], [4]]",
1745
"[[10, 3, 4], [10, 4, 5]]",
1746
"[[10, 3, 4], [4, 5]]",
1748
for dtype, size, device in product(self.dtypes, sizes, devices):
1749
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1752
size_x, size_y = size
1753
x = self.data_for(dtype, device, size=size_x)
1754
y = self.data_for(dtype, device, size=size_y)
1756
except Exception as e:
1762
t = torch.jit.trace(fn, (x, y))
1764
self.assertEqual(ref, t(x, y))
1765
if str(size) not in skip_is_fused_check_sizes:
1766
self.assertAllFused(t.graph_for(x, y))
1767
except Exception as e:
1768
raise RuntimeError(" ".join(["Failed:", str(dtype), device])) from e
1770
def test_binary_tensor_scalar_ops(self):
1771
with torch._jit_internal._disable_emit_hooks():
1773
def apply_with_scalar(fn, scalar):
1774
return lambda x: fn(x, scalar)
1790
devices = self.devices
1793
scalars = [1.5, 3, 0, -2.0, -1]
1794
for dtype, op, device, scalar in product(
1795
self.dtypes, binary_ops, devices, scalars
1797
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1800
x = self.data_for(dtype, device)
1801
fn = apply_with_scalar(op, scalar)
1809
t = torch.jit.trace(fn, (x))
1810
self.assertEqual(ref, t(x))
1811
self.assertAllFused(t.graph_for(x))
1812
except Exception as e:
1814
" ".join(["Failed:", str(dtype), op.__name__, device])
1817
def test_binary_div_ops(self):
1818
def apply_with_scalar(fn, scalar):
1819
return lambda x: fn(x, scalar)
1826
devices = self.devices
1829
scalars = [1.5, 3, -2.0, -1]
1830
for dtype, op, device, scalar in product(
1831
self.dtypes, binary_ops, devices, scalars
1833
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1836
x = self.data_for(dtype, device)
1837
fn = apply_with_scalar(op, scalar)
1845
t = torch.jit.trace(fn, (x))
1846
self.assertEqual(ref, t(x))
1847
except Exception as e:
1849
f"Failed: {dtype} {op.__name__} {device} {scalar}"
1852
def test_binary_pow(self):
1853
def apply_with_scalar(fn, scalar):
1854
return lambda x: fn(x, scalar)
1868
scalars = [1.5, 3, 0, -2.0, -1]
1869
for dtype, op, device, scalar in product(
1870
dtypes, binary_ops, self.devices, scalars
1872
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1875
x = self.data_for(dtype, device)
1876
fn = apply_with_scalar(op, scalar)
1884
t = torch.jit.trace(fn, (x))
1885
self.assertEqual(ref, t(x))
1886
self.assertAllFused(t.graph_for(x))
1887
except Exception as e:
1889
" ".join(["Failed:", str(dtype), op.__name__, device])
1892
def test_ternary_ops(self):
1894
return lambda x, y, z: fn(x, y, z)
1900
devices = self.devices
1901
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
1902
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1905
x = self.data_for(dtype, device)
1906
y = self.data_for(dtype, device)
1907
z = self.data_for(dtype, device)
1916
t = torch.jit.trace(fn, (x, y, z))
1917
self.assertEqual(ref, t(x, y, z))
1918
self.assertAllFused(t.graph_for(x, y, z))
1919
except Exception as e:
1921
" ".join(["Failed:", str(dtype), op.__name__, device])
1924
def test_ternary_norm_ops(self):
1926
return lambda x, y, z: fn(x, y, z)
1931
devices = self.devices
1932
for dtype, op, device in product(self.dtypes, ternary_ops, devices):
1933
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1936
x = self.data_for(dtype, device, size=[5, 3, 128, 128])
1937
y = self.data_for(dtype, device, size=[3])
1938
z = self.data_for(dtype, device, size=[3])
1947
t = torch.jit.trace(fn, (x, y, z))
1948
self.assertEqual(ref, t(x, y, z))
1949
self.assertAllFused(t.graph_for(x, y, z))
1950
except Exception as e:
1952
" ".join(["Failed:", str(dtype), op.__name__, device])
1956
"FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure"
1958
def test_list_ops(self):
1960
return lambda x, y, z: fn([x * x, y * y, z * z])
1962
devices = self.devices
1966
for dtype, op, device in product(self.dtypes, list_ops, devices):
1967
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
1970
x = self.data_for(dtype, device, size=[5, 4, 1, 7])
1971
y = self.data_for(dtype, device, size=[5, 4, 1, 7])
1972
z = self.data_for(dtype, device, size=[5, 4, 1, 7])
1981
t = torch.jit.trace(fn, (x, y, z))
1982
self.assertEqual(ref, t(x, y, z))
1983
self.assertAllFused(t.graph_for(x, y, z))
1984
except Exception as e:
1986
" ".join(["Failed:", str(dtype), op.__name__, device])
1989
def test_where_ops(self):
1991
return lambda cond, x, y: fn(cond, x, y)
1995
lambda cond, x, y: torch.where(cond, x, 3.1415),
1996
lambda cond, x, y: torch.where(cond, 42, y),
1998
devices = self.devices
1999
for dtype, op, device in product(self.dtypes, ops, devices):
2000
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
2003
cond = self.data_for(torch.bool, device)
2004
x = self.data_for(dtype, device)
2005
y = self.data_for(dtype, device)
2007
ref = fn(cond, x, y)
2014
t = torch.jit.trace(fn, (cond, x, y))
2015
self.assertEqual(ref, t(cond, x, y))
2016
self.assertAllFused(t.graph_for(cond, x, y))
2017
except Exception as e:
2019
" ".join(["Failed:", str(dtype), op.__name__, device])
2022
def test_unsupported_dtypes(self):
2023
for device in self.devices:
2028
unsupported_dtypes = [
2037
for dtype in unsupported_dtypes:
2039
x = self.data_for(dtype, device)
2046
t = torch.jit.trace(fn, (x,))
2047
self.assertEqual(ref, t(x))
2048
self.assertEqual(len(self.findFusionGroups(t.graph_for(x))), 0)
2050
def test_superslomo(self):
2051
devices = self.devices.copy()
2052
if not LLVM_ENABLED:
2053
devices.remove("cpu")
2054
for device in devices:
2060
def eager(t0, t1, t2, t3, t4):
2061
t5 = torch.mul(t0, t4)
2062
t6 = torch.mul(t2, t3)
2063
t7 = torch.mul(t6, t1)
2064
t9 = torch.add(t5, t7)
2065
t11 = torch.add(t0, t6)
2066
ft_p = torch.div(t9, t11)
2067
return (ft_p, t11, t9, t6)
2069
t0 = torch.rand(1, 6, 352, 352, device=device).transpose(0, 1)
2070
t1 = torch.rand(6, 3, 352, 352, device=device)
2071
t2 = torch.rand(6, device=device)[None, None, None, :].permute(3, 0, 1, 2)
2072
t3 = torch.rand(6, 1, 352, 352, device=device)
2073
t4 = torch.rand(6, 3, 352, 352, device=device)
2074
inputs = [t0, t1, t2, t3, t4]
2076
script = torch.jit.script(eager)
2078
for pair in zip(script(*inputs), eager(*inputs)):
2080
torch.testing.assert_close(test, ref)
2081
self.assertAllFused(
2082
script.graph_for(*inputs), except_for={"prim::TupleConstruct"}
2085
def test_sub_gt_and(self):
2086
for device in self.devices:
2088
def eager(t1, t2, t3, t4, t: float):
2091
k = (w > t) & (h > t)
2092
assert k.dtype == torch.bool
2101
t = torch.rand(8, dtype=torch.float, device=device)
2102
scripted = self.checkScript(eager, (t, t, t, t, 0.1))
2104
@skipIfTorchDynamo("too slow")
2105
def test_chunk_mul_one(self):
2106
if self.dynamic_shapes:
2107
self.skipTest("TODO: chunk dynamic shapes")
2109
for device in self.devices:
2112
z, y, w = torch.chunk(x, 3, -1)
2115
x = torch.rand(64, 1, 3072, dtype=torch.float, device=device)
2117
script = self.checkScript(eager, (x,))
2119
def test_eq_unsqueeze_type_as(self):
2120
for device in self.devices:
2124
mask = torch.unsqueeze(mask, -1)
2128
a = torch.rand(1, 64, 1024, device=device, dtype=torch.float)
2129
b = torch.randint(-2, 2, (1, 64), device=device, dtype=torch.long)
2130
script = self.checkScript(eager, (a, b))
2132
def test_neg_pow(self):
2133
def eager_tt(a: torch.Tensor, b: torch.Tensor):
2134
return torch.neg(torch.pow(a, b))
2136
def eager_ts(a: torch.Tensor, b: float):
2137
return torch.neg(torch.pow(a, b))
2139
def eager_st(a: float, b: torch.Tensor):
2140
return torch.neg(torch.pow(a, b))
2142
a = torch.rand(1, dtype=torch.float)
2143
b = torch.rand(1, dtype=torch.float)
2145
script = self.checkScript(eager_tt, (a, b))
2148
script = self.checkScript(eager_ts, (a, s))
2150
script = self.checkScript(eager_st, (s, b))
2153
@unittest.skipIf(not LLVM_ENABLED, "Too slow to run with the TE interpreter")
2154
def test_conv2d_depthwise(self):
2155
if self.dynamic_shapes:
2156
self.skipTest("don't run conv with dynamic shapes")
2158
def eager(input, weight, bias):
2159
return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=72)
2161
input = torch.rand((1, 72, 56, 56), dtype=torch.float)
2162
weight = torch.rand((72, 1, 3, 3), dtype=torch.float)
2163
bias = torch.rand((72), dtype=torch.float)
2165
script = self.checkScript(eager, (input, weight, bias))
2166
self.assertAllFused(script.graph_for(input, weight, bias))
2168
def test_conv2d(self):
2169
if self.dynamic_shapes:
2170
self.skipTest("don't run conv with dynamic shapes")
2172
def eager(input, weight, bias):
2173
return torch.conv2d(input, weight, bias, stride=1, padding=1, groups=1)
2175
input = torch.rand((1, 64, 56, 56), dtype=torch.float)
2176
weight = torch.rand((64, 64, 3, 3), dtype=torch.float)
2177
bias = torch.rand((64), dtype=torch.float)
2179
script = self.checkScript(eager, (input, weight, bias))
2180
FileCheck().check_not("TensorExpr").run(
2181
torch.jit.last_executed_optimized_graph()
2184
def test_type_as_cat(self):
2185
with inline_fusion_groups():
2188
return torch.cat((x, y.type_as(x)), dim=1)
2190
dtypes = self.dtypes.copy()
2192
dtypes.remove(torch.float16)
2193
dtypes.remove(torch.bfloat16)
2194
for dtype1, dtype2 in product(dtypes, dtypes):
2195
x = torch.randint(2, (1, 13)).to(dtype1)
2196
zero = torch.tensor([[0]]).to(dtype2)
2197
one = torch.tensor([[1]]).to(dtype2)
2198
script = torch.jit.trace(eager, (x, zero))
2200
torch.testing.assert_close(script(x, zero), eager(x, zero))
2201
torch.testing.assert_close(script(x, one), eager(x, one))
2202
self.assertAllFused(script.graph_for(x, one))
2204
def test_to_device(self):
2206
return x.to(device="cpu").relu()
2209
script = self.checkScript(eager, (x,))
2210
self.assertAllFused(script.graph_for(x))
2212
def test_dims(self):
2214
return x / (y + 0.0001)
2216
x = torch.linspace(-1, 1, 768, dtype=torch.float32).as_strided(
2217
(1, 1, 768), (768, 1, 1)
2219
y = torch.tensor([[[2.0]]], dtype=torch.float32)
2220
script = self.checkScript(eager, (x, y))
2221
self.assertAllFused(script.graph_for(x, y))
2223
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
2224
def test_channels_last_dims_dynamic(self):
2226
return x + (y + 0.0001)
2228
indices = [0, 1, 2, 3]
2230
for i in range(0, len(indices) + 1):
2231
for subset in combinations(indices, i):
2238
inp = torch.rand(size).to(memory_format=torch.channels_last).cuda()
2239
with texpr_enable_strategy([("DYNAMIC", 20)]):
2240
foo_s = torch.jit.trace(eager, (inp, inp))
2242
out = foo_s(inp, inp)
2243
out_eager = eager(inp, inp)
2244
self.assertEqual(out_eager, out)
2245
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
2246
g = torch.jit.last_executed_optimized_graph()
2247
FileCheck().check("TensorExpr").run(g)
2249
def test_exhaust_specializations(self):
2250
with texpr_enable_strategy([("STATIC", 1)]):
2257
foo(torch.rand([2, 2]))
2260
foo(torch.rand([4, 4, 4]))
2262
g = torch.jit.last_executed_optimized_graph()
2263
torch._C._jit_pass_inline(g)
2265
FileCheck().check_count("TensorExpr", 2, exactly=True).run(g)
2267
def test_unsqueeze_var_dim(self):
2268
def eager(x, y, z: int):
2269
return x * torch.unsqueeze(y, dim=z)
2271
x = torch.rand(4, 4, 64).permute(1, 0, 2)
2272
y = torch.rand(4, 4)
2274
script = self.checkScript(eager, (x, y, z))
2276
def _test_fwd_bwd(self, fn):
2277
x = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True)
2278
xs = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True)
2279
script = torch.jit.script(fn)
2282
g0 = torch.rand_like(y)
2288
with torch.no_grad():
2293
torch.testing.assert_close(y, ys)
2295
def test_relu_fwd_bwd(self):
2297
return torch.relu(x * 1.01)
2299
self._test_fwd_bwd(eager)
2301
def test_hardswish_fwd_bwd(self):
2303
return F.hardswish(x) * 1.01
2305
self._test_fwd_bwd(eager)
2307
def test_hardsigmoid_fwd_bwd(self):
2309
return F.hardsigmoid(x) * 1.01
2311
self._test_fwd_bwd(eager)
2313
def test_cat_graph_opt(self):
2315
return torch.log(torch.cat([x, y, z]))
2318
foo, (torch.rand([5, 5]), torch.rand([2, 5]), torch.rand([1, 5]))
2321
self.assertLastGraphAllFused()
2323
def test_dynamic_cat(self):
2324
with inline_fusion_groups():
2328
xs: List[torch.Tensor], ys: List[torch.Tensor], zs: List[torch.Tensor]
2331
torch.cat([x, torch.cat([y, z], dim=-1)], dim=-1)
2332
for x, y, z in zip(xs, ys, zs)
2337
xs = [torch.ones(21) for _ in range(N)]
2340
ys = [torch.ones(N - i) for i in range(N)]
2341
zs = [torch.ones(i) for i in range(N)]
2344
def test_scalar_only_inputs(self):
2345
def eager(b: float):
2349
script = self.checkScript(eager, (1.0,))
2351
def test_cat_2k_args(self):
2352
with inline_fusion_groups():
2355
return torch.relu(torch.cat([x for _ in range(2000)]))
2358
trace = self.checkTrace(eager, (x,))
2359
fusion_groups = self.findFusionGroups(trace.graph_for(x))
2360
self.assertEqual(len(fusion_groups), 0)
2362
def test_adaptive_avg_pool2d(self):
2365
with inline_fusion_groups():
2368
return torch.nn.functional.adaptive_avg_pool2d(x, (2, 2))
2371
return torch.nn.functional.adaptive_avg_pool2d(x, (2))
2373
x = torch.randn(4, 4, 4)
2374
for foo in [foo1, foo2]:
2375
f = torch.jit.trace(foo, (x,))
2376
kernel = torch._C._te.TensorExprKernel(f.graph)
2378
self.assertEqual(kernel.run((x,)), correct_val)
2380
def test_unrolled_cat(self):
2381
with inline_fusion_groups():
2384
ret = torch.empty(0)
2385
for i in range(x.shape[0]):
2386
ret = torch.cat([ret, x[i].relu()])
2389
script = torch.jit.script(eager)
2394
x = torch.ones(1, 1)
2398
torch.testing.assert_close(eager(x), script(x))
2402
x = torch.ones((8, 1))
2403
torch.testing.assert_close(eager(x), script(x))
2405
@skipIfTorchDynamo("too slow")
2406
@unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan")
2407
@unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans")
2408
def test_batch_norm(self):
2410
trace = torch.jit.trace(fn, args)
2411
self.assertAllFused(trace.graph_for(*args))
2414
torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True)
2417
return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu()
2419
def bn_no_weight(i, x):
2420
return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu()
2422
def bn_no_bias(i, x):
2423
return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu()
2425
def bn_neither(i, x):
2426
return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu()
2428
for device in self.devices:
2429
i = torch.randn(4, 16, 32, 40, device=device)
2430
x = torch.randn(16, device=device)
2431
for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]:
2434
def test_profiler(self):
2439
args = [torch.randn(4) for _ in range(3)]
2440
with torch.autograd.profiler.profile() as prof:
2443
self.assertIn("fused_mul_add", prof.table())
2445
def test_skip_grad_in_check(self):
2450
inp = torch.rand([4, 4])
2454
inp.requires_grad_(True)
2455
with torch.inference_mode():
2458
g = torch.jit.last_executed_optimized_graph()
2459
torch._C._jit_pass_inline(g)
2460
torch._C._jit_pass_inline(g)
2461
FileCheck().check_count("prim::If", 1, exactly=True).run(g)
2463
def test_dynamic_shapes(self):
2464
from functools import partial
2471
lambda n: R(n, n).transpose(0, 1),
2472
lambda n: R(n + 1, n + 1, 2)[:n, n, 0],
2473
lambda n: R(n, n, 2)[:, :, 0],
2474
lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last),
2477
with texpr_enable_strategy([("DYNAMIC", 20)]):
2480
return torch.sigmoid(torch.tanh(x))
2482
foo.__disable_jit_function_caching__ = True
2485
return torch.tanh(x + y)
2487
fi.__disable_jit_function_caching__ = True
2490
return torch.tanh(x + y) + z
2492
fum.__disable_jit_function_caching__ = True
2494
funcs = [foo, fi, fum]
2495
with inline_fusion_groups():
2496
for device in self.devices:
2497
I = partial(torch.randint, 0, 100, device=device)
2498
R = partial(torch.randn, device=device)
2500
for i, func in enumerate(funcs):
2502
for j, gen in enumerate(gen_tensor):
2503
inps = (gen(n), gen(n), gen(n))
2504
func_s = torch.jit.trace(func, inps, check_trace=False)
2505
torch._C._jit_pass_erase_shape_information(func_s.graph)
2507
x, y, z = gen(n), gen(n), gen(n)
2510
for incr in range(3):
2511
func_s(*[gen(n + 1) for _ in range(3)])
2513
g = torch.jit.last_executed_optimized_graph()
2514
torch._C._jit_pass_inline(g)
2515
torch._C._jit_pass_dce(g)
2518
FileCheck().check_count(
2519
"TensorExprDynamicGuard", 1, exactly=True
2521
self.assertEqual(func(*inps), func_s(*inps))
2524
inps = (gen(n), gen(n), gen(n))
2525
foo_s = torch.jit.trace(foo, inps)
2526
torch._C._jit_pass_erase_shape_information(foo_s.graph)
2528
for gen in gen_tensor:
2530
foo_s(*[gen(n + i) for _ in range(3)])
2531
inps = (gen(n), gen(n), gen(n))
2532
self.assertEqual(foo_s(*inps), foo(*inps))
2533
g = torch.jit.last_executed_optimized_graph()
2534
torch._C._jit_pass_inline(g)
2535
torch._C._jit_pass_dce(g)
2536
FileCheck().check_count(
2537
"TensorExprDynamicGuard", len(gen_tensor), exactly=True
2540
@unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA")
2541
def test_autocast_up(self):
2543
y = x._autocast_to_full_precision(True, True)
2547
x = torch.rand((2, 2), dtype=torch.half, device="cuda")
2548
scr = torch.jit.script(f)
2551
self.assertLastGraphAllFused()
2553
@unittest.skipIf(not RUN_CUDA, "half-precision NNC fusion requires CUDA")
2554
def test_autocast_down(self):
2556
y = torch.sigmoid(x)
2557
z = y._autocast_to_reduced_precision(True, True, torch.half, torch.half)
2560
x = torch.rand((2, 2), dtype=torch.float, device="cuda")
2561
scr = torch.jit.script(f)
2564
self.assertLastGraphAllFused()
2566
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2567
def test_to_dtype(self):
2569
y = torch.sigmoid(x)
2570
z = y._autocast_to_reduced_precision(True, True, torch.half, torch.bfloat16)
2571
h = z._autocast_to_full_precision(True, True)
2572
i = h.to(dtype=torch.bfloat16)
2573
j = i.to(dtype=torch.float32)
2576
x = torch.rand((2, 2), dtype=torch.float32)
2577
scr = torch.jit.trace(f, x)
2580
self.assertLastGraphAllFused()
2581
self.assertEqual(f(x), scr(x), atol=4e-3, rtol=4e-3)
2583
bf_x = torch.rand((2, 2), dtype=torch.bfloat16)
2584
bf_scr = torch.jit.trace(f, bf_x)
2587
graph = bf_scr.graph_for(bf_x)
2588
fusion_groups = self.findFusionGroups(graph)
2589
self.assertEqual(len(fusion_groups), 2)
2590
self.assertEqual(f(bf_x), bf_scr(bf_x), atol=4e-3, rtol=4e-3)
2592
def test_with_strict_fusion(self):
2594
with torch.jit.strict_fusion():
2597
scripted = self.checkScript(success, (torch.rand([4]),))
2598
g = torch.jit.last_executed_optimized_graph()
2599
FileCheck().check_not("aten::add").check("prim::TensorExprGroup").run(g)
2602
with torch.jit.strict_fusion():
2603
return x + x + torch.rand([4]) + 3
2605
with self.assertRaises(Exception) as error_out:
2606
foo_s = torch.jit.script(foo)
2607
foo_s(torch.rand([4]))
2608
foo_s(torch.rand([4]))
2609
print(torch.jit.last_executed_optimized_graph())
2610
fc = FileCheck().check("Found unfused operators")
2611
fc.check("aten::rand(SymInt[] size")
2612
fc.check("torch.rand([4]").run(str(error_out.exception))
2614
with warnings.catch_warnings(record=True) as warns:
2615
foo(torch.rand([4]))
2617
FileCheck().check("Only works in script mode").run(str(warns[0]))
2619
def test_autodiff(x):
2620
with torch.jit.strict_fusion():
2621
return torch.rand([4]) + x + x + x
2623
foo_s = torch.jit.script(test_autodiff)
2624
inp = torch.rand([4], requires_grad=True)
2625
with self.assertRaises(Exception) as error_out:
2628
f = FileCheck().check("unfused operators").check("aten::rand")
2629
f.run(str(error_out.exception))
2631
def test_separate_fusions(x, y):
2632
with torch.jit.strict_fusion():
2633
return x + x + x, y + y + y
2635
inp = torch.rand([4], requires_grad=True)
2636
with self.assertRaises(Exception) as error_out:
2638
foo_s = torch.jit.script(test_separate_fusions)
2641
f = FileCheck().check("Found multiple fusions")
2642
f.run(str(error_out.exception))
2644
def test_constant_chunk_shapes(self):
2651
if self.dynamic_shapes:
2652
self.skipTest("TODO: chunk dynamic shapes")
2654
for device in self.devices:
2658
z1, z2 = (x + y + r).chunk(2, dim=1)
2661
x = torch.randn(4, 4, dtype=torch.float, device=device)
2662
y = torch.randn(4, 4, dtype=torch.float, device=device)
2664
ge = self.checkTrace(f, (x, y))
2665
graph = ge.graph_for(x, y)
2668
FileCheck().check("with " + FUSION_GROUP + "_").check_count(
2669
"ConstantChunk", 1, exactly=True
2672
f_traced = torch.jit.trace(f, (x, y))
2676
res = f_traced(x, y)
2678
self.assertEqual(res, f(x, y))
2680
@unittest.skipIf(not RUN_CUDA_HALF, "half-precision NNC fusion requires CUDA")
2681
def test_pow_multiple_dtype(self):
2683
def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
2684
p = torch.sigmoid(p)
2688
x = torch.rand((2, 2), dtype=torch.half, device="cuda")
2692
script_fn = torch.jit.script(fn)
2696
self.assertEqual(ref, res)
2699
class TestTEFuserStatic(TestTEFuser):
2700
dynamic_shapes = False
2703
class TestTEFuserDynamic(TestTEFuser):
2704
dynamic_shapes = True
2728
"div.no_rounding_mode",
2729
"div.true_rounding",
2730
"div.floor_rounding",
2731
"div.trunc_rounding",
2761
"nn.functional.hardshrink",
2762
"nn.functional.hardsigmoid",
2763
"nn.functional.hardswish",
2764
"nn.functional.softplus",
2765
"nn.functional.hardtanh",
2766
"nn.functional.leaky_relu",
2767
"nn.functional.relu",
2768
"nn.functional.relu6",
2769
"nn.functional.softsign",
2770
"nn.functional.tanhshrink",
2771
"nn.functional.threshold",
2776
"remainder.autodiffed",
2809
"bool.channels_last",
2810
"byte.channels_last",
2811
"char.channels_last",
2812
"double.channels_last",
2813
"float.channels_last",
2814
"half.channels_last",
2815
"int.channels_last",
2816
"long.channels_last",
2817
"short.channels_last",
2832
if op.variant_test_name != "":
2833
l.append(op.variant_test_name)
2841
class TestNNCOpInfoParent(JitCommonTestCase):
2845
class TestNNCOpInfo(TestNNCOpInfoParent):
2847
super(TestNNCOpInfoParent, self).setUp()
2848
self.tensorexpr_options = TensorExprTestOptions()
2851
self.tensorexpr_options.restore()
2852
super(TestNNCOpInfoParent, self).tearDown()
2854
def te_compile(self, device, dtype, op):
2855
if op.name in skip_ops:
2857
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
2858
for sample_input in sample_inputs_itr:
2859
arg_values = [sample_input.input] + list(sample_input.args)
2860
kwarg_values = sample_input.kwargs
2864
for idx, v in enumerate(arg_values):
2865
if isinstance(v, torch.Tensor):
2866
param_names.append(f"arg_{idx}")
2867
param_values.append(v)
2868
fx_args.append(param_names[-1])
2870
fx_args.append(f"{repr(v)}")
2872
for k, v in kwarg_values.items():
2873
if isinstance(v, torch.Tensor):
2874
param_names.append(k)
2875
param_values.append(v)
2876
fx_args.append(f"{k} = {k}")
2878
fx_args.append(f"{k} = {repr(v)}")
2881
def f({', '.join(param_names)}):
2882
return op.op({', '.join(fx_args)})"""
2883
g = {"torch": torch, "inf": math.inf, "op": op}
2886
f.__module__ = "test"
2887
out = f(*param_values)
2889
ts_g = torch.jit.trace(f, param_values)
2890
kernel = torch._C._te.TensorExprKernel(ts_g.graph)
2891
correct_val = f(*param_values)
2892
self.assertEqual(kernel.run(tuple(param_values)), correct_val)
2893
self.assertEqual(kernel.fallback(tuple(param_values)), correct_val)
2896
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2898
[op for op in op_db if get_name(op) in works_list],
2899
allowed_dtypes=(torch.float,),
2901
def test_working(self, device, dtype, op):
2902
self.te_compile(device, dtype, op)
2905
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2907
[op for op in op_db if get_name(op) in known_failures],
2908
allowed_dtypes=(torch.float,),
2910
def test_failures(self, device, dtype, op):
2912
self.te_compile(device, dtype, op)
2913
except Exception as e:
2917
"Expected test to fail. If it now works, move op into works_list"
2921
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
2923
[op for op in op_db if get_name(op) not in works_list + known_failures],
2924
allowed_dtypes=(torch.float,),
2926
def test_unsupported(self, device, dtype, op):
2927
if get_name(op) in skip_ops:
2930
with warnings.catch_warnings():
2931
warnings.simplefilter("ignore", TracerWarning)
2932
self.te_compile(device, dtype, op)
2933
except Exception as e:
2937
"Expected test to fail. If it now works, move op into works_list"
2942
@ops(op_db, dtypes=OpDTypes.supported)
2943
def test_nnc_correctness(self, device, dtype, op):
2944
if not op.supports_tracing:
2945
self.skipTest("Requires tracing support")
2947
with NoTracerWarnContextManager() as no_warn:
2948
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
2950
for variant, sample in variant_sample_pairs:
2951
trace = create_traced_fn(self, variant, cache_traced_fn=True)
2953
*clone_inputs((sample.input, *sample.args)), **sample.kwargs
2956
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
2958
*clone_inputs((sample.input, *sample.args)), **sample.kwargs
2961
atol = 2e-1 if dtype == torch.bfloat16 else 1e-5
2962
rtol = 2e-1 if dtype == torch.bfloat16 else 1e-5
2963
self.assertEqual(ref, val, atol=atol, rtol=rtol)
2969
torch.jit._state._python_cu.drop_all_functions()
2973
only_for = ("cuda") if IS_FBCODE else ("cpu", "cuda")
2974
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
2978
class TestLoopnestRandomizationParent(JitTestCase):
2982
class TestLoopnestRandomization(TestLoopnestRandomizationParent):
2984
super(TestLoopnestRandomizationParent, self).setUp()
2985
self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
2986
self.old_must_use_cpu_state = torch._C._jit_get_te_must_use_llvm_cpu()
2987
self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
2989
torch._C._jit_override_can_fuse_on_cpu(True)
2992
torch._C._jit_override_can_fuse_on_gpu(True)
2994
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
2995
self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
2997
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
2998
torch._C._debug_set_fusion_group_inlining(False)
3000
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
3001
torch._C._jit_set_texpr_fuser_enabled(True)
3003
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
3004
torch._C._jit_set_te_must_use_llvm_cpu(False)
3008
os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "1"
3011
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
3012
torch._C._get_graph_executor_optimize(self.old_profiling_mode)
3014
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
3015
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
3016
torch._C._jit_set_te_must_use_llvm_cpu(self.old_must_use_cpu_state)
3017
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
3019
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
3020
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
3023
os.environ["PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED"] = "0"
3024
super(TestLoopnestRandomizationParent, self).tearDown()
3027
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
3028
def test_relu(self, device):
3029
def fn_test_relu(x, y):
3030
return F.relu(x + 0.5 * y)
3032
x = torch.randn(4, 4, dtype=torch.float, device=device)
3033
y = torch.randn(4, 4, dtype=torch.float, device=device)
3036
traced_fn = torch.jit.trace(fn, (x, y))
3039
res = traced_fn(x, y)
3040
assert torch.allclose(ref, res)
3043
instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("cpu"))
3046
if __name__ == "__main__":