pytorch
2133 строки · 75.8 Кб
1#include <gtest/gtest.h>
2
3#include <ATen/code_template.h>
4#include <c10/util/irange.h>
5#include <test/cpp/tensorexpr/test_base.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/ir/irparser.h>
8#include <torch/csrc/jit/passes/constant_propagation.h>
9#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
10#include <torch/csrc/jit/tensorexpr/kernel.h>
11#include <torch/csrc/jit/tensorexpr/loopnest.h>
12#include <torch/csrc/jit/tensorexpr/tensor.h>
13#include <torch/csrc/jit/testing/file_check.h>
14#include <torch/torch.h>
15#include <cmath>
16#include <sstream>
17#include <stdexcept>
18
19namespace torch {
20namespace jit {
21
22using namespace torch::indexing;
23using namespace torch::jit::tensorexpr;
24
25class Kernel : public ::testing::Test {
26public:
27void SetUp() override {
28getTEMustUseLLVMOnCPU() = false;
29}
30};
31
32TEST_F(Kernel, ParallelExternalCallBuf) {
33const auto graph_string = R"IR(
34graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu),
35%1 : Float(1000, 5000, strides=[5000, 1], device=cpu),
36%2 : Float(5000, 1000, strides=[5000, 1], device=cpu)):
37%3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1)
38%4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2)
39return (%4))IR";
40auto graph = std::make_shared<Graph>();
41torch::jit::parseIR(graph_string, &*graph);
42const std::string& verification_pattern =
43R"IR(
44# CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR";
45
46#ifdef TORCH_ENABLE_LLVM
47TensorExprKernel k(graph);
48StmtPtr s = k.getCodeGenStmt();
49std::ostringstream oss;
50oss << *s;
51torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
52#endif
53}
54
55TEST_F(Kernel, InliningIntermediates) {
56// here, each mul has only one use, so it should be completely inlined
57{
58const auto graph_string = R"IR(
59graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
60%1 : Float(5, 3, strides=[3, 1], device=cpu)):
61%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
62%one : int = prim::Constant[value=1]()
63%4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
64%5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one)
65return (%5))IR";
66auto graph = std::make_shared<Graph>();
67parseIR(graph_string, &*graph);
68TensorExprKernel k(graph);
69auto stmt = k.getCodeGenStmt();
70std::ostringstream oss;
71oss << *stmt;
72torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str());
73}
74{
75const auto graph_template = R"IR(
76graph(%0 : Float(5, 3, strides=[3, 1], device=${device}),
77%1 : Float(5, 3, strides=[3, 1], device=${device})):
78%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
79%one : int = prim::Constant[value=1]()
80%3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one)
81%4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one)
82%5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0)
83return (%4, %5))IR";
84for (bool use_cuda : {false, true}) {
85if (!torch::cuda::is_available() && use_cuda) {
86continue;
87}
88
89at::jit::TemplateEnv env;
90env.s("device", use_cuda ? "cuda:0" : "cpu");
91const auto graph_string = format(graph_template, env);
92auto graph = std::make_shared<Graph>();
93parseIR(graph_string, &*graph);
94TensorExprKernel k(graph);
95auto stmt = k.getCodeGenStmt();
96std::ostringstream oss;
97oss << *stmt;
98// aten_mul only has one use, inlined completely
99torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str());
100
101// aten_sub should be removed by the CUDA backend by metavar rewriting
102// and by the CPU backend by horizontal fusion.
103torch::jit::testing::FileCheck().check_not("aten_sub")->run(oss.str());
104}
105}
106}
107
108TEST_F(Kernel, PreAllocIntermediateBufs) {
109const auto graph_string = R"IR(
110graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu),
111%b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)):
112%2 : int = prim::Constant[value=1]()
113%c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12
114%3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15
115return (%3))IR";
116auto graph = std::make_shared<Graph>();
117parseIR(graph_string, &*graph);
118
119auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
120auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
121auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
122auto ref = at::matmul(a, b) + a;
123TensorExprKernel k(graph, {}, {}, true);
124
125std::vector<at::Tensor> inputs = {a, b};
126auto stmt = k.getCodeGenStmt();
127
128std::ostringstream oss;
129oss << *stmt;
130
131// Check whether the intermediate buffer has been added to constants
132auto constants = k.getConstantDescriptors();
133ASSERT_EQ(constants.size(), 1);
134
135// Check the IR we produced
136torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str());
137torch::jit::testing::FileCheck().check_not("Free")->run(oss.str());
138
139// Check correctness
140std::vector<IValue> stack = fmap<IValue>(inputs);
141k.run(stack);
142o = stack[0].toTensor();
143ASSERT_TRUE(at::allclose(o, ref));
144}
145
146TEST_F(Kernel, _1) {
147const auto graph_string = R"IR(
148graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
149%1 : Float(5, 3, strides=[3, 1], device=cpu)):
150%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
151%3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
152return (%3))IR";
153auto graph = std::make_shared<Graph>();
154parseIR(graph_string, &*graph);
155
156auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
157auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
158auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
159auto ref = a * (a * b);
160TensorExprKernel k(graph);
161std::vector<at::Tensor> inputs = {a, b};
162StmtPtr s = k.getCodeGenStmt();
163
164std::ostringstream oss;
165oss << *s;
166
167// Check the IR we produced
168const std::string& verification_pattern =
169R"IR(
170# CHECK: for
171# CHECK-NEXT: for
172# CHECK-NOT: for)IR";
173torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
174
175std::vector<IValue> stack = fmap<IValue>(inputs);
176k.run(stack);
177o = stack[0].toTensor();
178for (size_t i = 0; i < 5 * 3; i++) {
179TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
180}
181}
182
183TEST_F(Kernel, _2) {
184const auto graph_string = R"IR(
185graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
186%1 : Float(5, 3, strides=[1, 5], device=cpu)):
187%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
188%3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
189return (%3))IR";
190auto graph = std::make_shared<Graph>();
191parseIR(graph_string, &*graph);
192
193auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
194auto b =
195at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
196auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
197auto ref = a * (a * b);
198TensorExprKernel k(graph);
199std::vector<at::Tensor> inputs = {a, b};
200StmtPtr s = k.getCodeGenStmt();
201
202std::ostringstream oss;
203oss << *s;
204
205// Check the IR we produced
206const std::string& verification_pattern =
207R"IR(
208# CHECK: for
209# CHECK-NEXT: for
210# CHECK-NOT: for)IR";
211torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
212
213std::vector<IValue> stack = fmap<IValue>(inputs);
214k.run(stack);
215o = stack[0].toTensor();
216for (size_t i = 0; i < 5 * 3; i++) {
217TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
218}
219}
220
221TEST_F(Kernel, _3) {
222const auto graph_string = R"IR(
223graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
224%1 : Float(5, 3, strides=[12, 2], device=cpu)):
225%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
226%3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
227return (%3))IR";
228auto graph = std::make_shared<Graph>();
229parseIR(graph_string, &*graph);
230
231auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
232auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
233.index({Slice(None, None, 2), Slice(None, None, 2)});
234auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
235auto ref = a * (a * b);
236TensorExprKernel k(graph);
237std::vector<at::Tensor> inputs = {a, b};
238StmtPtr s = k.getCodeGenStmt();
239
240std::ostringstream oss;
241oss << *s;
242
243// Check the IR we produced
244const std::string& verification_pattern =
245R"IR(
246# CHECK: for
247# CHECK-NEXT: for
248# CHECK-NOT: for)IR";
249torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
250
251std::vector<IValue> stack = fmap<IValue>(inputs);
252k.run(stack);
253o = stack[0].toTensor();
254for (size_t i = 0; i < 5 * 3; i++) {
255TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
256}
257}
258
259TEST_F(Kernel, Huge) {
260const auto graph_string = R"IR(
261graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)):
262%1 : int = prim::Constant[value=0]()
263%2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1)
264%3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2)
265return (%3))IR";
266auto graph = std::make_shared<Graph>();
267parseIR(graph_string, &*graph);
268TensorExprKernel k(graph);
269std::ostringstream oss;
270oss << *k.getCodeGenStmt();
271// The 4000000000 iterations loop will be split into 500000000 x 8 and the
272// outer loop will be parallel. If LLVM is not present, it will not be split,
273// and to cover both of these cases we're looking for 00000000ll; in the
274// output.
275const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR";
276torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
277}
278
279TEST_F(Kernel, ParallelStrided) {
280const auto graph_string = R"IR(
281graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu),
282%1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)):
283%2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1)
284%3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2)
285return (%3))IR";
286auto graph = std::make_shared<Graph>();
287parseIR(graph_string, &*graph);
288
289auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat));
290auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat))
291.index(
292{Slice(None, None, 2),
293Slice(None, None, 2),
294Slice(None, None, 2)});
295auto ref = a * (a * b);
296auto o = at::zeros_like(ref);
297TensorExprKernel k(graph);
298std::vector<at::Tensor> inputs = {a, b};
299std::vector<IValue> stack = fmap<IValue>(inputs);
300k.run(stack);
301o = stack[0].toTensor();
302for (size_t i = 0; i < 5 * 3; i++) {
303TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
304}
305}
306
307TEST_F(Kernel, DISABLED_Shape_Inference) {
308// disabled: doesn't do stride propagation, and isn't being used currently
309
310// Test TensorExpr shape inference capabilities: it should only require shapes
311// for the inputs
312{
313const auto graph_string = R"IR(
314graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
315%1 : Float(5, 3, strides=[12, 2], device=cpu)):
316%2 : Tensor = aten::mul(%0, %1)
317%3 : Tensor = aten::mul(%0, %2)
318return (%3))IR";
319auto graph = std::make_shared<Graph>();
320parseIR(graph_string, &*graph);
321
322auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
323auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
324.index({Slice(None, None, 2), Slice(None, None, 2)});
325auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
326auto ref = a * (a * b);
327TensorExprKernel k(graph);
328std::vector<at::Tensor> inputs = {a, b};
329StmtPtr s = k.getCodeGenStmt();
330
331std::ostringstream oss;
332oss << *s;
333
334// Check the IR we produced
335const std::string& verification_pattern =
336R"IR(
337# CHECK: for
338# CHECK-NEXT: for
339# CHECK-NOT: for)IR";
340torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
341
342std::vector<IValue> stack = fmap<IValue>(inputs);
343k.run(stack);
344o = stack[0].toTensor();
345for (size_t i = 0; i < 5 * 3; i++) {
346TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
347}
348}
349{
350const auto graph_string = R"IR(
351graph(%0 : Float(8, 8, strides=[8, 1], device=cpu),
352%1 : Float(8, 8, strides=[8, 1], device=cpu)):
353%2 : Tensor = aten::mul(%0, %1)
354%3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2)
355%r : Tensor = aten::mul(%3, %4)
356return (%r))IR";
357auto graph = std::make_shared<Graph>();
358parseIR(graph_string, &*graph);
359
360auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
361auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
362auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat));
363auto t = torch::chunk(a * b, 2, 1);
364auto ref = t[0] * t[1];
365TensorExprKernel k(graph);
366std::vector<at::Tensor> inputs = {a, b};
367StmtPtr s = k.getCodeGenStmt();
368
369std::ostringstream oss;
370oss << *s;
371
372// Check the IR we produced
373const std::string& verification_pattern =
374R"IR(
375# CHECK: for)IR";
376torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
377
378std::vector<IValue> stack = fmap<IValue>(inputs);
379k.run(stack);
380o = stack[0].toTensor();
381TORCH_CHECK_EQ(o.sizes()[0], 8);
382TORCH_CHECK_EQ(o.sizes()[1], 4);
383for (size_t i = 0; i < 8 * 4; i++) {
384TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
385}
386}
387{
388// Test that shape inference handles aten::unsqueeze
389
390const auto graph_string = R"IR(
391graph(%a : Float(4, 2, strides=[2, 1], device=cpu),
392%b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu),
393%c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)):
394%one : int = prim::Constant[value=1]()
395%minus_one : int = prim::Constant[value=-1]()
396%three : int = prim::Constant[value=3]()
397%minus_four : int = prim::Constant[value=-4]()
398%a1 : Tensor = aten::unsqueeze(%a, %one) # new size: [4,1,2]
399%a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1]
400%b1 : Tensor = aten::unsqueeze(%b, %three) # new size: [4,3,2,1]
401%c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2]
402%ab : Tensor = aten::mul(%a2, %b1) # expected size: [4,3,2,1]
403%abc : Tensor = aten::mul(%ab, %c1) # expected size: [4,3,2,2]
404return (%abc))IR";
405auto graph = std::make_shared<Graph>();
406parseIR(graph_string, &*graph);
407
408auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat));
409auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
410auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
411auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
412auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) *
413at::unsqueeze(c, -4);
414
415TensorExprKernel k(graph);
416std::vector<at::Tensor> inputs = {a, b, c};
417StmtPtr s = k.getCodeGenStmt();
418
419std::ostringstream oss;
420oss << *s;
421
422// Check the IR we produced
423const std::string& verification_pattern =
424R"IR(
425# CHECK: for
426# CHECK-NEXT: for
427# CHECK-NEXT: for
428# CHECK-NEXT: for
429# CHECK-NEXT: aten_mul)IR";
430torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
431
432std::vector<IValue> stack = fmap<IValue>(inputs);
433k.run(stack);
434o = stack[0].toTensor();
435
436// Check sizes
437TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
438size_t num_el = 1;
439for (const auto idx : c10::irange(ref.sizes().size())) {
440TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
441num_el *= ref.sizes()[idx];
442}
443
444// Check the contents
445for (const auto i : c10::irange(num_el)) {
446TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
447}
448}
449{
450// Test that shape inference handles aten::cat
451
452const auto graph_string = R"IR(
453graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
454%b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
455%c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)):
456%dim : int = prim::Constant[value=1]()
457%inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
458%r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2]
459return (%r))IR";
460auto graph = std::make_shared<Graph>();
461parseIR(graph_string, &*graph);
462
463auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
464auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
465auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat));
466auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat));
467auto ref = at::cat({a, b, c}, 1);
468
469TensorExprKernel k(graph);
470std::vector<at::Tensor> inputs = {a, b, c};
471StmtPtr s = k.getCodeGenStmt();
472
473std::ostringstream oss;
474oss << *s;
475
476// Check the IR we produced
477const std::string& verification_pattern =
478R"IR(
479# CHECK: for
480# CHECK-NEXT: for
481# CHECK-NEXT: for
482# CHECK-NEXT: aten_cat)IR";
483torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
484
485std::vector<IValue> stack = fmap<IValue>(inputs);
486k.run(stack);
487o = stack[0].toTensor();
488
489// Check sizes
490TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
491size_t num_el = 1;
492for (const auto idx : c10::irange(ref.sizes().size())) {
493TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
494num_el *= ref.sizes()[idx];
495}
496
497// Check the contents
498for (const auto i : c10::irange(num_el)) {
499TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
500}
501}
502{
503// Test that we throw an error when input list for aten::cat is empty
504
505const auto graph_string = R"IR(
506graph():
507%dim : int = prim::Constant[value=1]()
508%inputs : Tensor[] = prim::ListConstruct()
509%r : Tensor = aten::cat(%inputs, %dim)
510return (%r))IR";
511auto graph = std::make_shared<Graph>();
512parseIR(graph_string, &*graph);
513auto compile = [&]() {
514TensorExprKernel k(graph);
515k.getCodeGenStmt();
516};
517ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat");
518}
519{
520// Test that we throw an error when 'dim' passed to aten::cat is invalid
521
522const auto ir_dim_99 = R"IR(
523graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
524%b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)):
525%dim : int = prim::Constant[value=99]()
526%inputs : Tensor[] = prim::ListConstruct(%a, %b)
527%r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim)
528return (%r))IR";
529const auto ir_dim_minus_6 = R"IR(
530graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
531%b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)):
532%dim : int = prim::Constant[value=-6]()
533%inputs : Tensor[] = prim::ListConstruct(%a, %b)
534%r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim)
535return (%r))IR";
536
537auto compile = [](const std::string& graph_string) {
538auto graph = std::make_shared<Graph>();
539parseIR(graph_string, &*graph);
540TensorExprKernel k(graph);
541k.getCodeGenStmt();
542};
543ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index");
544ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index");
545}
546}
547
548TEST_F(Kernel, CatInputTypesPromotion) {
549{
550// Test that we properly promote input types for aten::cat
551
552const auto graph_string = R"IR(
553graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
554%b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
555%c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)):
556%dim : int = prim::Constant[value=1]()
557%inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
558%r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim)
559return (%r))IR";
560auto graph = std::make_shared<Graph>();
561parseIR(graph_string, &*graph);
562
563auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
564auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
565auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble));
566auto ref = at::cat({a, b, c}, 1);
567
568TensorExprKernel k(graph);
569std::vector<at::Tensor> inputs = {a, b, c};
570StmtPtr s = k.getCodeGenStmt();
571
572std::ostringstream oss;
573oss << *s;
574
575// Check the IR we produced
576const std::string& verification_pattern =
577R"IR(
578# CHECK: for
579# CHECK-NEXT: for
580# CHECK-NEXT: for
581# CHECK-NEXT: aten_cat)IR";
582torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
583
584std::vector<IValue> stack = fmap<IValue>(inputs);
585k.run(stack);
586auto o = stack[0].toTensor();
587
588// Check sizes
589TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
590TORCH_CHECK_EQ(o.dtype(), ref.dtype());
591size_t num_el = 1;
592for (const auto idx : c10::irange(ref.sizes().size())) {
593TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
594num_el *= ref.sizes()[idx];
595}
596
597// Check the contents
598for (const auto i : c10::irange(num_el)) {
599TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]);
600}
601}
602}
603
604TEST_F(Kernel, ToDType) {
605#ifdef TORCH_ENABLE_LLVM
606const auto graph_string = R"IR(
607graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
608%1 : NoneType = prim::Constant()
609%2 : bool = prim::Constant[value=0]()
610%3 : int = prim::Constant[value=6]()
611%4 : int = prim::Constant[value=15]()
612%5 : int = prim::Constant[value=5]()
613%6 : bool = prim::Constant[value=1]()
614%y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1)
615%z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4)
616%h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6)
617%i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1)
618%j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1)
619%k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1)
620return (%k.3))IR";
621
622auto graph = std::make_shared<Graph>();
623parseIR(graph_string, &*graph);
624TensorExprKernel k(graph);
625StmtPtr s = k.getCodeGenStmt();
626std::ostringstream oss;
627oss << *s;
628
629const std::string& verification_pattern =
630R"IR(
631# CHECK: for
632# CHECK-NEXT: for
633# CHECK-NEXT: aten_to
634# CHECK-NEXT: }
635# CHECK-NEXT: })IR";
636torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
637
638auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16));
639auto ref =
640at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat));
641
642std::vector<at::Tensor> inputs = {a};
643std::vector<IValue> stack = fmap<IValue>(inputs);
644k.run(stack);
645auto o = stack[0].toTensor();
646ASSERT_EQ(o.sizes(), ref.sizes());
647ASSERT_EQ(o.dtype(), ref.dtype());
648ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
649#endif
650}
651
652TEST_F(Kernel, CatAndInlineWithAConstantDim) {
653const auto graph_string = R"IR(
654graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu),
655%1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)):
656%2 : bool = prim::Constant[value=0]()
657%3 : int = prim::Constant[value=1]()
658%4 : Tensor[] = prim::ListConstruct(%0, %1)
659%5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3)
660%6 : Tensor[] = prim::ListConstruct(%5)
661%7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3)
662%8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2)
663return (%8, %7))IR";
664
665auto graph = std::make_shared<Graph>();
666parseIR(graph_string, &*graph);
667TensorExprKernel k(graph);
668
669auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
670auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
671auto ref = at::_cast_Float(at::cat({a, b}, 1), 0);
672
673std::vector<at::Tensor> inputs = {a, b};
674std::vector<IValue> stack = fmap<IValue>(inputs);
675k.run(stack);
676auto o = stack[0].toTensor();
677ASSERT_EQ(o.sizes(), ref.sizes());
678ASSERT_EQ(o.dtype(), ref.dtype());
679ASSERT_TRUE(at::allclose(o, ref));
680}
681
682TEST_F(Kernel, CatWithEmptyInputs) {
683bool curr_cat_wo_conditionals = getCatWoConditionals();
684for (auto cat_wo_conditionals : {true, false}) {
685getCatWoConditionals() = cat_wo_conditionals;
686const auto graph_string = R"IR(
687graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu),
688%1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)):
689%3 : int = prim::Constant[value=0]()
690%6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0)
691%7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1)
692%10 : Tensor[] = prim::ListConstruct(%6, %7)
693%11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3)
694return (%11))IR";
695
696auto graph = std::make_shared<Graph>();
697parseIR(graph_string, &*graph);
698TensorExprKernel k(graph);
699
700auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat));
701auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat));
702auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0);
703
704std::vector<at::Tensor> inputs = {a, b};
705std::vector<IValue> stack = fmap<IValue>(inputs);
706k.run(stack);
707auto o = stack[0].toTensor();
708ASSERT_EQ(o.sizes(), ref.sizes());
709ASSERT_EQ(o.dtype(), ref.dtype());
710ASSERT_TRUE(at::allclose(o, ref));
711}
712getCatWoConditionals() = curr_cat_wo_conditionals;
713}
714
715TEST_F(Kernel, CatWoConditionals) {
716bool old_cat_wo_conditionals = getCatWoConditionals();
717getCatWoConditionals() = true;
718const auto graph_string = R"IR(
719graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
720%b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
721%c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)):
722%dim : int = prim::Constant[value=1]()
723%inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
724%r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim)
725return (%r))IR";
726
727auto graph = std::make_shared<Graph>();
728parseIR(graph_string, &*graph);
729
730TensorExprKernel k(graph);
731StmtPtr s = k.getCodeGenStmt();
732std::ostringstream oss;
733oss << *s;
734
735const std::string& verification_pattern =
736R"IR(
737# CHECK: for
738# CHECK: for
739# CHECK: for
740# CHECK: aten_cat
741# CHECK: for
742# CHECK: for
743# CHECK: aten_cat
744# CHECK: for
745# CHECK: for
746# CHECK: aten_cat)IR";
747torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
748
749auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
750auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
751auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat));
752auto ref = at::cat({a, b, c}, 1);
753
754std::vector<at::Tensor> inputs = {a, b, c};
755std::vector<IValue> stack = fmap<IValue>(inputs);
756k.run(stack);
757auto o = stack[0].toTensor();
758
759// Check sizes
760TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
761TORCH_CHECK_EQ(o.dtype(), ref.dtype());
762size_t num_el = 1;
763for (const auto idx : c10::irange(ref.sizes().size())) {
764TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
765num_el *= ref.sizes()[idx];
766}
767
768// Check the contents
769for (const auto i : c10::irange(num_el)) {
770TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
771}
772getCatWoConditionals() = old_cat_wo_conditionals;
773}
774
775TEST_F(Kernel, OptimizeConditionals) {
776bool old_cat_wo_conditionals = getCatWoConditionals();
777bool old_opt_conditionals = getOptConditionals();
778getCatWoConditionals() = false;
779getOptConditionals() = true;
780const auto graph_string = R"IR(
781graph(%a : Float(5, 3, strides=[3, 1], device=cpu),
782%b : Float(5, 7, strides=[7, 1], device=cpu),
783%c : Float(5, 9, strides=[9, 1], device=cpu)):
784%dim : int = prim::Constant[value=1]()
785%inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
786%r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim)
787%t : Float(5, 19, strides=[19, 1]) = aten::relu(%r)
788return (%t))IR";
789
790auto graph = std::make_shared<Graph>();
791parseIR(graph_string, &*graph);
792
793TensorExprKernel k(graph);
794StmtPtr s = k.getCodeGenStmt();
795std::ostringstream oss;
796oss << *s;
797
798const std::string& verification_pattern =
799R"IR(
800# CHECK: for
801# CHECK-NEXT: for
802# CHECK-NEXT: aten_relu
803# CHECK: for
804# CHECK-NEXT: aten_relu
805# CHECK: for
806# CHECK-NEXT: aten_relu
807# CHECK-NOT: Allocate
808# CHECK-NOT: Free)IR";
809torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
810
811// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
812auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
813// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
814auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat));
815// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
816auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat));
817auto ref = at::relu(at::cat({a, b, c}, 1));
818
819std::vector<at::Tensor> inputs = {a, b, c};
820std::vector<IValue> stack = fmap<IValue>(inputs);
821k.run(stack);
822auto o = stack[0].toTensor();
823
824// Check sizes
825TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
826TORCH_CHECK_EQ(o.dtype(), ref.dtype());
827size_t num_el = 1;
828for (const auto idx : c10::irange(ref.sizes().size())) {
829TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
830num_el *= ref.sizes()[idx];
831}
832
833// Check the contents
834for (const auto i : c10::irange(num_el)) {
835TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
836}
837getOptConditionals() = old_opt_conditionals;
838getCatWoConditionals() = old_cat_wo_conditionals;
839}
840
841namespace {
842
843std::string dtypeConstant(ScalarType scalar_type) {
844if (scalar_type == ScalarType::Undefined) {
845return "None = prim::Constant()";
846} else {
847at::jit::TemplateEnv env_dtype;
848env_dtype.d("scalar_type", static_cast<int>(scalar_type));
849return format("int = prim::Constant[value=${scalar_type}]()", env_dtype);
850}
851}
852
853at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) {
854int64_t numel = std::accumulate(
855sizes.begin(),
856sizes.end(),
8571,
858// NOLINTNEXTLINE(modernize-use-transparent-functors)
859std::multiplies<int64_t>());
860std::vector<float> values(numel);
861std::iota(values.begin(), values.end(), 0);
862auto a = at::tensor(values, options);
863return a.reshape(sizes);
864}
865
866} // namespace
867
868TEST_F(Kernel, SumAllAxes) {
869// Test lowering of sum on all axes.
870const auto graph_template = R"IR(
871graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
872%1 : ${dtype}
873%2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1)
874return (%2))IR";
875auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
876
877for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) {
878at::jit::TemplateEnv env;
879env.s("dtype", dtypeConstant(scalar_type));
880if (scalar_type == ScalarType::Undefined) {
881env.s("out_dtype", "Float");
882} else {
883env.s("out_dtype", "Double");
884}
885const auto graph_string = format(graph_template, env);
886
887auto graph = std::make_shared<Graph>();
888parseIR(graph_string, &*graph);
889
890auto o = at::empty({}, TensorOptions(kCPU));
891std::optional<c10::ScalarType> dtype;
892if (scalar_type != ScalarType::Undefined) {
893dtype = static_cast<c10::ScalarType>(scalar_type);
894}
895auto ref = a.sum(/*dtype=*/dtype);
896TensorExprKernel k(graph);
897std::vector<at::Tensor> inputs = {a};
898StmtPtr s = k.getCodeGenStmt();
899
900std::ostringstream oss;
901oss << *s;
902
903// Check the IR we produced
904const std::string& verification_pattern =
905R"IR(
906# CHECK: for
907# CHECK-NEXT: for)IR";
908torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
909
910std::vector<IValue> stack = fmap<IValue>(inputs);
911k.run(stack);
912o = stack[0].toTensor();
913ASSERT_EQ(o.sizes(), ref.sizes());
914ASSERT_EQ(o.dtype(), ref.dtype());
915ASSERT_TRUE(at::allclose(o, ref));
916}
917}
918
919std::string li_to_str(at::ArrayRef<int64_t> li) {
920std::stringstream out;
921bool first = true;
922for (auto elem : li) {
923if (!first) {
924out << ", ";
925}
926out << elem;
927first = false;
928}
929return out.str();
930}
931
932TEST_F(Kernel, SumOneAxis) {
933// Test lowering of sum on one axis.
934const auto graph_template = R"IR(
935graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
936%1 : int[] = prim::Constant[value=[${dim}]]()
937%2 : bool = prim::Constant[value=${keepdim}]()
938%3 : ${dtype}
939%4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3)
940return (%4))IR";
941auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
942
943for (int dim = -a.dim(); dim < a.dim(); ++dim) {
944for (bool keepdim : {false, true}) {
945for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) {
946at::jit::TemplateEnv env;
947env.d("dim", dim);
948env.d("keepdim", keepdim);
949env.s("dtype", dtypeConstant(scalar_type));
950std::optional<c10::ScalarType> dtype;
951if (scalar_type != ScalarType::Undefined) {
952dtype = static_cast<c10::ScalarType>(scalar_type);
953}
954auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype);
955if (scalar_type == ScalarType::Undefined) {
956env.s("out_dtype", "Float");
957} else {
958env.s("out_dtype", "Double");
959}
960env.s("size", li_to_str(ref.sizes()));
961env.s("strides", li_to_str(ref.strides()));
962const auto graph_string = format(graph_template, env);
963auto graph = std::make_shared<Graph>();
964parseIR(graph_string, &*graph);
965
966auto o = at::empty({}, TensorOptions(kCPU));
967TensorExprKernel k(graph);
968std::vector<at::Tensor> inputs = {a};
969StmtPtr s = k.getCodeGenStmt();
970
971std::ostringstream oss;
972oss << *s;
973
974// Check the IR we produced
975const std::string& verification_pattern =
976R"IR(
977# CHECK: for (int64_t
978# CHECK-NEXT: sum
979# CHECK-NEXT: for (int64_t
980# CHECK-NEXT: sum)IR";
981torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
982
983std::vector<IValue> stack = fmap<IValue>(inputs);
984k.run(stack);
985o = stack[0].toTensor();
986ASSERT_EQ(o.sizes(), ref.sizes());
987ASSERT_EQ(o.dtype(), ref.dtype());
988ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
989}
990}
991}
992}
993
994TEST_F(Kernel, SumMultipleAxes) {
995// Test lowering of sum on multiple axes.
996const auto graph_template = R"IR(
997graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)):
998%1 : int = prim::Constant[value=${dim1}]()
999%2 : int = prim::Constant[value=${dim2}]()
1000%3 : int[] = prim::ListConstruct(%1, %2)
1001%4 : bool = prim::Constant[value=${keepdim}]()
1002%5 : ${dtype}
1003%6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5)
1004return (%6))IR";
1005auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1006
1007// Only iterate over positive values of axes to keep the running time
1008// reasonable, since the number of pairs is quadratic.
1009for (const auto dim1 : c10::irange(a.dim())) {
1010for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) {
1011for (bool keepdim : {false, true}) {
1012at::jit::TemplateEnv env;
1013env.d("dim1", dim1);
1014env.d("dim2", dim2);
1015env.d("keepdim", keepdim);
1016env.s("dtype", dtypeConstant(ScalarType::Undefined));
1017auto o = at::empty({}, TensorOptions(kCPU));
1018auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim);
1019
1020env.s("size", li_to_str(ref.sizes()));
1021env.s("strides", li_to_str(ref.strides()));
1022
1023const auto graph_string = format(graph_template, env);
1024
1025auto graph = std::make_shared<Graph>();
1026parseIR(graph_string, &*graph);
1027
1028TensorExprKernel k(graph);
1029std::vector<at::Tensor> inputs = {a};
1030StmtPtr s = k.getCodeGenStmt();
1031
1032std::ostringstream oss;
1033oss << *s;
1034
1035// Check the IR we produced
1036const std::string& verification_pattern =
1037R"IR(
1038# CHECK: for (int64_t
1039# CHECK: for (int64_t
1040# CHECK: for (int64_t
1041# CHECK: for (int64_t
1042# CHECK: sum)IR";
1043torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1044
1045std::vector<IValue> stack = fmap<IValue>(inputs);
1046k.run(stack);
1047o = stack[0].toTensor();
1048ASSERT_EQ(o.sizes(), ref.sizes());
1049ASSERT_EQ(o.dtype(), ref.dtype());
1050ASSERT_TRUE(at::allclose(o, ref));
1051}
1052}
1053}
1054}
1055
1056// This test and the following ones testing Softmax only tests with dim set
1057// to one of the valid input dimensions. It does not test with dim=None
1058// because that is supposed to be deprecated.
1059TEST_F(Kernel, Softmax2D) {
1060const auto graph_template = R"IR(
1061graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
1062%1 : int = prim::Constant[value=${dim}]()
1063%dt_float : int = prim::Constant[value=7]()
1064%dt_none : NoneType = prim::Constant()
1065%4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt})
1066return (%4))IR";
1067
1068auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1069
1070const std::string& verification_template =
1071R"IR(
1072# CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size}
1073# CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1074# CHECK-NEXT: aten_softmax_max
1075# CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size}
1076# CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1077# CHECK-NEXT: aten_softmax_sum
1078# CHECK: for (int i0_2 = 0; i0_2 < 5
1079# CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
1080# CHECK-NEXT: aten_softmax)IR";
1081
1082for (bool empty_dtype : {false, true}) {
1083for (auto log_softmax : {false, true}) {
1084for (const auto softmax_dim : c10::irange(a.dim())) {
1085auto softmax_dim_size = a.sizes()[softmax_dim];
1086auto other_dim = (softmax_dim + 1) % a.dim();
1087auto ref =
1088log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1089at::jit::TemplateEnv env;
1090env.d("dim", softmax_dim);
1091env.s("op", log_softmax ? "log_softmax" : "softmax");
1092env.s("size", li_to_str(ref.sizes()));
1093env.s("strides", li_to_str(ref.strides()));
1094env.s("dt", empty_dtype ? "dt_none" : "dt_float");
1095
1096const auto graph_string = format(graph_template, env);
1097
1098auto graph = std::make_shared<Graph>();
1099parseIR(graph_string, &*graph);
1100
1101TensorExprKernel k(graph);
1102std::vector<at::Tensor> inputs = {a};
1103StmtPtr s = k.getCodeGenStmt();
1104
1105std::ostringstream oss;
1106oss << *s;
1107
1108at::jit::TemplateEnv ver_env;
1109ver_env.d("other_dim", other_dim);
1110ver_env.d("other_dim_size", a.sizes()[other_dim]);
1111ver_env.d("softmax_dim", softmax_dim);
1112ver_env.d("softmax_dim_size", softmax_dim_size);
1113const auto verification_pattern =
1114format(verification_template, ver_env);
1115
1116// verification sting temporarily disabled until
1117// inlining of exp() is benchmarked and determined
1118// torch::jit::testing::FileCheck().run(verification_pattern,
1119// oss.str());
1120
1121std::vector<IValue> stack = fmap<IValue>(inputs);
1122k.run(stack);
1123auto output = stack[0].toTensor();
1124ASSERT_EQ(output.sizes(), ref.sizes());
1125ASSERT_TRUE(at::allclose(output, ref));
1126}
1127}
1128}
1129}
1130
1131TEST_F(Kernel, Softmax3D) {
1132const auto graph_template = R"IR(
1133graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)):
1134%1 : int = prim::Constant[value=${dim}]()
1135%2 : int = prim::Constant[value=7]()
1136%3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2)
1137return (%3))IR";
1138
1139auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat));
1140
1141const std::string& verification_template =
1142R"IR(
1143# CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size}
1144# CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size}
1145# CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1146# CHECK-NEXT: aten_softmax_max
1147# CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size}
1148# CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size}
1149# CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1150# CHECK-NEXT: aten_softmax_sum
1151# CHECK: for (int i0_2 = 0; i0_2 < 3
1152# CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4
1153# CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5
1154# CHECK-NEXT: aten_softmax)IR";
1155
1156for (auto log_softmax : {false, true}) {
1157for (const auto softmax_dim : c10::irange(a.dim())) {
1158auto softmax_dim_size = a.sizes()[softmax_dim];
1159std::vector<int> other_dims;
1160for (const auto i : c10::irange(a.dim())) {
1161if (i != softmax_dim) {
1162other_dims.push_back(i);
1163}
1164}
1165auto ref =
1166log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1167
1168at::jit::TemplateEnv env;
1169env.d("dim", softmax_dim);
1170env.s("op", log_softmax ? "log_softmax" : "softmax");
1171env.s("size", li_to_str(ref.sizes()));
1172env.s("strides", li_to_str(ref.strides()));
1173
1174const auto graph_string = format(graph_template, env);
1175
1176auto graph = std::make_shared<Graph>();
1177parseIR(graph_string, &*graph);
1178
1179TensorExprKernel k(graph);
1180std::vector<at::Tensor> inputs = {a};
1181StmtPtr s = k.getCodeGenStmt();
1182
1183std::ostringstream oss;
1184oss << *s;
1185
1186at::jit::TemplateEnv ver_env;
1187ver_env.d("dim1", other_dims[0]);
1188ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
1189ver_env.d("dim2", other_dims[1]);
1190ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
1191ver_env.d("softmax_dim", softmax_dim);
1192ver_env.d("softmax_dim_size", softmax_dim_size);
1193const auto verification_pattern = format(verification_template, ver_env);
1194
1195// verification sting temporarily disabled until
1196// inlining of exp() is benchmarked and determined
1197// torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1198
1199std::vector<IValue> stack = fmap<IValue>(inputs);
1200k.run(stack);
1201auto output = stack[0].toTensor();
1202
1203ASSERT_EQ(output.sizes(), ref.sizes());
1204ASSERT_TRUE(at::allclose(output, ref));
1205}
1206}
1207}
1208
1209TEST_F(Kernel, Softmax4D) {
1210const auto graph_template = R"IR(
1211graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)):
1212%1 : int = prim::Constant[value=${dim}]()
1213%2 : int = prim::Constant[value=7]()
1214%3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2)
1215return (%3))IR";
1216
1217auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1218
1219const std::string& verification_template =
1220R"IR(
1221# CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size}
1222# CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size}
1223# CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size}
1224# CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1225# CHECK-NEXT: aten_softmax_max
1226# CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size}
1227# CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size}
1228# CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size}
1229# CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1230# CHECK-NEXT: aten_softmax_sum
1231# CHECK: for (int i0_2 = 0; i0_2 < 2
1232# CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
1233# CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2
1234# CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3
1235# CHECK-NEXT: aten_softmax)IR";
1236
1237for (auto log_softmax : {false, true}) {
1238for (const auto softmax_dim : c10::irange(a.dim())) {
1239auto softmax_dim_size = a.sizes()[softmax_dim];
1240std::vector<int> other_dims;
1241for (const auto i : c10::irange(a.dim())) {
1242if (i != softmax_dim) {
1243other_dims.push_back(i);
1244}
1245}
1246auto ref =
1247log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1248
1249at::jit::TemplateEnv env;
1250env.d("dim", softmax_dim);
1251env.s("op", log_softmax ? "log_softmax" : "softmax");
1252env.s("size", li_to_str(ref.sizes()));
1253env.s("strides", li_to_str(ref.strides()));
1254
1255const auto graph_string = format(graph_template, env);
1256
1257auto graph = std::make_shared<Graph>();
1258parseIR(graph_string, &*graph);
1259
1260TensorExprKernel k(graph);
1261std::vector<at::Tensor> inputs = {a};
1262StmtPtr s = k.getCodeGenStmt();
1263
1264std::ostringstream oss;
1265oss << *s;
1266
1267at::jit::TemplateEnv ver_env;
1268ver_env.d("dim1", other_dims[0]);
1269ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
1270ver_env.d("dim2", other_dims[1]);
1271ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
1272ver_env.d("dim3", other_dims[2]);
1273ver_env.d("dim3_size", a.sizes()[other_dims[2]]);
1274ver_env.d("softmax_dim", softmax_dim);
1275ver_env.d("softmax_dim_size", softmax_dim_size);
1276const auto verification_pattern = format(verification_template, ver_env);
1277
1278// verification sting temporarily disabled until
1279// inlining of exp() is benchmarked and determined
1280// torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1281
1282std::vector<IValue> stack = fmap<IValue>(inputs);
1283k.run(stack);
1284auto output = stack[0].toTensor();
1285ASSERT_EQ(output.sizes(), ref.sizes());
1286ASSERT_TRUE(at::allclose(output, ref));
1287}
1288}
1289}
1290
1291TEST_F(Kernel, SignTest) {
1292const auto graph_template = R"IR(
1293graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)):
1294%2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0)
1295return (%2))IR";
1296
1297auto run_test = [](const std::string& graph_string, const at::Tensor& input) {
1298auto graph = std::make_shared<Graph>();
1299parseIR(graph_string, &*graph);
1300
1301TensorExprKernel k(graph);
1302StmtPtr s = k.getCodeGenStmt();
1303
1304std::vector<at::Tensor> inputs = {input};
1305std::vector<IValue> stack = fmap<IValue>(inputs);
1306k.run(stack);
1307auto o = stack[0].toTensor();
1308auto ref = at::sign(input);
1309ASSERT_TRUE(at::allclose(o, ref));
1310};
1311auto common_options = at::TensorOptions()
1312.layout(at::kStrided)
1313.device(at::kCPU)
1314.requires_grad(false);
1315int default_input_size = 100;
1316for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) {
1317at::Tensor corner_case_inputs;
1318at::jit::TemplateEnv env;
1319auto options = common_options;
1320switch (scalar_type) {
1321case ScalarType::Float: {
1322env.s("dtype", "Float");
1323options = options.dtype(at::kFloat);
1324std::vector<float> input_float = {
13250.0f,
1326-0.0f,
1327std::numeric_limits<float>::infinity(),
1328-std::numeric_limits<float>::infinity(),
1329std::nanf("1"),
1330-std::nanf("1")};
1331corner_case_inputs = at::from_blob(
1332input_float.data(),
1333{static_cast<long>(input_float.size())},
1334options);
1335auto rand_input = at::rand({default_input_size}, options);
1336auto input = at::cat({rand_input, corner_case_inputs});
1337env.d("size", at::numel(input));
1338const auto graph_string = format(graph_template, env);
1339run_test(graph_string, input);
1340break;
1341}
1342case ScalarType::Double: {
1343env.s("dtype", "Double");
1344options = options.dtype(at::kDouble);
1345std::vector<double> input_double = {
13460.0,
1347-0.0,
1348std::numeric_limits<double>::infinity(),
1349-std::numeric_limits<double>::infinity(),
1350std::nan("1"),
1351-std::nan("1")};
1352corner_case_inputs = at::from_blob(
1353input_double.data(),
1354{static_cast<long>(input_double.size())},
1355options);
1356auto rand_input = at::rand({default_input_size}, options);
1357auto input = at::cat({rand_input, corner_case_inputs});
1358env.d("size", at::numel(input));
1359const auto graph_string = format(graph_template, env);
1360run_test(graph_string, input);
1361break;
1362}
1363default:
1364throw unsupported_dtype();
1365}
1366}
1367}
1368
1369TEST_F(Kernel, InlineProducerIntoReduction) {
1370// Inline producer (mul) into reduction (sum).
1371const auto graph_string = R"IR(
1372graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1373%1 : Float(5, 3, strides=[3, 1], device=cpu)):
1374%2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
1375%3 : int = prim::Constant[value=7]()
1376%4 : Double(device=cpu) = aten::sum(%2, %3)
1377return (%4))IR";
1378auto graph = std::make_shared<Graph>();
1379parseIR(graph_string, &*graph);
1380
1381TensorExprKernel k(graph);
1382StmtPtr s = k.getCodeGenStmt();
1383std::ostringstream oss;
1384oss << *s;
1385
1386// Check the IR we produced.
1387// We should have only one loop in the end.
1388const std::string& verification_pattern =
1389R"IR(
1390# CHECK: for (int64_t i_1 = 0ll; i_1 < 5
1391# CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3
1392# CHECK-NEXT: sum
1393# CHECK-NOT: for)IR";
1394torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1395
1396auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1397auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1398std::vector<at::Tensor> inputs = {a, b};
1399std::vector<IValue> stack = fmap<IValue>(inputs);
1400k.run(stack);
1401auto o = stack[0].toTensor();
1402auto ref = (a * b).sum(at::kDouble);
1403ASSERT_TRUE(at::allclose(o, ref));
1404}
1405
1406TEST_F(Kernel, InlineReductionIntoConsumer) {
1407// Inline producer (mul %2) into reduction (sum %4) but DO NOT
1408// inline the reduction into consumer (mul %4).
1409const auto graph_string = R"IR(
1410graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1411%1 : Float(5, 3, strides=[3, 1], device=cpu)):
1412%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1413%3 : int = prim::Constant[value=6]()
1414%4 : Float(device=cpu) = aten::sum(%2, %3)
1415%5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4)
1416return (%5))IR";
1417auto graph = std::make_shared<Graph>();
1418parseIR(graph_string, &*graph);
1419
1420TensorExprKernel k(graph);
1421StmtPtr s = k.getCodeGenStmt();
1422std::ostringstream oss;
1423oss << *s;
1424
1425// Check the IR we produced.
1426// We should have two loops in the end.
1427const std::string& verification_pattern =
1428R"IR(
1429# CHECK: for (int64_t i_1 = 0ll; i_1 < 5
1430# CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3
1431# CHECK-NEXT: sum
1432# CHECK: for (int64_t i_2 = 0ll; i_2 < 5
1433# CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3
1434# CHECK-NEXT: aten_mul
1435# CHECK-NOT: for)IR";
1436torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1437
1438auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1439auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1440std::vector<at::Tensor> inputs = {a, b};
1441std::vector<IValue> stack = fmap<IValue>(inputs);
1442k.run(stack);
1443auto o = stack[0].toTensor();
1444auto ref = (a * b).sum(at::kFloat) * (a * b);
1445ASSERT_TRUE(at::allclose(o, ref));
1446}
1447
1448TEST_F(Kernel, SanitizeNames_CUDA) {
1449const auto graph_string = R"IR(
1450graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0),
1451%1 : Float(5, 3, strides=[3, 1], device=cuda:0)):
1452%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1453%4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1454return (%4))IR";
1455auto graph = std::make_shared<Graph>();
1456parseIR(graph_string, &*graph);
1457graph->inputs().at(0)->setDebugName("aten::add:");
1458graph->inputs().at(1)->setDebugName("aten::add_");
1459TensorExprKernel k(graph);
1460auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat));
1461auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat));
1462auto ref = a * (a * b);
1463std::vector<at::Tensor> inputs = {a, b};
1464std::vector<IValue> stack = fmap<IValue>(inputs);
1465k.run(stack);
1466auto o = stack[0].toTensor();
1467ASSERT_TRUE(at::allclose(o, ref));
1468}
1469
1470TEST_F(Kernel, SanitizeConstants_CUDA) {
1471const auto graph_string = R"IR(
1472graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)):
1473%none : NoneType = prim::Constant()
1474%size : int = prim::Constant[value=16]()
1475%sizes : int[] = prim::ListConstruct(%size, %size)
1476%30 : Device = prim::Constant[value="cuda"]()
1477%y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none)
1478%z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y)
1479return (%z))IR";
1480auto graph = std::make_shared<Graph>();
1481parseIR(graph_string, &*graph);
1482// IRParser doesn't support tensor constants, so we insert a call to
1483// aten::ones and then const-prop it
1484ConstantPropagation(graph);
1485
1486// We set the name of the constant to include special characters that are
1487// not allowed. This should be fixed by the sanitizer in TensorExprKernel.
1488graph->nodes().front()->output()->setDebugName("illegal.name");
1489
1490// Check if we have a constant node with illegal name in the graph.
1491auto const_node = graph->nodes().front();
1492ASSERT_EQ(const_node->kind(), prim::Constant);
1493ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos);
1494
1495TensorExprKernel k(graph);
1496
1497auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat));
1498std::vector<at::Tensor> inputs = {x};
1499std::vector<IValue> stack = fmap<IValue>(inputs);
1500k.run(stack);
1501auto o = stack[0].toTensor();
1502auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat));
1503auto ref = x * y;
1504ASSERT_TRUE(at::allclose(o, ref));
1505}
1506
1507TEST_F(Kernel, ConstantTensors) {
1508const auto graph_string = R"IR(
1509graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1510%none : NoneType = prim::Constant()
1511%size : int = prim::Constant[value=16]()
1512%sizes : int[] = prim::ListConstruct(%size, %size)
1513%y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none)
1514%z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1515return (%z))IR";
1516auto graph = std::make_shared<Graph>();
1517parseIR(graph_string, &*graph);
1518// IRParser doesn't support tensor constants, so we insert a call to
1519// aten::ones and then const-prop it
1520ConstantPropagation(graph);
1521
1522TensorExprKernel k(graph);
1523
1524auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1525std::vector<at::Tensor> inputs = {x};
1526std::vector<IValue> stack = fmap<IValue>(inputs);
1527k.run(stack);
1528auto o = stack[0].toTensor();
1529auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1530auto ref = x * y;
1531ASSERT_TRUE(at::allclose(o, ref));
1532}
1533
1534TEST_F(Kernel, ConstantTensorsNonContiguous) {
1535const auto graph_string = R"IR(
1536graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1537%none : NoneType = prim::Constant()
1538%dtype : int = prim::Constant[value=6]()
1539%c0 : int = prim::Constant[value=0]()
1540%c256 : int = prim::Constant[value=256]()
1541%c16 : int = prim::Constant[value=16]()
1542%y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none)
1543%sizes : int[] = prim::ListConstruct(%c16, %c16)
1544%y_t : Tensor = aten::view(%y_flat, %sizes)
1545%y : Tensor = aten::t(%y_t)
1546%z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1547return (%z))IR";
1548auto graph = std::make_shared<Graph>();
1549parseIR(graph_string, &*graph);
1550// IRParser doesn't support tensor constants, so we generate several aten
1551// calls to produce non-contiguous constant tensor and then const-prop it
1552ConstantPropagation(graph);
1553
1554TensorExprKernel k(graph);
1555
1556auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1557std::vector<at::Tensor> inputs = {x};
1558std::vector<IValue> stack = fmap<IValue>(inputs);
1559k.run(stack);
1560auto o = stack[0].toTensor();
1561auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat))
1562.view({16, 16})
1563.t();
1564auto ref = x * y;
1565ASSERT_TRUE(at::allclose(o, ref));
1566}
1567
1568TEST_F(Kernel, RunFast) {
1569#ifdef TORCH_ENABLE_LLVM
1570// TODO: Implement call_raw in IREval and remove the ifdef
1571
1572const auto graph_string = R"IR(
1573graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1574%1 : Float(5, 3, strides=[1, 5], device=cpu)):
1575%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1576%3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1577return (%3))IR";
1578auto graph = std::make_shared<Graph>();
1579parseIR(graph_string, &*graph);
1580
1581auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1582auto b =
1583at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1584auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1585auto ref = a * (a * b);
1586TensorExprKernel k(graph);
1587
1588k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()});
1589for (size_t i = 0; i < 5 * 3; i++) {
1590TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1591}
1592#endif
1593}
1594
1595TEST_F(Kernel, RunWithAllocatedOutputs) {
1596#ifdef TORCH_ENABLE_LLVM
1597const auto graph_string = R"IR(
1598graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1599%1 : Float(5, 3, strides=[1, 5], device=cpu)):
1600%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1601%3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1602return (%3))IR";
1603auto graph = std::make_shared<Graph>();
1604parseIR(graph_string, &*graph);
1605
1606auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1607auto b =
1608at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1609auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1610auto ref = a * (a * b);
1611TensorExprKernel k(graph);
1612
1613std::vector<at::Tensor> args = {o, a, b};
1614std::vector<IValue> stack = fmap<IValue>(args);
1615k.runWithAllocatedOutputs(stack);
1616for (size_t i = 0; i < 5 * 3; i++) {
1617TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1618}
1619#endif
1620}
1621
1622TEST_F(Kernel, CodegenInspection) {
1623#ifdef TORCH_ENABLE_LLVM
1624const auto graph_string = R"IR(
1625graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1626%none : NoneType = prim::Constant()
1627%dtype : int = prim::Constant[value=6]()
1628%c0 : int = prim::Constant[value=0]()
1629%c256 : int = prim::Constant[value=256]()
1630%c16 : int = prim::Constant[value=16]()
1631%y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none)
1632%sizes : int[] = prim::ListConstruct(%c16, %c16)
1633%y_t : Tensor = aten::view(%y_flat, %sizes)
1634%y : Tensor = aten::t(%y_t)
1635%z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1636return (%z))IR";
1637auto graph = std::make_shared<Graph>();
1638parseIR(graph_string, &*graph);
1639// IRParser doesn't support tensor constants, so we generate several aten
1640// calls to produce non-contiguous constant tensor and then const-prop it
1641ConstantPropagation(graph);
1642
1643TensorExprKernel k(graph);
1644
1645// Check that we could retrieve generated assembly
1646auto asm_str = k.getCodeText("asm");
1647const std::string& asm_verification_pattern =
1648R"ASM(
1649# CHECK: .text
1650# CHECK: retq)ASM";
1651torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str);
1652
1653// Check that we could retrieve info about codegen parameters
1654auto constants = k.getConstantDescriptors();
1655auto buf_args = k.getBufferArgs();
1656// Expected buf args: [input0, output0, constant0]
1657ASSERT_EQ(buf_args.size(), 3);
1658ASSERT_EQ(constants.size(), 1);
1659ASSERT_TRUE(
1660!buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar());
1661#endif
1662}
1663
1664Tensor lowerNanToNum(
1665const std::vector<ArgValue>& inputs,
1666const std::vector<ExprHandle>& outputShape,
1667const std::vector<ExprHandle>& outputStrides,
1668const std::optional<ScalarType>& outputType,
1669at::Device device) {
1670auto input_buf = std::get<BufHandle>(inputs[0]);
1671auto e = Compute(
1672"custom_nan_to_num",
1673outputShape,
1674outputStrides,
1675[&](const std::vector<VarHandle>& axes) {
1676std::vector<ExprHandle> indices(axes.begin(), axes.end());
1677auto load = input_buf.load(indices);
1678return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load);
1679});
1680return e;
1681}
1682
1683TEST_F(Kernel, CustomLowering) {
1684const auto graph_string = R"IR(
1685graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
1686%none : NoneType = prim::Constant()
1687%y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none)
1688return (%y)
1689)IR";
1690auto graph = std::make_shared<Graph>();
1691parseIR(graph_string, &*graph);
1692
1693std::unordered_map<c10::Symbol, NNCLoweringFunction> lowerings = {
1694{aten::nan_to_num, lowerNanToNum}};
1695TensorExprKernel k(graph, lowerings);
1696
1697auto stmt = k.getCodeGenStmt();
1698std::ostringstream oss;
1699oss << *stmt;
1700
1701// Check that our custom lowering is actually used
1702torch::jit::testing::FileCheck().check("custom_nan_to_num")->run(oss.str());
1703torch::jit::testing::FileCheck().check("isnan")->run(oss.str());
1704}
1705
1706TEST_F(Kernel, Vectorize) {
1707#ifdef TORCH_ENABLE_LLVM
1708const auto graph_string = R"IR(
1709graph(%0 : Float(100, 16, strides=[16, 1], device=cpu),
1710%1 : Float(100, 16, strides=[16, 1], device=cpu)):
1711%2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1)
1712%3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2)
1713return (%3))IR";
1714auto graph = std::make_shared<Graph>();
1715parseIR(graph_string, &*graph);
1716
1717auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1718auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1719auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1720auto ref = a * (a * b);
1721TensorExprKernel k(graph);
1722std::vector<at::Tensor> inputs = {a, b};
1723StmtPtr s = k.getCodeGenStmt();
1724
1725std::ostringstream oss;
1726oss << *s;
1727
1728// Check the IR we produced
1729const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
1730torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1731
1732std::vector<IValue> stack = fmap<IValue>(inputs);
1733k.run(stack);
1734o = stack[0].toTensor();
1735for (size_t i = 0; i < 100 * 16; i++) {
1736TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1737}
1738#endif
1739}
1740
1741// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first.
1742TEST_F(Kernel, DISABLED_FlattenVectorize) {
1743#ifdef TORCH_ENABLE_LLVM
1744const auto graph_string = R"IR(
1745graph(%0 : Float(100, 3, strides=[3, 1], device=cpu),
1746%1 : Float(100, 3, strides=[3, 1], device=cpu)):
1747%2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1)
1748%3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2)
1749return (%3))IR";
1750auto graph = std::make_shared<Graph>();
1751parseIR(graph_string, &*graph);
1752
1753auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1754auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1755auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1756auto ref = a * (a * b);
1757TensorExprKernel k(graph);
1758std::vector<at::Tensor> inputs = {a, b};
1759StmtPtr s = k.getCodeGenStmt();
1760
1761std::ostringstream oss;
1762oss << *s;
1763
1764// Check the IR we produced
1765const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
1766torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1767
1768std::vector<IValue> stack = fmap<IValue>(inputs);
1769k.run(stack);
1770o = stack[0].toTensor();
1771for (size_t i = 0; i < 100 * 3; i++) {
1772TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1773}
1774#endif
1775}
1776
1777TEST_F(Kernel, Strided1dWithinBounds) {
1778auto ir = R"IR(
1779graph(%0 : Float(3, strides=[1], device=cpu),
1780%1 : Float(3, strides=[2], device=cpu)):
1781%2 : int = prim::Constant[value=1]()
1782%3 : Float(3, strides=[1]) = aten::add(%0, %1, %2)
1783return (%3))IR";
1784auto graph = std::make_shared<Graph>();
1785std::unordered_map<std::string, Value*> vmap;
1786parseIR(ir, graph.get(), vmap);
1787TensorExprKernel k(graph);
1788
1789auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat));
1790auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat))
1791.index({Slice(None, None, 2)});
1792auto expect = a + b;
1793
1794std::vector<at::Tensor> inputs = {a, b};
1795
1796std::vector<IValue> stack = fmap<IValue>(inputs);
1797k.run(stack);
1798
1799auto output = stack[0].toTensor();
1800
1801for (size_t i = 0; i < 3; ++i) {
1802TORCH_CHECK_EQ(
1803((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]);
1804}
1805}
1806
1807TEST_F(Kernel, InputAsOutput) {
1808const auto graph_string = R"IR(
1809graph(%x : Float(5, 3, strides=[3, 1], device=cpu),
1810%y : Float(5, 3, strides=[1, 5], device=cpu)):
1811return (%x, %y))IR";
1812auto graph = std::make_shared<Graph>();
1813parseIR(graph_string, &*graph);
1814
1815auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1816auto y =
1817at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1818TensorExprKernel k(graph);
1819std::vector<at::Tensor> inputs = {x, y};
1820
1821std::vector<IValue> stack = fmap<IValue>(inputs);
1822k.run(stack);
1823CHECK(at::allclose(x, stack[0].toTensor()));
1824CHECK(at::allclose(y, stack[1].toTensor()));
1825}
1826
1827TEST_F(Kernel, ScalarOut) {
1828auto ir = R"IR(
1829graph(%x : int, %y : int):
1830%z : int = aten::mul(%x, %y)
1831%r : int = aten::mul(%z, %x)
1832return (%r, %z))IR";
1833auto graph = std::make_shared<Graph>();
1834std::unordered_map<std::string, Value*> vmap;
1835parseIR(ir, graph.get(), vmap);
1836TensorExprKernel k(graph);
1837
1838auto stmt = k.getCodeGenStmt();
1839std::ostringstream oss;
1840oss << *stmt;
1841
1842// Verify the generated IR. We expect to see a scalar variable (Let) followed
1843// by a store to a 0-dim buffer.
1844const std::string& verification_pattern = R"IR(
1845# CHECK: int64_t
1846# CHECK-NEXT: [0ll] =
1847# CHECK-NEXT: int64_t
1848# CHECK-NEXT: [0ll] =
1849)IR";
1850torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1851
1852int64_t x = 2, y = 3, r = 0, z = 0;
1853
1854// Verify that TEK::runFast works correctly with scalar outputs
1855std::vector<void*> inputs = {&x, &y};
1856std::vector<void*> outputs = {&r, &z};
1857k.runFast(inputs, outputs);
1858TORCH_CHECK_EQ(z, x * y);
1859TORCH_CHECK_EQ(r, z * x);
1860
1861// Verify that TEK::run works correctly with scalar outputs
1862std::vector<IValue> stack = {x, y};
1863k.run(stack);
1864TORCH_CHECK_EQ(stack[0], x * y * x);
1865TORCH_CHECK_EQ(stack[1], x * y);
1866}
1867
1868TEST_F(Kernel, ScalarTensorOut) {
1869auto ir = R"IR(
1870graph(%x : int,
1871%xt : Long(3, strides=[1], device=cpu),
1872%y : int,
1873%yt : Long(3, strides=[1], device=cpu)):
1874%z : int = aten::mul(%x, %y)
1875%r : int = aten::mul(%z, %x)
1876%zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y)
1877%rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt)
1878return (%r, %rt, %z, %zt))IR";
1879auto graph = std::make_shared<Graph>();
1880std::unordered_map<std::string, Value*> vmap;
1881parseIR(ir, graph.get(), vmap);
1882TensorExprKernel k(graph);
1883int64_t x = 2, y = 3, r = 0, z = 0;
1884auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2;
1885auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3;
1886auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong));
1887auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong));
1888
1889// Verify that TEK::runFast works correctly with mixed scalar and tensor
1890// inputs/utputs
1891std::vector<void*> inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()};
1892std::vector<void*> outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()};
1893k.runFast(inputs, outputs);
1894TORCH_CHECK_EQ(z, x * y);
1895TORCH_CHECK_EQ(r, z * x);
1896ASSERT_TRUE(at::equal(zt, xt * yt));
1897ASSERT_TRUE(at::equal(rt, zt * xt));
1898
1899// Verify that TEK::run works correctly with mixed scalar and tensor
1900// inputs/utputs
1901std::vector<IValue> stack = {x, xt, y, yt};
1902k.run(stack);
1903TORCH_CHECK_EQ(stack[0], x * y * x);
1904ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt));
1905TORCH_CHECK_EQ(stack[2], x * y);
1906ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt));
1907}
1908
1909TEST_F(Kernel, FuseLoopsWithVariableBounds) {
1910#ifdef TORCH_ENABLE_LLVM
1911bool old_cat_wo_conditionals = getCatWoConditionals();
1912getCatWoConditionals() = true;
1913const auto graph_string = R"IR(
1914graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu),
1915%b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu),
1916%c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu),
1917%SS_2 : int,
1918%SS_3 : int):
1919%dim : int = prim::Constant[value=1]()
1920%inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
1921%r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2]
1922return (%r))IR";
1923std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1924torch::jit::parseIR(graph_string, graph.get());
1925
1926std::vector<int64_t> symbolic_shape_inputs = {-2, -3};
1927
1928std::vector<torch::jit::StrideInput> input_desc = {
1929torch::jit::StrideInput::TENSOR_CONT};
1930std::unordered_map<
1931const torch::jit::Value*,
1932std::vector<torch::jit::StrideInput>>
1933symbolic_strides;
1934symbolic_strides[graph->inputs().at(0)] = input_desc;
1935symbolic_strides[graph->inputs().at(1)] = input_desc;
1936symbolic_strides[graph->inputs().at(2)] = input_desc;
1937symbolic_strides[graph->outputs().at(0)] = input_desc;
1938
1939TensorExprKernel kernel(
1940graph, {}, symbolic_shape_inputs, false, symbolic_strides);
1941
1942std::ostringstream oss;
1943oss << *kernel.getCodeGenStmt();
1944const std::string& verification_pattern =
1945R"IR(
1946# CHECK: for (int64_t i
1947# CHECK-NEXT: for (int64_t j
1948# CHECK-NEXT: for (int64_t k
1949# CHECK: for (int64_t j
1950# CHECK-NEXT: for (int64_t k
1951# CHECK: for (int64_t j
1952# CHECK-NEXT: for (int64_t k
1953# CHECK-NOT: for (int64_t i
1954)IR";
1955torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1956
1957auto run_kernel = [&](int dim1, int dim2) {
1958auto a =
1959at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1960auto b =
1961at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1962auto c =
1963at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1964
1965auto ref = at::cat({a, b, c}, 1);
1966
1967std::vector<IValue> stack =
1968fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
1969stack.emplace_back(dim1);
1970stack.emplace_back(dim2);
1971kernel.run(stack);
1972
1973auto o = stack[0].toTensor();
1974ASSERT_TRUE(at::allclose(o, ref));
1975};
1976
1977run_kernel(10, 20);
1978getCatWoConditionals() = old_cat_wo_conditionals;
1979#endif
1980}
1981
1982TEST_F(Kernel, FuseLoopsWithVariableConcatDim) {
1983#ifdef TORCH_ENABLE_LLVM
1984bool old_cat_wo_conditionals = getCatWoConditionals();
1985getCatWoConditionals() = true;
1986const auto graph_string = R"IR(
1987graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1988%b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1989%c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1990%SS_2 : int,
1991%SS_3 : int,
1992%SS_4 : int,
1993%SS_5 : int):
1994%dim : int = prim::Constant[value=1]()
1995%inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
1996%r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2]
1997return (%r))IR";
1998std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1999torch::jit::parseIR(graph_string, graph.get());
2000
2001std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5};
2002
2003std::vector<torch::jit::StrideInput> input_desc = {
2004torch::jit::StrideInput::TENSOR_CONT};
2005std::unordered_map<
2006const torch::jit::Value*,
2007std::vector<torch::jit::StrideInput>>
2008symbolic_strides;
2009symbolic_strides[graph->inputs().at(0)] = input_desc;
2010symbolic_strides[graph->inputs().at(1)] = input_desc;
2011symbolic_strides[graph->inputs().at(2)] = input_desc;
2012symbolic_strides[graph->outputs().at(0)] = input_desc;
2013
2014TensorExprKernel kernel(
2015graph, {}, symbolic_shape_inputs, false, symbolic_strides);
2016
2017std::ostringstream oss;
2018oss << *kernel.getCodeGenStmt();
2019const std::string& verification_pattern =
2020R"IR(
2021# CHECK: for (int64_t i
2022# CHECK-NEXT: for (int64_t j
2023# CHECK-NEXT: for (int64_t k
2024# CHECK: for (int64_t j
2025# CHECK-NEXT: for (int64_t k
2026# CHECK: for (int64_t j
2027# CHECK-NEXT: for (int64_t k
2028# CHECK-NOT: for (int64_t i
2029)IR";
2030torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2031
2032auto run_kernel = [&](int dim1, int dim2, int dim3) {
2033auto a =
2034at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2035auto b =
2036at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2037auto c =
2038at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2039
2040auto ref = at::cat({a, b, c}, 1);
2041
2042std::vector<IValue> stack =
2043fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
2044stack.emplace_back(dim1);
2045stack.emplace_back(dim2);
2046stack.emplace_back(dim3);
2047stack.emplace_back(3 * dim3);
2048kernel.run(stack);
2049
2050auto o = stack[0].toTensor();
2051ASSERT_TRUE(at::allclose(o, ref));
2052};
2053
2054run_kernel(10, 20, 15);
2055getCatWoConditionals() = old_cat_wo_conditionals;
2056#endif
2057}
2058
2059TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) {
2060#ifdef TORCH_ENABLE_LLVM
2061bool old_cat_wo_conditionals = getCatWoConditionals();
2062getCatWoConditionals() = true;
2063const auto graph_string = R"IR(
2064graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
2065%b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu),
2066%SS_2 : int,
2067%SS_3 : int,
2068%SS_4 : int,
2069%SS_5 : int,
2070%SS_6 : int):
2071%dim : int = prim::Constant[value=1]()
2072%inputs : Tensor[] = prim::ListConstruct(%a, %b)
2073%r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2]
2074return (%r))IR";
2075std::shared_ptr<Graph> graph = std::make_shared<Graph>();
2076torch::jit::parseIR(graph_string, graph.get());
2077
2078std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5, -6};
2079
2080std::vector<torch::jit::StrideInput> input_desc = {
2081torch::jit::StrideInput::TENSOR_CONT};
2082std::unordered_map<
2083const torch::jit::Value*,
2084std::vector<torch::jit::StrideInput>>
2085symbolic_strides;
2086symbolic_strides[graph->inputs().at(0)] = input_desc;
2087symbolic_strides[graph->inputs().at(1)] = input_desc;
2088symbolic_strides[graph->outputs().at(0)] = input_desc;
2089
2090TensorExprKernel kernel(
2091graph, {}, symbolic_shape_inputs, false, symbolic_strides);
2092
2093std::ostringstream oss;
2094oss << *kernel.getCodeGenStmt();
2095const std::string& verification_pattern =
2096R"IR(
2097# CHECK: for (int64_t i
2098# CHECK-NEXT: for (int64_t j
2099# CHECK-NEXT: for (int64_t k
2100# CHECK: for (int64_t j
2101# CHECK-NEXT: for (int64_t k
2102# CHECK-NOT: for (int64_t j
2103# CHECK-NOT: for (int64_t i
2104)IR";
2105torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2106
2107auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) {
2108auto a =
2109at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat));
2110auto b =
2111at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat));
2112
2113auto ref = at::cat({a, b}, 1);
2114
2115std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
2116stack.emplace_back(dim2);
2117stack.emplace_back(dim3);
2118stack.emplace_back(dim4);
2119stack.emplace_back(dim5);
2120stack.emplace_back(dim4 + dim5);
2121kernel.run(stack);
2122
2123auto o = stack[0].toTensor();
2124ASSERT_TRUE(at::allclose(o, ref));
2125};
2126
2127run_kernel(10, 20, 15, 8);
2128getCatWoConditionals() = old_cat_wo_conditionals;
2129#endif
2130}
2131
2132} // namespace jit
2133} // namespace torch
2134