8
import torch.nn.functional as F
9
from torch.testing import FileCheck
10
from unittest import skipIf
12
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
13
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
14
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
15
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward
16
from textwrap import dedent
17
from itertools import product, permutations
18
from torch.testing._internal.common_cuda import with_tf32_off
20
from test_jit import backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \
21
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
23
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
24
torch._C._jit_set_profiling_executor(True)
25
torch._C._jit_set_profiling_mode(True)
28
def strip_profiling_nodes(nodes):
29
profiling_opcodes = {'prim::BailoutTemplate', 'prim::BailOut'}
30
return [n for n in nodes if n.kind() not in profiling_opcodes]
33
def warmup_forward(f, *args):
35
for i in range(profiling_count):
41
@skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "skip due to SIGIOT failures, #67646")
42
class TestFuser(JitTestCase):
43
def assertAllFused(self, graph, except_for=()):
45
diff_graphs = [n for n in graph.nodes() if n.kind() == 'prim::DifferentiableGraph']
46
if len(diff_graphs) > 0:
47
self.assertEqual(len(diff_graphs), 1)
48
graph = diff_graphs[0].g('Subgraph')
50
allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::BailoutTemplate',
51
'prim::BailOut', 'prim::TupleConstruct'} | set(except_for)
52
self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
54
self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
56
def _test_fused_abs(self, device='cpu'):
60
a = torch.randn(5, device=device)
61
scripted = self.checkScript(func, (a,))
62
self.assertAllFused(scripted.graph_for(a))
64
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
66
def test_abs_cpu(self):
67
self._test_fused_abs()
69
@unittest.skipIf(not IS_WINDOWS, "This is meant to be Windows-specific")
70
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
72
def test_abs_cpu_unicode_temp_dir(self):
73
with TemporaryDirectoryName(suffix='中文') as dname:
74
shell_env = os.environ.copy()
75
shell_env['TMP'] = dname
76
cmd = [sys.executable, os.path.basename(__file__), type(self).__name__ + '.test_abs_cpu']
77
legacy_jit_flag = '--jit-executor=legacy'
79
if v == legacy_jit_flag:
80
cmd.append(legacy_jit_flag)
81
return_code = shell(cmd, cwd=os.path.dirname(__file__), env=shell_env)
82
self.assertEqual(return_code, 0)
84
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
85
def test_abs_cuda(self):
86
self._test_fused_abs(device="cuda")
88
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
89
def test_zero_element_tensors(self):
90
def decode(sin_t, cos_t):
91
theta = torch.atan2(sin_t.float(), cos_t.float())
94
sin = torch.zeros(0, device="cuda")
95
cos = torch.zeros(0, device="cuda")
97
ge = self.checkScript(decode, inputs)
99
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
100
def test_arg_configurations_smoke_cuda(self):
106
z1, z2 = (x + y).chunk(2, dim=1)
109
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
110
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
111
traced_f = torch.jit.trace(f, (x, y,))
112
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
114
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
115
def test_broadcast_cuda(self):
116
def scaleshift(x, scale, shift):
117
return x * scale + shift
120
torch.randn(4, 4, dtype=torch.float, device='cuda'),
121
torch.randn(4, dtype=torch.float, device='cuda'),
122
torch.randn(4, dtype=torch.float, device='cuda'),
124
ge = self.checkTrace(scaleshift, inputs)
125
self.assertAllFused(ge.graph_for(*inputs))
127
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
128
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no bfloat support with profiling on")
129
def test_cuda_bfloat16(self):
131
return (x + y).relu()
132
m = torch.jit.script(foo)
133
x = torch.randn(65536).cuda().bfloat16()
134
y = torch.randn_like(x)
135
self.assertAllFused(m.graph_for(x, y))
137
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
138
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
139
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
140
def test_cuda_half(self):
141
x = torch.randn(4, 4, dtype=torch.half, device='cuda')
142
y = torch.randn(4, 4, dtype=torch.half, device='cuda')
145
self.fn_test_comparison_gt_lt,
151
inputs = (x.float(), y.float())
152
fusion_inputs = (x, y)
154
local_inputs = [t.clone().requires_grad_() for t in inputs]
155
local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
158
fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False)
159
outputs = fn(*local_inputs)
160
fusion_outputs = fusion(*local_fusion_inputs)
161
outputs_half = [t.half() for t in outputs]
162
self.assertEqual(outputs_half, fusion_outputs)
165
for output, fusion_output in zip(outputs_half, fusion_outputs):
166
grads = torch.autograd.grad(
167
output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
168
fusion_grads = torch.autograd.grad(
169
fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
170
grads_half = [t.half() for t in grads]
171
self.assertEqual(grads_half, fusion_grads)
173
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
174
def test_checks_cat_inputs(self):
179
return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
183
x = torch.randn(2, 4, dtype=torch.float, device='cuda')
184
y = torch.randn(1, 4, dtype=torch.float, device='cuda')
186
scripted = self.checkScript(f, (x, y))
187
self.assertAllFused(scripted.graph_for(x, y))
189
@unittest.skipIf(not RUN_CUDA, "No CUDA")
190
def test_remainder_cuda(self):
192
return 1 + torch.remainder(x, y) - 1
194
a = torch.rand([512], dtype=torch.float).cuda()
195
b = torch.rand([512], dtype=torch.float).cuda()
197
ge = self.checkScript(cuda_rem, inputs)
198
graph = ge.graph_for(*inputs)
199
self.assertAllFused(graph)
201
@unittest.skipIf(not RUN_CUDA, "No CUDA")
202
def test_chunk_cuda(self):
204
a, b, c = x.chunk(3, 1)
207
inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
209
ge = self.checkScript(fn, inputs)
210
graph = ge.graph_for(*inputs)
211
self.assertAllFused(graph)
212
FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph))
215
def _test_chunk_correctness(self, device='cpu'):
217
x0, x1, x2, x3 = x.chunk(4, 0)
218
return x0 + x1 + x2 + x3
221
x0, x1, x2, x3 = x.chunk(4, 1)
222
return x0 + x1 + x2 + x3
225
x0, x1, x2, x3 = x.chunk(4, 2)
226
return x0 + x1 + x2 + x3
228
fns = [chunk_4_0, chunk_4_1, chunk_4_last]
231
torch.randn(4, 4, 4, dtype=torch.float, device=device),
234
torch.randn(12, 8, 16, dtype=torch.float, device=device),
237
torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
240
for tensor in tensors:
242
self.checkScript(fn, [tensor])
244
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
246
def test_chunk_correctness(self):
247
return self._test_chunk_correctness(self, 'cpu')
249
@unittest.skipIf(not RUN_CUDA, "No CUDA")
250
def test_chunk_correctness_cuda(self):
251
return self._test_chunk_correctness(self, 'cuda')
253
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
254
def test_chunk_distributes_cuda(self):
256
z1, z2 = (x + y).chunk(2, dim=1)
259
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
260
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
262
ge = self.checkTrace(f, (x, y))
263
graph = ge.graph_for(x, y)
264
FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_') \
265
.check_count('ConstantChunk', 2, exactly=True).run(str(graph))
267
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
268
def test_chunk_motion_deduplicates_inputs(self):
280
torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
282
for func in [func1, func2]:
283
module = self.checkScript(func, inputs)
284
forward_graph = module.graph_for(*inputs)
285
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
286
fusion_group = list(forward_graph.nodes())[-1]
287
self.assertEqual(len(list(fusion_group.inputs())), 1)
289
@unittest.skipIf(not RUN_CUDA, "No CUDA")
290
def test_chunk_multiple_cuda(self):
294
z1, z2 = z.chunk(2, 2)
295
x1, x2, x3 = x.chunk(3, 1)
296
y1, y2 = y.chunk(2, 0)
297
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
300
torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
301
torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
302
torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
303
torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
306
ge = self.checkScript(fn, inputs)
307
self.assertAllFused(ge.graph_for(*inputs))
309
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
310
def test_minmax(self):
312
return torch.max(2 * a, b)
315
return torch.min(2 * a, b)
317
a = torch.randn(4, 4, dtype=torch.float, device="cuda")
318
b = torch.randn(4, 4, dtype=torch.float, device="cuda")
319
nan = torch.tensor(float('nan'), dtype=torch.float, device="cuda")
321
for f, inputs in product(
323
([a, b], [a, nan], [b, nan])):
324
s = self.checkScript(f, inputs)
325
self.assertAllFused(s.graph_for(*inputs))
327
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
328
def test_clamp(self):
330
return torch.clamp(a + b, min=0, max=2)
333
return torch.clamp(a + b, min=0, max=float('inf'))
335
def funcOptMin(a, b):
336
return torch.clamp(a + b, max=2)
338
def funcOptMax(a, b):
339
return torch.clamp(a + b, min=0)
341
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
342
b = torch.randn(4, 4, dtype=torch.float, device='cuda')
343
nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')
345
funcs = (func2, funcInf, funcOptMin, funcOptMax)
346
for f, inputs in product(funcs, [[a, b], [a, nan]]):
347
f.__disable_jit_function_caching__ = True
349
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
350
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
352
with enable_profiling_mode_for_profiling_tests():
353
warmup_backward(c.sum())
354
graph = backward_graph(s)
355
self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
357
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
358
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
359
def test_dropout(self):
361
x = torch.nn.functional.dropout(x)
362
return torch.nn.functional.relu(x)
364
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
365
s = torch.jit.script(func)
368
warmup_backward(c.sum())
370
graph = backward_graph(s, skip_check=True)
371
self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'})
373
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
374
def test_comparison_eq_ne(self):
376
mask = (x == 0).type_as(x)
378
mask = (x != 0).type_as(x)
382
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
383
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
385
ge = self.checkTrace(f, (x, y))
386
self.assertAllFused(ge.graph_for(x, y))
389
def fn_test_comparison_gt_lt(x, y):
390
mask = (x > 0).type_as(x)
392
mask = (x < 0).type_as(x)
396
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
397
def test_comparison_gt_lt_cuda(self):
398
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
399
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
401
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
402
self.assertAllFused(ge.graph_for(x, y))
404
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
405
def test_comparison_ge_le_cuda(self):
407
mask = (x >= 0).type_as(x)
409
mask = (x <= 0).type_as(x)
413
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
414
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
416
ge = self.checkTrace(f, (x, y))
417
self.assertAllFused(ge.graph_for(x, y))
418
x.requires_grad_(True)
419
y.requires_grad_(True)
420
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
421
"aten::_size_if_not_equal"))
423
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
424
def test_addcmul_cuda(self):
425
t = torch.randn(1, 4, dtype=torch.float, device='cuda')
426
t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
427
t2 = torch.randn(1, 4, dtype=torch.float, device='cuda')
430
return t.addcmul(t + 1, t2, value=0.1)
432
ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
433
graph = ge.graph_for(t, t1, t2)
434
self.assertAllFused(graph)
442
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
444
start = torch.randn(4, 1, dtype=torch.float, device='cuda')
445
end = torch.randn(1, 4, dtype=torch.float, device='cuda')
446
weight = torch.tensor(0.5, dtype=torch.float, device='cuda')
449
def foo_weight_scalar(start, end):
450
return torch.lerp(start + 1, end, 0.5)
453
def foo_weight_tensor(start, end):
454
return torch.lerp(start + 1, end, weight)
456
ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
457
graph = ge_weight_scalar.graph_for(start, end)
458
self.assertAllFused(graph)
460
ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end))
461
graph = ge_weight_tensor.graph_for(start, end)
462
self.assertAllFused(graph)
464
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
465
def test_concat_cuda(self):
466
hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
467
cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
470
return torch.cat((hx + cx, hx * cx))
472
ge = self.checkTrace(foo, (hx, cx))
473
graph = ge.graph_for(hx, cx)
474
self.assertAllFused(graph)
475
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
477
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
478
def test_concat_invariant_cuda(self):
484
w = torch.cat([x1, y1])
487
x = torch.randn(2, 2, dtype=torch.float, device='cuda')
488
y = torch.randn(2, 2, dtype=torch.float, device='cuda')
489
z = torch.randn(4, 2, dtype=torch.float, device='cuda')
490
ge = self.checkTrace(fn, (x, y, z))
491
graph = ge.graph_for(x, y, z)
492
self.assertAllFused(graph, except_for={'aten::add'})
493
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
496
def fn_test_exp(x, y):
497
return (x + .5 * y).exp()
499
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
500
def test_exp_cuda(self):
501
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
502
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
504
ge = self.checkTrace(self.fn_test_exp, (x, y))
505
self.assertAllFused(ge.graph_for(x, y))
507
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
508
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
509
@torch._jit_internal._disable_emit_hooks_decorator
511
def test_fuse_decompose_normalization(self):
512
class ResLike(torch.jit.ScriptModule):
513
def __init__(self, norm_module):
515
self.nm = norm_module
517
@torch.jit.script_method
518
def forward(self, x, y):
519
return y + torch.relu(self.nm(x))
521
def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph):
522
model = ResLike(nm).cuda()
523
model_noopt = ResLike(nm).cuda()
524
model_noopt.load_state_dict(model.state_dict())
525
x = torch.randn(2, 16, 8, 8, device='cuda')
526
y = torch.randn(2, 16, 8, 8, device='cuda')
529
with torch.no_grad():
531
graph = model.graph_for(x, y)
534
with torch.jit.optimized_execution(False):
535
out_noopt = model_noopt(x, y)
536
rep_noopt = str(model_noopt.graph_for(x, y))
537
self.assertEqual(out, out_noopt, atol=3e-5)
540
for node_in_graph in in_opt_graph:
541
self.assertIn(node_in_graph, rep)
543
for node_not_in_graph in not_in_opt_graph:
544
self.assertNotIn(node_not_in_graph, rep)
545
self.assertIn(node_not_in_graph, rep_noopt)
547
fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
548
self.assertEqual(len(fusion_groups), 1)
549
fused_graph = str(fusion_groups[0].g('Subgraph'))
550
for node_in_fusegraph in in_fusegraph:
551
self.assertIn(node_in_fusegraph, fused_graph)
554
bm = nn.BatchNorm2d(16)
555
test_norm_decompose(bm, ['aten::batch_norm_update_stats'],
556
['aten::batch_norm('], ['aten::sqrt'])
560
test_norm_decompose(lm, ['aten::batch_norm_stats'],
561
['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add'])
563
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
564
def test_threshold(self):
566
return torch.threshold(x, 0, -10) + x + x + x
568
x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
569
scripted = self.checkScript(f, (x,))
570
self.assertAllFused(scripted.graph_for(x))
572
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
573
def test_scalar_arg_cuda(self):
574
def fn_test_scalar_arg(x: torch.Tensor, p: float) -> torch.Tensor:
575
return p * (x * x + x)
577
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
579
scripted = self.checkScript(fn_test_scalar_arg, (x, p))
580
self.assertAllFused(scripted.graph_for(x, p))
582
x.requires_grad_(True)
586
def fn_test_scalar_arg_requires_grad(x: torch.Tensor, p: float) -> torch.Tensor:
587
return p * (x * x + x)
589
scripted = torch.jit.script(fn_test_scalar_arg_requires_grad)
591
self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes",
592
"aten::_size_if_not_equal"))
594
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
595
@unittest.skip("deduplicating introduces aliasing in backward graph's outputs")
597
def test_fuser_deduplication(self):
601
return torch.sigmoid(x + y)
603
b = torch.randn(5, 5, requires_grad=True)
604
a = torch.randn(5, 5, requires_grad=True)
605
s = self.checkScript(f, (a, b))
606
self.assertAllFused(s.graph_for(a, b), except_for={
607
'aten::size', 'aten::_size_if_not_equal', 'prim::BroadcastSizes'})
610
results = warmup_backward(c.sum(), [a, b])
611
ga2, gb2 = results.pop()
612
graph = backward_graph(s)
613
self.assertAllFused(graph)
615
self.assertEqual(ga2.data_ptr(), gb2.data_ptr())
617
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
619
@unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
620
def test_fuser_iou(self):
623
def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
624
ltx = torch.max(b1x1, b2x1)
625
lty = torch.max(b1y1, b2y1)
626
rbx = torch.min(b1x2, b2x2)
627
rby = torch.min(b1y2, b2y2)
629
w = (rbx - ltx).clamp(min=0, max=float('inf'))
630
h = (rby - lty).clamp(min=0, max=float('inf'))
633
area1 = (b1x2 - b1x1) * (b1y2 - b1y2)
634
area2 = (b2x2 - b2x1) * (b2y2 - b2y2)
635
iou = inter / (area1 + area2 - inter)
638
box1 = torch.randn(5, 4, requires_grad=True)
639
box2 = torch.randn(5, 4, requires_grad=True)
641
b1x1 = box1[:, 0].unsqueeze(1)
642
b1y1 = box1[:, 1].unsqueeze(1)
643
b1x2 = box1[:, 2].unsqueeze(1)
644
b1y2 = box1[:, 3].unsqueeze(1)
645
b2x1 = box2[:, 0].unsqueeze(0)
646
b2y1 = box2[:, 1].unsqueeze(0)
647
b2x2 = box2[:, 2].unsqueeze(0)
648
b2y2 = box2[:, 3].unsqueeze(0)
650
s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
651
self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
652
except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
654
with enable_profiling_mode_for_profiling_tests(True):
655
c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
656
warmup_backward(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
657
graph = backward_graph(s)
658
self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
660
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
661
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
663
def test_fusion_reuse_multi_gpu(self):
668
torch.randn(4, 4, dtype=torch.float),
669
torch.randn(4, 4, dtype=torch.float),
671
inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
672
inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
675
ge = self.checkScript(fn, inputs_cpu)
676
self.assertAllFused(ge.graph_for(*inputs_cpu))
680
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
681
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
683
def test_kernel_cache_multi_gpu(self):
688
x_out = x * x * x * x * x
689
y_out = y * y * y * y * y
690
z_out = z * z * z * z * z
691
return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
694
torch.randn(4, 4, dtype=torch.float),
695
torch.randn(4, 4, dtype=torch.float, device='cuda:0'),
696
torch.randn(4, 4, dtype=torch.float, device='cuda:1'),
699
prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
703
ge = self.checkScript(fn, inputs)
704
self.assertGraphContainsExactly(
705
ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
706
new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
708
self.assertEqual(new_cache_size - prev_cache_size, 1)
710
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
711
def test_nonzero_device_cuda(self):
712
device = 'cuda:' + str(1)
713
x = torch.tensor([0.4], dtype=torch.float, device=device)
714
y = torch.tensor([0.7], dtype=torch.float, device=device)
717
return torch.sigmoid(torch.tanh(x * (x + y) + x))
719
ge = self.checkTrace(doit, (x, y))
720
self.assertAllFused(ge.graph_for(x, y))
722
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
723
def test_lstm_cuda(self):
724
inputs = get_lstm_inputs('cuda', training=True)
725
module = self.checkScript(LSTMCellS, inputs)
727
forward_graph = module.graph_for(*inputs)
728
self.assertGraphContainsExactly(
729
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
730
self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
732
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
733
.check_next("return").run(str(forward_graph))
735
with enable_profiling_mode_for_profiling_tests(True):
736
hy, cy = module(*inputs)
737
warmup_backward((hy + cy).sum())
738
backward = backward_graph(module)
739
self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
740
"aten::_grad_sum_to_size"))
742
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
746
def test_lstm_concat_cuda(self):
747
inputs = get_lstm_inputs('cuda')
748
ge = self.checkTrace(LSTMCellC, inputs)
749
graph = ge.graph_for(*inputs)
750
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
752
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
753
def test_lstm_gates_permutations_cuda(self):
756
choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
757
template = dedent('''
758
def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
759
gates = {} + {} + {} + {}
760
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
761
return ingate * forgetgate * cellgate * outgate
763
for permutation in permutations(choices, len(choices)):
764
code = template.format(*permutation)
766
exec(code, globals(), scope)
767
cu = torch.jit.CompilationUnit(code)
769
inputs = get_lstm_inputs('cuda', training=False)
770
self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
771
forward_graph = cu.cell.graph_for(*inputs)
772
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
775
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
779
def test_lstm_traced_cuda(self):
780
inputs = get_lstm_inputs('cuda')
781
ge = self.checkTrace(LSTMCellF, inputs)
782
graph = ge.graph_for(*inputs)
784
FileCheck().check_not("Chunk").check_not("aten::sigmoid") \
785
.check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
786
.check_next("return").check_not("FusionGroup_2").run(str(graph))
788
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
789
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
791
def test_lstm_traced_cpu(self):
792
inputs = get_lstm_inputs('cpu')
794
ge = self.checkTrace(LSTMCellF, inputs)
795
graph = ge.graph_for(*inputs)
796
FileCheck.check("FusionGroup").run(str(graph))
797
except RuntimeError as e:
798
if 'Failed to compile' in e.args[0]:
799
warnings.warn('CPU fuser test has failed! This is not a hard failure, '
800
'because the kernels sometimes trigger bugs in compilers '
801
'(most notably GCC 7.2).')
802
raise unittest.SkipTest('Failed to compile') from e
806
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
807
def test_milstm_cuda(self):
808
inputs = get_milstm_inputs('cuda', training=True)
809
module = self.checkScript(MiLSTMCell, inputs)
810
forward_graph = module.graph_for(*inputs)
811
self.assertGraphContainsExactly(
812
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
813
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
814
.check_next("return").check("FusionGroup").run(str(forward_graph))
815
hy, cy = module(*inputs)
816
warmup_backward((hy + cy).sum())
818
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
819
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor")
820
def test_rand_cuda(self):
821
class M(torch.jit.ScriptModule):
822
__constants__ = ['d']
826
self.d = torch.device('cuda')
828
@torch.jit.script_method
830
return x * x + x + torch.rand_like(x)
832
x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
836
self.assertNotEqual(out1, out2)
837
self.assertTrue(torch.all(out1 >= 0))
838
self.assertTrue(torch.all(out1 < 1))
839
self.assertTrue(torch.all(out2 >= 0))
840
self.assertTrue(torch.all(out2 < 1))
841
self.assertAllFused(m.create.graph_for(x))
844
def fn_test_relu(x, y):
845
return F.relu(x + .5 * y)
847
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
848
def test_relu_cuda(self):
849
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
850
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
852
ge = self.checkTrace(self.fn_test_relu, (x, y))
853
self.assertAllFused(ge.graph_for(x, y))
855
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
856
def test_erf_cuda(self):
858
return F.relu(torch.erf(x) - torch.erfc(x))
860
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
861
ge = self.checkTrace(fn_test_erf, (x,))
862
self.assertAllFused(ge.graph_for(x))
863
x.requires_grad_(True)
864
ge = self.checkTrace(fn_test_erf, (x,))
865
self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes",
866
"aten::_size_if_not_equal"))
868
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
869
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "borked on the legacy executor")
870
def test_rand_broadcast_cuda(self):
871
def fn_test_rand(x, y):
872
r = torch.rand_like(y)
875
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
876
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
877
script_f = torch.jit.script(fn_test_rand)
879
self.assertAllFused(script_f.graph_for(x, y))
880
x.requires_grad_(True)
882
self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
883
"aten::_size_if_not_equal"))
885
x = torch.ones(4, 4, dtype=torch.float, device='cuda')
886
y = torch.ones(4, dtype=torch.float, device='cuda')
888
self.assertEqual(out[0], out[1])
890
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
892
def test_scalar(self):
896
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
897
y = torch.tensor(1, dtype=torch.float, device='cpu')
898
ge = self.checkScript(fn, (x, y))
899
self.assertAllFused(ge.graph_for(x, y))
901
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
902
def test_small_constant_cuda(self):
903
def fn_test_small_constant(x, y):
904
return (1e-8 * x + 5e-9 * y) * 1e8
905
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
906
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
908
ge = self.checkTrace(fn_test_small_constant, (x, y))
909
self.assertAllFused(ge.graph_for(x, y))
911
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
912
def test_tensor_scalar_ops_cuda(self):
920
def should_not_fuse(x, z):
924
inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
925
ge = self.checkScript(should_fuse, inputs)
926
self.assertAllFused(ge.graph_for(*inputs))
929
torch.randn(2, 2, dtype=torch.float, device='cuda'),
930
torch.tensor(3., dtype=torch.float, device='cuda'),
932
ge = self.checkScript(should_not_fuse, inputs)
933
self.assertGraphContainsExactly(
934
ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
936
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
938
def test_where_and_typing(self):
941
res = torch.where(mask, x, y)
944
x = torch.randn(4, 4, dtype=torch.double)
945
y = torch.randn(4, 4, dtype=torch.double)
947
script_f = self.checkScript(f, (x, y))
948
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
950
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
951
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
952
def test_grad_sum_to_size_elimination(self):
954
def my_broadcasted_cell(a, b, c):
957
s1 = torch.randn(5, 1, requires_grad=True, device='cuda')
958
s2 = torch.randn(5, 5, requires_grad=True, device='cuda')
960
module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING)
961
forward_graph = module.graph_for(s1, s1, s1)
962
self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes",
963
"aten::_size_if_not_equal"))
969
args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2
970
args = [a.detach_().requires_grad_() for a in args]
972
module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING)
973
res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2)
974
warmup_backward(res.sum(), args)
975
grads = torch.autograd.grad(res.sum(), args)
976
for inp, gr in zip(args, grads):
977
self.assertEqual(inp.shape, gr.shape)
981
for g in all_backward_graphs(module):
982
if str(g) not in old_plans:
983
assert backward is None
985
old_plans.add(str(backward))
986
num_grads = 1 if i > 0 else 0
987
self.assertEqual(len([n for n in backward.nodes() if n.kind() == 'aten::_grad_sum_to_size']), num_grads)
990
if __name__ == '__main__':