5
import torch._C._te as te
7
from torch.testing._internal.common_utils import run_tests
8
from torch.testing._internal.jit_utils import JitTestCase
11
LLVM_ENABLED = torch._C._llvm_enabled()
14
def construct_adder(n: int, dtype=torch.float32):
15
A = te.BufHandle("A", [n], dtype)
16
B = te.BufHandle("B", [n], dtype)
19
return A.load([i]) + B.load([i])
21
C = te.Compute("C", [n], compute)
23
loopnest = te.LoopNest([C])
24
loopnest.prepare_for_codegen()
25
stmt = te.simplify(loopnest.root_stmt())
27
return te.construct_codegen("ir_eval", stmt, [A, B, C])
30
class TestTensorExprPyBind(JitTestCase):
31
def test_simple_sum(self):
33
cg = construct_adder(n)
39
torch.testing.assert_close(tA + tB, tC)
41
def test_call_raw(self):
43
cg = construct_adder(n, dtype=torch.float64)
45
tA = torch.randn(n, dtype=torch.float64)
46
tB = torch.randn(n, dtype=torch.float64)
47
tC = torch.empty(n, dtype=torch.float64)
48
cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()])
49
torch.testing.assert_close(tA + tB, tC)
51
def test_external_calls(self):
54
A = te.BufHandle("A", [1, 4], dtype)
55
B = te.BufHandle("B", [4, 1], dtype)
56
C = te.BufHandle("C", [1, 1], dtype)
58
s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
60
loopnest = te.LoopNest(s, [C])
61
loopnest.prepare_for_codegen()
62
codegen = te.construct_codegen("ir_eval", s, [A, B, C])
66
tC = torch.empty(1, 1)
67
codegen.call([tA, tB, tC])
68
torch.testing.assert_close(torch.matmul(tA, tB), tC)
70
def test_dynamic_shape(self):
71
dN = te.VarHandle(torch.int32)
72
A = te.BufHandle([dN], torch.float64)
73
B = te.BufHandle([dN], torch.float64)
76
return A.load(i) - B.load(i)
78
C = te.Compute("C", [dN], compute)
80
loopnest = te.LoopNest([C])
81
loopnest.prepare_for_codegen()
83
cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN])
85
def test_with_shape(n):
86
tA = torch.randn(n, dtype=torch.double)
87
tB = torch.randn(n, dtype=torch.double)
88
tC = torch.empty(n, dtype=torch.double)
89
cg.call([tA, tB, tC, n])
90
torch.testing.assert_close(tA - tB, tC)
95
def test_dynamic_shape_2d(self):
96
dN = te.VarHandle(torch.int32)
97
dM = te.VarHandle(torch.int32)
98
A = te.BufHandle([dN, dM], torch.float64)
99
B = te.BufHandle([dN, dM], torch.float64)
102
return A.load([i, j]) - B.load([i, j])
104
C = te.Compute("C", [dN, dM], compute)
106
loopnest = te.LoopNest([C])
107
loopnest.prepare_for_codegen()
109
cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM])
111
def test_with_shape(n, m):
112
tA = torch.randn(n, m, dtype=torch.double)
113
tB = torch.randn(n, m, dtype=torch.double)
114
tC = torch.empty(n, m, dtype=torch.double)
115
cg.call([tA, tB, tC, n, m])
116
torch.testing.assert_close(tA - tB, tC)
118
test_with_shape(2, 4)
119
test_with_shape(5, 3)
121
def test_dtype_error(self):
122
te.BufHandle("a", [1], torch.float32)
123
self.assertRaises(TypeError, lambda: te.BufHandle("a", [1], "float55"))
125
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
126
def test_kernel_with_tensor_inputs(self):
130
device, size = "cpu", (4, 4)
131
x = torch.rand(size, device=device)
132
y = torch.rand(size, device=device)
133
z = torch.rand(size, device=device)
136
graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
137
%b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
138
%c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)):
139
%6 : int = prim::Constant[value=1]()
140
%7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6)
141
%3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6)
144
graph = torch._C.parse_ir(graph_str)
146
kernel = te.TensorExprKernel(graph)
147
res1 = kernel.run((x, y, z))
148
res2 = kernel.fallback((x, y, z))
150
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
151
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
153
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
154
def test_kernel_with_scalar_inputs(self):
158
x = torch.tensor(0.1, dtype=torch.float, device="cpu")
159
y = torch.tensor(0.6, dtype=torch.float, device="cpu")
160
z = torch.tensor(0.7, dtype=torch.float, device="cpu")
163
graph(%a.1 : Float(requires_grad=0, device=cpu),
164
%b.1 : Float(requires_grad=0, device=cpu),
165
%c.1 : Float(requires_grad=0, device=cpu)):
166
%3 : int = prim::Constant[value=1]()
167
%6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3)
168
%9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3)
171
graph = torch._C.parse_ir(graph_str)
173
kernel = te.TensorExprKernel(graph)
174
res1 = kernel.run((x, y, z))
175
res2 = kernel.fallback((x, y, z))
177
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
178
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
180
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
181
def test_kernel_shape_prop(self):
182
device, size = "cpu", (4, 4)
183
x = torch.rand(size, device=device)
184
y = torch.rand(size, device=device)
187
graph(%a : Tensor, %b : Tensor):
188
%c : Tensor = aten::mul(%a, %b)
191
graph = torch._C.parse_ir(graph_str)
193
exception_thrown = False
195
kernel = te.TensorExprKernel(graph)
199
exception_thrown = True
200
assert exception_thrown
203
example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
204
torch._C._te.annotate_input_shapes(graph, example_inputs)
205
torch._C._jit_pass_propagate_shapes_on_graph(graph)
208
kernel = te.TensorExprKernel(graph)
210
res = kernel.run((x, y))
211
correct = torch.mul(x, y)
212
np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
214
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
215
def test_kernel_shape_prop_module(self):
216
class TestModule(torch.nn.Module):
217
def forward(self, x, y):
220
graph = torch.jit.script(TestModule()).graph
224
exception_thrown = False
226
kernel = te.TensorExprKernel(graph)
228
exception_thrown = True
229
assert exception_thrown
232
example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
234
exception_thrown = False
236
torch._C._te.annotate_input_shapes(graph, example_inputs)
239
exception_thrown = True
240
assert exception_thrown
243
torch._C._te.remove_unused_self_argument(graph)
246
torch._C._te.annotate_input_shapes(graph, example_inputs)
247
torch._C._jit_pass_propagate_shapes_on_graph(graph)
250
kernel = te.TensorExprKernel(graph)
252
device, size = "cpu", (4, 4)
253
x = torch.rand(size, device=device)
254
y = torch.rand(size, device=device)
256
res = kernel.run((x, y))
257
correct = TestModule().forward(x, y)
258
np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
260
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
261
def test_kernel_with_t(self):
265
device, size = "cpu", (3, 4)
266
x = torch.rand(size, device=device)
269
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
270
%3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1)
273
graph = torch._C.parse_ir(graph_str)
275
kernel = te.TensorExprKernel(graph)
276
res1 = kernel.run((x,))
277
res2 = kernel.fallback((x,))
279
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
280
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
282
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
283
def test_kernel_with_transpose(self):
285
return a.transpose(-1, -2)
287
device, size = "cpu", (3, 4)
288
x = torch.rand(size, device=device)
291
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
292
%2 : int = prim::Constant[value=-1]()
293
%3 : int = prim::Constant[value=-2]()
294
%4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3)
297
graph = torch._C.parse_ir(graph_str)
299
kernel = te.TensorExprKernel(graph)
300
res1 = kernel.run((x,))
301
res2 = kernel.fallback((x,))
303
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
304
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
306
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
307
def test_kernel_with_permute(self):
309
return a.permute([2, 1, 0])
311
device, size = "cpu", (3, 4, 5)
312
x = torch.rand(size, device=device)
315
graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)):
316
%1 : int = prim::Constant[value=2]()
317
%2 : int = prim::Constant[value=1]()
318
%3 : int = prim::Constant[value=0]()
319
%4 : int[] = prim::ListConstruct(%1, %2, %3)
320
%5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4)
323
graph = torch._C.parse_ir(graph_str)
325
kernel = te.TensorExprKernel(graph)
326
res1 = kernel.run((x,))
327
res2 = kernel.fallback((x,))
329
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
330
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
332
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
333
def test_kernel_with_custom_lowering(self):
335
return a.nan_to_num()
338
x = torch.ones((2, 2), device=device)
339
x[0, 0] = x[1, 1] = torch.nan
341
graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
342
%none : NoneType = prim::Constant()
343
%y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none)
346
graph = torch._C.parse_ir(graph_str)
348
def my_custom_lowering(inputs, out_shape, out_stride, out_type, device):
350
load = inputs[0].as_buf().load(idxs)
351
return te.ifThenElse(
352
te.ExprHandle.isnan(load), te.ExprHandle.float(0.0), load
355
return te.Compute2("custom_nan_to_num", out_shape, compute)
357
kernel = te.TensorExprKernel(graph, {"aten::nan_to_num": my_custom_lowering})
358
res1 = kernel.run((x,))
359
res2 = kernel.fallback((x,))
361
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
362
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
364
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
365
def test_kernel_with_expand(self):
367
return a.expand((2, 3, 4))
370
x = torch.rand((1, 3, 1), device=device)
372
graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
373
%1 : int = prim::Constant[value=2]()
374
%2 : int = prim::Constant[value=3]()
375
%3 : int = prim::Constant[value=4]()
376
%4 : int[] = prim::ListConstruct(%1, %2, %3)
377
%5 : bool = prim::Constant[value=0]()
378
%6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5)
381
graph = torch._C.parse_ir(graph_str)
383
kernel = te.TensorExprKernel(graph)
384
res1 = kernel.run((x,))
385
res2 = kernel.fallback((x,))
387
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
388
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
390
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
391
def test_alloc_in_loop(self):
393
te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"]
395
body = te.Block([tmp.store([0], a.load([0])), b.store([0], tmp.load([0]))])
397
i = te.VarHandle("i", torch.int32)
398
body = te.For.make(i, 0, 100, body)
399
nest = te.LoopNest(body, [b])
400
nest.prepare_for_codegen()
401
f = te.construct_codegen("llvm", nest.simplify(), [a, b])
402
ta, tb = (torch.ones(1) for _ in range(2))
403
f.call([ta.data_ptr(), tb.data_ptr()])
406
class TestExprHandlePyBind(JitTestCase):
407
def test_unary_ops(self):
409
torch.sin: torch._C._te.sin,
410
torch.cos: torch._C._te.cos,
411
torch.tan: torch._C._te.tan,
412
torch.asin: torch._C._te.asin,
413
torch.acos: torch._C._te.acos,
414
torch.atan: torch._C._te.atan,
415
torch.sinh: torch._C._te.sinh,
416
torch.cosh: torch._C._te.cosh,
417
torch.tanh: torch._C._te.tanh,
418
torch.sigmoid: torch._C._te.sigmoid,
419
torch.exp: torch._C._te.exp,
420
torch.expm1: torch._C._te.expm1,
421
torch.abs: torch._C._te.abs,
422
torch.log: torch._C._te.log,
423
torch.log2: torch._C._te.log2,
424
torch.log10: torch._C._te.log10,
425
torch.log1p: torch._C._te.log1p,
426
torch.erf: torch._C._te.erf,
427
torch.erfc: torch._C._te.erfc,
428
torch.sqrt: torch._C._te.sqrt,
429
torch.rsqrt: torch._C._te.rsqrt,
430
torch.ceil: torch._C._te.ceil,
431
torch.floor: torch._C._te.floor,
432
torch.round: torch._C._te.round,
433
torch.trunc: torch._C._te.trunc,
434
torch.lgamma: torch._C._te.lgamma,
435
torch.frac: torch._C._te.frac,
438
def construct_te_fn(op, n: int, dtype=torch.float32):
439
A = torch._C._te.BufHandle("A", [n], dtype)
442
return op(A.load([i]))
444
C = te.Compute("C", [n], compute)
446
loopnest = te.LoopNest([C])
447
loopnest.prepare_for_codegen()
448
stmt = te.simplify(loopnest.root_stmt())
450
return te.construct_codegen("ir_eval", stmt, [A, C])
454
for torch_op, te_op in unary_operators.items():
457
te_fn = construct_te_fn(te_op, n, torch.float32)
460
assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
463
if __name__ == "__main__":