1
# Owner(s): ["oncall: jit"]
7
from torch.testing._internal.common_utils import skipIfTorchDynamo
10
# Make the helper files in test/ importable
11
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12
sys.path.append(pytorch_test_dir)
13
from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
16
if __name__ == "__main__":
18
"This test file is not meant to be run directly, use:\n\n"
19
"\tpython test/test_jit.py TESTNAME\n\n"
25
class TestProfiler(JitTestCase):
27
self.prev_exec = torch._C._jit_set_profiling_executor(True)
28
self.prev_profiling = torch._C._get_graph_executor_optimize(True)
29
self.inline_autodiff = torch._C._debug_set_autodiff_subgraph_inlining(False)
30
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
31
self.can_fuse_on_cpu = torch._C._jit_can_fuse_on_cpu()
32
torch._C._jit_set_texpr_fuser_enabled(True)
33
torch._C._jit_override_can_fuse_on_cpu(True)
34
self.default_dtype = torch.get_default_dtype()
35
self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True)
36
torch.set_default_dtype(torch.double)
37
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
38
torch._C._debug_set_fusion_group_inlining(False)
39
self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
40
torch._C._jit_set_te_must_use_llvm_cpu(False)
43
torch._C._jit_set_profiling_executor(self.prev_exec)
44
torch._C._get_graph_executor_optimize(self.prev_profiling)
45
torch._C._debug_set_autodiff_subgraph_inlining(self.inline_autodiff)
46
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
47
torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu)
48
torch.set_default_dtype(self.default_dtype)
49
torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
50
torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
51
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
53
def test_tensor_type_not_determined_by_inputs(self):
55
def scalar_type_input(x, y, z):
56
return x + y + 4 + z.item()
58
x = torch.tensor([2, 2])
59
scalar_type_input(x, x, torch.tensor(1))
60
scalar_type_input(x, x, torch.tensor(1))
61
scalar_type_input(x, x, torch.tensor(1.0))
62
g = torch.jit.last_executed_optimized_graph()
64
# item & add should not get pulled into the fusion group -
65
# we expect to see Fusion Group (item / add) Fusion Group in ir dump
66
FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next(
68
).check("TensorExpr").run(g)
71
def non_const_dtype(x, y, cond: bool):
72
dtype = torch.int16 if cond else torch.int32
73
return (x + y + 3).sum(dtype=dtype)
75
non_const_dtype(x, x, True)
76
non_const_dtype(x, x, True)
77
g = torch.jit.last_executed_optimized_graph()
78
# because dtype is non-const, sum should not get pulled into the Fusion Group
79
FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(
83
def test_specialize_backward(self):
89
test_fuse.__disable_jit_function_caching__ = True
91
scripted_f = torch.jit.script(test_fuse)
92
x = torch.ones(1, requires_grad=True)
93
y = torch.ones(1, requires_grad=True)
97
g = torch.jit.last_executed_optimized_graph()
98
# Backward has an if node guarding specializations,
99
# within the if node true block there is only one if node
100
# that guards a tensorexpr group
101
optimized_block = next(g.findNode("prim::If").blocks())
102
if_nodes = list(optimized_block.findAllNodes("prim::If"))
104
self.assertEqual(len(if_nodes), 1)
105
FileCheck().check("Group[Subgraph").run(str(if_nodes[0]))
106
# no broadcasts occurred, sum_to_size have been specialized out
107
self.assertIsNone(optimized_block.findNode("aten::_grad_sum_to_size"))
109
broadcast_f = torch.jit.script(test_fuse)
110
x = torch.ones([2, 2], requires_grad=True)
111
y = torch.ones([1], requires_grad=True)
113
b = broadcast_f(x, y)
114
b.backward(torch.ones([2, 2], dtype=torch.float), retain_graph=True)
115
b.backward(torch.ones([2, 2], dtype=torch.float))
116
# warmup_backward(b, torch.ones([2, 2], dtype=torch.float))
117
g = torch.jit.last_executed_optimized_graph()
118
optimized_block = next(g.findNode("prim::If").blocks())
119
# broadcasts occurred, currently expect to see aten::_grad_sum_to_size
120
self.assertIsNotNone(optimized_block.findNode("aten::_grad_sum_to_size"))
122
def test_specialized_types(self):
129
x = torch.tensor([0.5])
133
g = torch.jit.last_executed_optimized_graph()
134
# Types should remain specialized for typecheck outputs & fusion outputs
135
FileCheck().check("Double(").check_same("prim::TypeCheck").check_same(
137
).check("Double").check_same("TensorExpr").run(g)
139
# other outputs should not be specialized
140
FileCheck().check("Tensor = prim::If").run(g)
142
def test_aliasing_merge(self):
155
g = torch.jit.last_executed_optimized_graph()
156
self.assertEqual(len(list(g.findAllNodes("prim::TypeCheck"))), 2)
157
FileCheck().check("TensorExpr").check("aten::add_").check("TensorExpr").run(g)
159
def test_use_not_profiled(self):
160
def foo(t1, t2, t3, t4, t: float):
161
h = t1 + t2 + t3 + t4
163
# Putting a use of t1 in a never-executed conditional prevents
167
t = torch.rand(8, dtype=torch.float)
169
foo_script = torch.jit.script(foo)
170
for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
171
foo_script(t, t, t, t, 0.1)
173
self.assertEqual(foo(t, t, t, t, 0.1), foo_script(t, t, t, t, 0.1))
174
g = torch.jit.last_executed_optimized_graph()
176
FileCheck().check("graph").check_not("aten::add").check("prim::If").run(g)
178
def test_not_fusing_scalar_ops(self):
180
def foo(x: int, y: int):
181
return x + y + 2 + 4 + 5 + 6
185
g = torch.jit.last_executed_optimized_graph()
186
FileCheck().check_not("TensorExpr").run(g)
188
def test_not_optimizing_property(self):
191
return x + y + 1 + 2 + 3, x.size()
196
g = torch.jit.last_executed_optimized_graph()
197
FileCheck().check("aten::size").run(g)
198
x = torch.ones([2, 3, 5])
199
self.assertEqual(foo(x, x), (x + x + 1 + 2 + 3, x.size()))
201
def test_fallback_graph_not_specialized(self):
213
g = torch.jit.last_executed_optimized_graph()
214
FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(
218
def test_autograd_fallback_graph(self):
226
x = torch.ones(1, requires_grad=True)
227
y = torch.ones(1, requires_grad=True)
230
b.backward(torch.ones([1], dtype=torch.float), retain_graph=True)
231
b.backward(torch.ones([1], dtype=torch.float))
233
g = torch.jit.last_executed_optimized_graph()
234
FileCheck().check("fallback_function").check_next("CallFunction").run(g)
236
def test_tensor_constant(self):
238
return a + b + torch.tensor([2])
240
x = torch.ones(1, requires_grad=False)
241
foo_script = torch.jit.script(foo)
245
self.assertEqual(foo_script(x, x), foo(x, x))
246
g = torch.jit.last_executed_optimized_graph()
247
FileCheck().check_count("aten::add", 2, exactly=True).run(g)
249
def test_local_fusion_strategy(self):
254
torch.jit.set_fusion_strategy([("STATIC", 1)])
256
foo(torch.rand([10]))
258
torch.jit.set_fusion_strategy([("STATIC", 10)])
264
g = torch.jit.last_executed_optimized_graph()
265
FileCheck().check_count(":TensorExprGroup", 2, exactly=True).run(g)
267
def test_iterative_fusion(self):
276
x = torch.ones(1, requires_grad=False)
280
# when we iterate through the block, we will start
281
# by fusing a = a + b with a = a + 1
282
# if we were to continue iteration from that fusion point,
283
# would miss the fusion opportunity of c = c + d + b
285
g = torch.jit.last_executed_optimized_graph()
286
self.assertEqual(len(list(g.findAllNodes("prim::TensorExprGroup"))), 2)