pytorch

Форк
0
/
test_tensorexpr_pybind.py 
464 строки · 16.0 Кб
1
# Owner(s): ["NNC"]
2

3
import torch
4
import numpy as np
5
import torch._C._te as te
6

7
from torch.testing._internal.common_utils import run_tests
8
from torch.testing._internal.jit_utils import JitTestCase
9
import unittest
10

11
LLVM_ENABLED = torch._C._llvm_enabled()
12

13

14
def construct_adder(n: int, dtype=torch.float32):
15
    A = te.BufHandle("A", [n], dtype)
16
    B = te.BufHandle("B", [n], dtype)
17

18
    def compute(i):
19
        return A.load([i]) + B.load([i])
20

21
    C = te.Compute("C", [n], compute)
22

23
    loopnest = te.LoopNest([C])
24
    loopnest.prepare_for_codegen()
25
    stmt = te.simplify(loopnest.root_stmt())
26

27
    return te.construct_codegen("ir_eval", stmt, [A, B, C])
28

29

30
class TestTensorExprPyBind(JitTestCase):
31
    def test_simple_sum(self):
32
        n = 32
33
        cg = construct_adder(n)
34

35
        tA = torch.randn(n)
36
        tB = torch.randn(n)
37
        tC = torch.empty(n)
38
        cg.call([tA, tB, tC])
39
        torch.testing.assert_close(tA + tB, tC)
40

41
    def test_call_raw(self):
42
        n = 16
43
        cg = construct_adder(n, dtype=torch.float64)
44

45
        tA = torch.randn(n, dtype=torch.float64)
46
        tB = torch.randn(n, dtype=torch.float64)
47
        tC = torch.empty(n, dtype=torch.float64)
48
        cg.call_raw([tA.data_ptr(), tB.data_ptr(), tC.data_ptr()])
49
        torch.testing.assert_close(tA + tB, tC)
50

51
    def test_external_calls(self):
52
        dtype = torch.float32
53

54
        A = te.BufHandle("A", [1, 4], dtype)
55
        B = te.BufHandle("B", [4, 1], dtype)
56
        C = te.BufHandle("C", [1, 1], dtype)
57

58
        s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])
59

60
        loopnest = te.LoopNest(s, [C])
61
        loopnest.prepare_for_codegen()
62
        codegen = te.construct_codegen("ir_eval", s, [A, B, C])
63

64
        tA = torch.ones(1, 4)
65
        tB = torch.ones(4, 1)
66
        tC = torch.empty(1, 1)
67
        codegen.call([tA, tB, tC])
68
        torch.testing.assert_close(torch.matmul(tA, tB), tC)
69

70
    def test_dynamic_shape(self):
71
        dN = te.VarHandle(torch.int32)
72
        A = te.BufHandle([dN], torch.float64)
73
        B = te.BufHandle([dN], torch.float64)
74

75
        def compute(i):
76
            return A.load(i) - B.load(i)
77

78
        C = te.Compute("C", [dN], compute)
79

80
        loopnest = te.LoopNest([C])
81
        loopnest.prepare_for_codegen()
82

83
        cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN])
84

85
        def test_with_shape(n):
86
            tA = torch.randn(n, dtype=torch.double)
87
            tB = torch.randn(n, dtype=torch.double)
88
            tC = torch.empty(n, dtype=torch.double)
89
            cg.call([tA, tB, tC, n])
90
            torch.testing.assert_close(tA - tB, tC)
91

92
        test_with_shape(8)
93
        test_with_shape(31)
94

95
    def test_dynamic_shape_2d(self):
96
        dN = te.VarHandle(torch.int32)
97
        dM = te.VarHandle(torch.int32)
98
        A = te.BufHandle([dN, dM], torch.float64)
99
        B = te.BufHandle([dN, dM], torch.float64)
100

101
        def compute(i, j):
102
            return A.load([i, j]) - B.load([i, j])
103

104
        C = te.Compute("C", [dN, dM], compute)
105

106
        loopnest = te.LoopNest([C])
107
        loopnest.prepare_for_codegen()
108

109
        cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM])
110

111
        def test_with_shape(n, m):
112
            tA = torch.randn(n, m, dtype=torch.double)
113
            tB = torch.randn(n, m, dtype=torch.double)
114
            tC = torch.empty(n, m, dtype=torch.double)
115
            cg.call([tA, tB, tC, n, m])
116
            torch.testing.assert_close(tA - tB, tC)
117

118
        test_with_shape(2, 4)
119
        test_with_shape(5, 3)
120

121
    def test_dtype_error(self):
122
        te.BufHandle("a", [1], torch.float32)  # ok
123
        self.assertRaises(TypeError, lambda: te.BufHandle("a", [1], "float55"))
124

125
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
126
    def test_kernel_with_tensor_inputs(self):
127
        def f(a, b, c):
128
            return a + b + c
129

130
        device, size = "cpu", (4, 4)
131
        x = torch.rand(size, device=device)
132
        y = torch.rand(size, device=device)
133
        z = torch.rand(size, device=device)
134

135
        graph_str = """
136
graph(%a.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
137
      %b.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu),
138
      %c.1 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu)):
139
  %6 : int = prim::Constant[value=1]()
140
  %7 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %6)
141
  %3 : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::add(%7, %c.1, %6)
142
  return (%3)
143
        """
144
        graph = torch._C.parse_ir(graph_str)
145

146
        kernel = te.TensorExprKernel(graph)
147
        res1 = kernel.run((x, y, z))
148
        res2 = kernel.fallback((x, y, z))
149
        correct = f(x, y, z)
150
        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
151
        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
152

153
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
154
    def test_kernel_with_scalar_inputs(self):
155
        def f(a, b, c):
156
            return a + b + c
157

158
        x = torch.tensor(0.1, dtype=torch.float, device="cpu")
159
        y = torch.tensor(0.6, dtype=torch.float, device="cpu")
160
        z = torch.tensor(0.7, dtype=torch.float, device="cpu")
161

162
        graph_str = """
163
graph(%a.1 : Float(requires_grad=0, device=cpu),
164
      %b.1 : Float(requires_grad=0, device=cpu),
165
      %c.1 : Float(requires_grad=0, device=cpu)):
166
  %3 : int = prim::Constant[value=1]()
167
  %6 : Float(requires_grad=0, device=cpu) = aten::add(%a.1, %b.1, %3)
168
  %9 : Float(requires_grad=0, device=cpu) = aten::add(%6, %c.1, %3)
169
  return (%9)
170
        """
171
        graph = torch._C.parse_ir(graph_str)
172

173
        kernel = te.TensorExprKernel(graph)
174
        res1 = kernel.run((x, y, z))
175
        res2 = kernel.fallback((x, y, z))
176
        correct = f(x, y, z)
177
        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
178
        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
179

180
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
181
    def test_kernel_shape_prop(self):
182
        device, size = "cpu", (4, 4)
183
        x = torch.rand(size, device=device)
184
        y = torch.rand(size, device=device)
185

186
        graph_str = """
187
graph(%a : Tensor, %b : Tensor):
188
  %c : Tensor = aten::mul(%a, %b)
189
  return (%c)
190
        """
191
        graph = torch._C.parse_ir(graph_str)
192

193
        exception_thrown = False
194
        try:
195
            kernel = te.TensorExprKernel(graph)
196
        except RuntimeError:
197
            # Graph doesn't have shape info for inputs => compilation should
198
            # fail
199
            exception_thrown = True
200
        assert exception_thrown
201

202
        # Inject shape info and try compiling again
203
        example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
204
        torch._C._te.annotate_input_shapes(graph, example_inputs)
205
        torch._C._jit_pass_propagate_shapes_on_graph(graph)
206

207
        # Now compilation should pass
208
        kernel = te.TensorExprKernel(graph)
209

210
        res = kernel.run((x, y))
211
        correct = torch.mul(x, y)
212
        np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
213

214
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
215
    def test_kernel_shape_prop_module(self):
216
        class TestModule(torch.nn.Module):
217
            def forward(self, x, y):
218
                return x * x + y
219

220
        graph = torch.jit.script(TestModule()).graph
221

222
        # Try compiling the graph as-is. It should fail because it doesn't have
223
        # shape info.
224
        exception_thrown = False
225
        try:
226
            kernel = te.TensorExprKernel(graph)
227
        except RuntimeError:
228
            exception_thrown = True
229
        assert exception_thrown
230

231
        # Try injecting shape info for graph inputs
232
        example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
233

234
        exception_thrown = False
235
        try:
236
            torch._C._te.annotate_input_shapes(graph, example_inputs)
237
        except RuntimeError:
238
            # Graph has a 'self' argument for which we can't set shapes
239
            exception_thrown = True
240
        assert exception_thrown
241

242
        # Remove 'self' argument and try annotating shapes one more time
243
        torch._C._te.remove_unused_self_argument(graph)
244

245
        # Inject shape info and try compiling again
246
        torch._C._te.annotate_input_shapes(graph, example_inputs)
247
        torch._C._jit_pass_propagate_shapes_on_graph(graph)
248

249
        # Now compilation should pass
250
        kernel = te.TensorExprKernel(graph)
251

252
        device, size = "cpu", (4, 4)
253
        x = torch.rand(size, device=device)
254
        y = torch.rand(size, device=device)
255

256
        res = kernel.run((x, y))
257
        correct = TestModule().forward(x, y)
258
        np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
259

260
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
261
    def test_kernel_with_t(self):
262
        def f(a):
263
            return a.t()
264

265
        device, size = "cpu", (3, 4)
266
        x = torch.rand(size, device=device)
267

268
        graph_str = """
269
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
270
  %3 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::t(%a.1)
271
  return (%3)
272
        """
273
        graph = torch._C.parse_ir(graph_str)
274

275
        kernel = te.TensorExprKernel(graph)
276
        res1 = kernel.run((x,))
277
        res2 = kernel.fallback((x,))
278
        correct = f(x)
279
        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
280
        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
281

282
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
283
    def test_kernel_with_transpose(self):
284
        def f(a):
285
            return a.transpose(-1, -2)
286

287
        device, size = "cpu", (3, 4)
288
        x = torch.rand(size, device=device)
289

290
        graph_str = """
291
graph(%a.1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
292
  %2 : int = prim::Constant[value=-1]()
293
  %3 : int = prim::Constant[value=-2]()
294
  %4 : Float(4, 3, strides=[4, 1], requires_grad=0, device=cpu) = aten::transpose(%a.1, %2, %3)
295
  return (%4)
296
        """
297
        graph = torch._C.parse_ir(graph_str)
298

299
        kernel = te.TensorExprKernel(graph)
300
        res1 = kernel.run((x,))
301
        res2 = kernel.fallback((x,))
302
        correct = f(x)
303
        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
304
        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
305

306
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
307
    def test_kernel_with_permute(self):
308
        def f(a):
309
            return a.permute([2, 1, 0])
310

311
        device, size = "cpu", (3, 4, 5)
312
        x = torch.rand(size, device=device)
313

314
        graph_str = """
315
graph(%a.1 : Float(3, 4, 5, strides=[20, 5, 1], requires_grad=0, device=cpu)):
316
  %1 : int = prim::Constant[value=2]()
317
  %2 : int = prim::Constant[value=1]()
318
  %3 : int = prim::Constant[value=0]()
319
  %4 : int[] = prim::ListConstruct(%1, %2, %3)
320
  %5 : Float(5, 4, 3, strides=[12, 3, 1], requires_grad=0, device=cpu) = aten::permute(%a.1, %4)
321
  return (%5)
322
        """
323
        graph = torch._C.parse_ir(graph_str)
324

325
        kernel = te.TensorExprKernel(graph)
326
        res1 = kernel.run((x,))
327
        res2 = kernel.fallback((x,))
328
        correct = f(x)
329
        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
330
        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
331

332
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
333
    def test_kernel_with_custom_lowering(self):
334
        def f(a):
335
            return a.nan_to_num()
336

337
        device = "cpu"
338
        x = torch.ones((2, 2), device=device)
339
        x[0, 0] = x[1, 1] = torch.nan
340
        graph_str = """
341
graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
342
    %none : NoneType = prim::Constant()
343
    %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none)
344
    return (%y)
345
        """
346
        graph = torch._C.parse_ir(graph_str)
347

348
        def my_custom_lowering(inputs, out_shape, out_stride, out_type, device):
349
            def compute(idxs):
350
                load = inputs[0].as_buf().load(idxs)
351
                return te.ifThenElse(
352
                    te.ExprHandle.isnan(load), te.ExprHandle.float(0.0), load
353
                )
354

355
            return te.Compute2("custom_nan_to_num", out_shape, compute)
356

357
        kernel = te.TensorExprKernel(graph, {"aten::nan_to_num": my_custom_lowering})
358
        res1 = kernel.run((x,))
359
        res2 = kernel.fallback((x,))
360
        correct = f(x)
361
        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
362
        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
363

364
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
365
    def test_kernel_with_expand(self):
366
        def f(a):
367
            return a.expand((2, 3, 4))
368

369
        device = "cpu"
370
        x = torch.rand((1, 3, 1), device=device)
371
        graph_str = """
372
graph(%a : Float(1, 3, 1, strides=[3, 1, 1], requires_grad=0, device=cpu)):
373
  %1 : int = prim::Constant[value=2]()
374
  %2 : int = prim::Constant[value=3]()
375
  %3 : int = prim::Constant[value=4]()
376
  %4 : int[] = prim::ListConstruct(%1, %2, %3)
377
  %5 : bool = prim::Constant[value=0]()
378
  %6 : Float(2, 3, 4, strides=[12, 4, 0], requires_grad=0, device=cpu) = aten::expand(%a, %4, %5)
379
  return (%6)
380
        """
381
        graph = torch._C.parse_ir(graph_str)
382

383
        kernel = te.TensorExprKernel(graph)
384
        res1 = kernel.run((x,))
385
        res2 = kernel.fallback((x,))
386
        correct = f(x)
387
        np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
388
        np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
389

390
    @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
391
    def test_alloc_in_loop(self):
392
        a, tmp, b = (
393
            te.BufHandle(name, [1], torch.float32) for name in ["a", "tmp", "b"]
394
        )
395
        body = te.Block([tmp.store([0], a.load([0])), b.store([0], tmp.load([0]))])
396
        for _ in range(4):
397
            i = te.VarHandle("i", torch.int32)
398
            body = te.For.make(i, 0, 100, body)
399
        nest = te.LoopNest(body, [b])
400
        nest.prepare_for_codegen()
401
        f = te.construct_codegen("llvm", nest.simplify(), [a, b])
402
        ta, tb = (torch.ones(1) for _ in range(2))
403
        f.call([ta.data_ptr(), tb.data_ptr()])
404

405

406
class TestExprHandlePyBind(JitTestCase):
407
    def test_unary_ops(self):
408
        unary_operators = {
409
            torch.sin: torch._C._te.sin,
410
            torch.cos: torch._C._te.cos,
411
            torch.tan: torch._C._te.tan,
412
            torch.asin: torch._C._te.asin,
413
            torch.acos: torch._C._te.acos,
414
            torch.atan: torch._C._te.atan,
415
            torch.sinh: torch._C._te.sinh,
416
            torch.cosh: torch._C._te.cosh,
417
            torch.tanh: torch._C._te.tanh,
418
            torch.sigmoid: torch._C._te.sigmoid,
419
            torch.exp: torch._C._te.exp,
420
            torch.expm1: torch._C._te.expm1,
421
            torch.abs: torch._C._te.abs,
422
            torch.log: torch._C._te.log,
423
            torch.log2: torch._C._te.log2,
424
            torch.log10: torch._C._te.log10,
425
            torch.log1p: torch._C._te.log1p,
426
            torch.erf: torch._C._te.erf,
427
            torch.erfc: torch._C._te.erfc,
428
            torch.sqrt: torch._C._te.sqrt,
429
            torch.rsqrt: torch._C._te.rsqrt,
430
            torch.ceil: torch._C._te.ceil,
431
            torch.floor: torch._C._te.floor,
432
            torch.round: torch._C._te.round,
433
            torch.trunc: torch._C._te.trunc,
434
            torch.lgamma: torch._C._te.lgamma,
435
            torch.frac: torch._C._te.frac,
436
        }
437

438
        def construct_te_fn(op, n: int, dtype=torch.float32):
439
            A = torch._C._te.BufHandle("A", [n], dtype)
440

441
            def compute(i):
442
                return op(A.load([i]))
443

444
            C = te.Compute("C", [n], compute)
445

446
            loopnest = te.LoopNest([C])
447
            loopnest.prepare_for_codegen()
448
            stmt = te.simplify(loopnest.root_stmt())
449

450
            return te.construct_codegen("ir_eval", stmt, [A, C])
451

452
        n = 10
453
        a = torch.rand(n)
454
        for torch_op, te_op in unary_operators.items():
455
            ref = torch_op(a)
456

457
            te_fn = construct_te_fn(te_op, n, torch.float32)
458
            res = torch.empty(n)
459
            te_fn.call([a, res])
460
            assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
461

462

463
if __name__ == "__main__":
464
    run_tests()
465

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.