pytorch
1799 строк · 49.8 Кб
1#ifdef TORCH_ENABLE_LLVM2#include <gtest/gtest.h>3
4#include <test/cpp/tensorexpr/test_base.h>5
6#include <c10/util/irange.h>7#include <test/cpp/tensorexpr/padded_buffer.h>8#include <test/cpp/tensorexpr/test_utils.h>9#include <torch/csrc/jit/tensorexpr/eval.h>10#include <torch/csrc/jit/tensorexpr/ir.h>11#include <torch/csrc/jit/tensorexpr/ir_printer.h>12#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>13#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>14#include <torch/csrc/jit/tensorexpr/loopnest.h>15#include <torch/csrc/jit/tensorexpr/tensor.h>16#include <torch/csrc/jit/testing/file_check.h>17
18#include <cmath>19#include <numeric>20
21namespace torch {22namespace jit {23using namespace torch::jit::tensorexpr;24
25using LLVMExprEval = ExprEval<LLVMCodeGen>;26
27// Typed tests, can't use gtest params here due to the way we instantiate tests.
28#define TEST_LLVM_SCALAR_TYPES(_) \29_(uint8_t, Byte, 24) \30_(int8_t, Char, -20) \31_(int16_t, Short, 3332) \32_(int, Int, 123456) \33_(int64_t, Long, 2631563121321) \34_(float, Float, 0.122) \35_(double, Double, 0.21312) \36_(at::Half, Half, 0.128f)37
38#define IMM_TEST(Type, Name, Val) \39TEST(LLVM, Name##ImmTest) { \40auto a = Name##Imm::make(Val); \41LLVMExprEval cg(a); \42if (std::is_floating_point<decltype(Val)>()) { \43ASSERT_NEAR(cg.value<Type>(), Val, 0.1); \44} else { \45ASSERT_EQ(cg.value<Type>(), Val); \46} \47}48TEST_LLVM_SCALAR_TYPES(IMM_TEST)49#undef IMM_TEST50
51#define ADD_TEST(Type, Name, Val) \52TEST(LLVM, Name##AddTest) { \53auto a = Name##Imm::make(Val); \54auto b = Name##Imm::make(Val * 2); \55auto c = Add::make(a, b); \56LLVMExprEval cg(c); \57if (std::is_floating_point<decltype(Val)>()) { \58ASSERT_NEAR(cg.value<Type>(), Val * 3, 0.1); \59} else { \60ASSERT_EQ(cg.value<Type>(), Val * 3); \61} \62}63TEST_LLVM_SCALAR_TYPES(ADD_TEST)64#undef ADD_TEST65
66#define SUB_TEST(Type, Name, Val) \67TEST(LLVM, Name##SubTest) { \68auto a = Name##Imm::make(Val * 2); \69auto b = Name##Imm::make(Val); \70auto c = Sub::make(a, b); \71LLVMExprEval cg(c); \72if (std::is_floating_point<decltype(Val)>()) { \73ASSERT_NEAR(cg.value<Type>(), Val, 0.1); \74} else { \75ASSERT_EQ(cg.value<Type>(), Val); \76} \77}78TEST_LLVM_SCALAR_TYPES(SUB_TEST)79#undef SUB_TEST80
81#define MUL_TEST(Type, Name, Val) \82TEST(LLVM, Name##MulTest) { \83auto a = Name##Imm::make(Val); \84auto b = Name##Imm::make((Type)4); \85auto c = Mul::make(a, b); \86LLVMExprEval cg(c); \87if (std::is_floating_point<decltype(Val)>()) { \88ASSERT_NEAR(cg.value<Type>(), Val * 4, 0.1); \89} else { \90ASSERT_EQ(cg.value<Type>(), Val * 4); \91} \92}93TEST_LLVM_SCALAR_TYPES(MUL_TEST)94#undef MUL_TEST95
96#define DIV_TEST(Type, Name, Val) \97TEST(LLVM, Name##DivTest) { \98auto a = Name##Imm::make((Type)6); \99auto b = Name##Imm::make((Type)3); \100auto c = Div::make(a, b); \101LLVMExprEval cg(c); \102if (std::is_floating_point<decltype(Val)>()) { \103ASSERT_NEAR(cg.value<Type>(), 2, 0.1); \104} else { \105ASSERT_EQ(cg.value<Type>(), 2); \106} \107}108TEST_LLVM_SCALAR_TYPES(DIV_TEST)109#undef DIV_TEST110
111TEST(LLVM, IntToFloatCastTest) {112auto a = IntImm::make(2);113auto b = Cast::make(kFloat, a);114LLVMExprEval cg(b, {});115ASSERT_EQ(cg.value<float>(), 2.0);116}
117
118TEST(LLVM, FloatToIntCastTest) {119auto a = FloatImm::make(2.0);120auto b = Cast::make(kInt, a);121LLVMExprEval cg(b);122ASSERT_EQ(cg.value<int>(), 2);123}
124
125TEST(LLVM, IntToLongCastTest) {126auto a = IntImm::make(12345);127auto b = Cast::make(kLong, a);128LLVMExprEval cg(b);129ASSERT_EQ(cg.value<int64_t>(), 12345);130}
131
132TEST(LLVM, ByteToCharCastTest) {133auto a = ByteImm::make(250);134auto b = Cast::make(kChar, a);135LLVMExprEval cg(b);136ASSERT_EQ(cg.value<int8_t>(), (int8_t)250);137}
138
139TEST(LLVM, HalfToLongCastTest) {140auto a = HalfImm::make(2.0);141auto b = Cast::make(kLong, a);142LLVMExprEval cg(b);143ASSERT_EQ(cg.value<int64_t>(), 2);144}
145
146TEST(LLVM, ByteToDoubleCastTest) {147auto a = ByteImm::make(2);148auto b = Cast::make(kDouble, a);149LLVMExprEval cg(b);150ASSERT_EQ(cg.value<double>(), 2);151}
152
153TEST(LLVM, FloatToByteCastTest) {154auto a = FloatImm::make(254.0);155auto b = Cast::make(kByte, a);156LLVMExprEval cg(b);157ASSERT_EQ(cg.value<uint8_t>(), 254);158}
159
160TEST(LLVM, FloatToCharCastTest) {161auto a = FloatImm::make(-2.0);162auto b = Cast::make(kChar, a);163LLVMExprEval cg(b);164ASSERT_EQ(cg.value<int8_t>(), -2);165}
166
167TEST(LLVM, ByteToFloatCastTest) {168auto a = ByteImm::make(254);169auto b = Cast::make(kFloat, a);170LLVMExprEval cg(b);171ASSERT_EQ(cg.value<float>(), 254.0);172}
173
174TEST(LLVM, CharToFloatCastTest) {175auto a = CharImm::make(-2);176auto b = Cast::make(kFloat, a);177LLVMExprEval cg(b);178ASSERT_EQ(cg.value<float>(), -2.0);179}
180
181TEST(LLVM, BitCast) {182/* constexpr int16_t ref16 = 1337; */183constexpr int32_t ref32 = 1337;184constexpr int64_t ref64 = 1337;185constexpr float reff32 = 1337.0f;186constexpr double reff64 = 1337.0f;187
188// this is broken189/*{190at::Half k_;
191at::Half* k = &k_;
192*reinterpret_cast<int16_t*>(k) = ref16;
193auto a = HalfImm::make(k);
194auto b = BitCast::make(kShort, a);
195LLVMExprEval cg(b);
196ASSERT_EQ(cg.value<int16_t>(), ref16);
197}*/
198
199{200float k = raw_bitcast<float>(ref32);201auto a = FloatImm::make(k);202auto b = BitCast::make(kInt, a);203LLVMExprEval cg(b);204ASSERT_EQ(cg.value<int32_t>(), ref32);205}206
207{208double k = raw_bitcast<double>(ref64);209auto a = DoubleImm::make(k);210auto b = BitCast::make(kLong, a);211LLVMExprEval cg(b);212ASSERT_EQ(cg.value<int64_t>(), ref64);213}214
215{216int64_t k = raw_bitcast<int64_t>(reff64);217auto a = LongImm::make(k);218auto b = BitCast::make(kDouble, a);219LLVMExprEval cg(b);220ASSERT_EQ(cg.value<double>(), reff64);221}222
223{224int32_t k = raw_bitcast<int32_t>(reff32);225auto a = IntImm::make(k);226auto b = BitCast::make(kFloat, a);227LLVMExprEval cg(b);228ASSERT_EQ(cg.value<float>(), reff32);229}230}
231
232TEST(LLVM, fastLogFloat) {233const int kTotalSize = 128 * 128;234BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);235BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);236
237VarHandle index = VarHandle("index", kInt);238ExprHandle load_a = a_buf.load(index);239StmtPtr store_b = b_buf.store({index}, fast_log(load_a));240StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);241
242PaddedBuffer<float> a_v(kTotalSize);243PaddedBuffer<float> b_v(kTotalSize);244
245for (const auto i : c10::irange(kTotalSize)) {246a_v(i) = at::randn({1}).item().to<float>();247}248
249LLVMCodeGen ir_eval(stmt, {a_buf, b_buf});250ir_eval.call({a_v, b_v});251
252for (const auto i : c10::irange(kTotalSize)) {253auto test = b_v(i);254auto ref = std::log(a_v(i));255if (std::isnan(ref)) {256ASSERT_EQ(std::isnan(test), true);257} else {258ASSERT_FLOAT_EQ(test, ref);259}260}261}
262
263TEST(LLVM, LetTest01) {264BufHandle a("A", {1}, kFloat);265std::vector<float> v = {1, 0};266std::vector<void*> args({v.data()});267VarHandle x("x", kFloat);268auto block = Block::make({269Let::make(x, 3.f),270a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))),271});272
273LLVMCodeGen cg(block, {a});274ASSERT_EQ(cg.value<int>(args), 0);275ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f);276}
277
278TEST(LLVM, LetTest02) {279BufHandle a("A", {1}, kFloat);280std::vector<float> v = {1, 0};281std::vector<void*> args({v.data()});282VarHandle x("x", kFloat);283VarHandle y("y", kFloat);284auto block = Block::make(285{Let::make(x, 3.f),286Let::make(y, 6.f),287a.store(288{IntImm::make(0)},289ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))});290
291LLVMCodeGen cg(block, {a});292ASSERT_EQ(cg.value<int>(args), 0);293ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f);294}
295
296TEST(LLVM, LetTestMultitype) {297BufHandle a("A", {1}, kDouble);298std::vector<double> v = {1, 0};299std::vector<void*> args({v.data()});300VarHandle x("x", kByte);301VarHandle y("y", kHalf);302auto block = Block::make(303{Let::make(x, 3),304Let::make(y, 6.f),305a.store(306{0},307Cast::make(308kDouble,309ExprHandle(2.f) +310(x * ExprHandle(3.f) + y * ExprHandle(4.f))))});311
312LLVMCodeGen cg(block, {a});313ASSERT_EQ(cg.value<int>(args), 0);314ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f);315}
316
317TEST(LLVM, BufferTest) {318BufHandle a("A", {32}, kFloat);319std::vector<int32_t> v(5);320std::vector<void*> args({v.data()});321auto rv = IntImm::make(0);322LLVMExprEval cg(rv, {a});323ASSERT_EQ(cg.value<int>(args), 0);324}
325
326TEST(LLVM, BlockTest) {327BufHandle a("A", {32}, kInt);328std::vector<int32_t> v = {1, 2};329std::vector<void*> args({v.data()});330
331auto block = Block::make({332a.store({0}, 3),333a.store({1}, 4),334a.store({0}, 4),335});336
337LLVMCodeGen cg(block, {a});338ASSERT_EQ(cg.value<int>(args), 0);339ASSERT_EQ(v[0], 4);340ASSERT_EQ(v[1], 4);341}
342
343TEST(LLVM, LoadStoreTest) {344BufHandle a("A", {1}, kInt);345BufHandle b("B", {1}, kInt);346std::vector<int32_t> a_buffer = {42};347std::vector<int32_t> b_buffer = {-11};348
349auto store = b.store({0}, a.load(0));350LLVMCodeGen cg(store, {a, b});351std::vector<void*> args({a_buffer.data(), b_buffer.data()});352ASSERT_EQ(cg.value<int>(args), 0);353ASSERT_EQ(a_buffer[0], 42);354ASSERT_EQ(b_buffer[0], 42);355}
356
357TEST(LLVM, IfThenElseTest) {358BufHandle a("A", {1}, kInt);359BufHandle b("B", {1}, kInt);360BufHandle c("C", {1}, kInt);361std::vector<int32_t> a_buffer = {42};362std::vector<int32_t> b_buffer = {-11};363std::vector<int32_t> c_buffer = {1};364
365auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0));366LLVMCodeGen cg(store, {a, b, c});367std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});368ASSERT_EQ(cg.value<int>(args), 0);369ASSERT_EQ(a_buffer[0], 42);370ASSERT_EQ(b_buffer[0], 42);371}
372
373// if (x < 10) x = x + 1
374TEST(LLVM, CondNoFalseBlockTest) {375BufHandle x("X", {1}, kInt);376auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);377auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr);378
379for (int32_t x_value : {0, 10, 20}) {380std::vector<int32_t> x_buffer = {x_value};381std::vector<void*> args({x_buffer.data()});382LLVMCodeGen cg(cond, {x});383ASSERT_EQ(cg.value<int>(args), 0);384if (x_value < 10) {385ASSERT_EQ(x_buffer[0], x_value + 1);386} else {387ASSERT_EQ(x_buffer[0], x_value);388}389}390}
391
392// if (x < 10) {
393// x = x + 1;
394// } else {
395// x = x - 1;
396// }
397TEST(LLVM, CondTest) {398BufHandle x("X", {1}, kInt);399auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);400auto cond =401Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));402auto block = Block::make({403cond,404x.store({0}, x.load(0) * 2),405});406
407for (int32_t x_value : {0, 10, 20}) {408std::vector<int32_t> x_buffer = {x_value};409std::vector<void*> args({x_buffer.data()});410LLVMCodeGen cg(block, {x});411ASSERT_EQ(cg.value<int>(args), 0);412if (x_value < 10) {413ASSERT_EQ(x_buffer[0], (x_value + 1) * 2);414} else {415ASSERT_EQ(x_buffer[0], (x_value - 1) * 2);416}417}418}
419
420// if (x < 10) {
421// if (x > 5) {
422// x = x + 1;
423// } else {
424// x = x - 1;
425// }
426// } else {
427// if (x <= 15) {
428// x = x + 2;
429// } else {
430// x = x - 2;
431// }
432// }
433TEST(LLVM, CondNestedTest) {434BufHandle x("X", {1}, kInt);435auto true_cmp =436CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT);437auto true_cond = Cond::make(438true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));439auto false_cmp =440CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE);441auto false_cond = Cond::make(442false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2));443auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);444auto cond = Cond::make(cmp, true_cond, false_cond);445
446for (int32_t x_value : {0, 8, 15, 20}) {447std::vector<int32_t> x_buffer = {x_value};448std::vector<void*> args({x_buffer.data()});449LLVMCodeGen cg(cond, {x});450ASSERT_EQ(cg.value<int>(args), 0);451if (x_value < 10) {452if (x_value > 5) {453ASSERT_EQ(x_buffer[0], x_value + 1);454} else {455ASSERT_EQ(x_buffer[0], x_value - 1);456}457} else {458if (x_value <= 15) {459ASSERT_EQ(x_buffer[0], x_value + 2);460} else {461ASSERT_EQ(x_buffer[0], x_value - 2);462}463}464}465}
466
467TEST(LLVM, DirectVectorization) {468constexpr int M = 3;469constexpr int N = 64;470BufHandle a("a", {M, N}, kFloat);471BufHandle b("b", {M, N}, kFloat);472BufHandle c("c", {M, N}, kFloat);473VarHandle m("m", kInt);474VarHandle n("n", kInt);475StmtPtr s = For::make(476m,4770,478M,479Store::make(480c,481{Ramp::make(m * 64, 1, 64)},482Load::make({kFloat, 64}, a, {Ramp::make(m * 64, 1, 64)}) *483Load::make({kFloat, 64}, b, {Ramp::make(m * 64, 1, 64)})));484LLVMCodeGen cg(s, {a, b, c});485}
486
487TEST(LLVM, VecLoadStoreTest) {488BufHandle a("A", {1}, kInt);489BufHandle b("B", {1}, kInt);490std::vector<int32_t> a_buffer = {1, 1, 1, 1};491std::vector<int32_t> b_buffer = {2, 2, 2, 2};492
493auto store = b.store({Ramp::make(0, 1, 4)}, a.load({Ramp::make(0, 1, 4)}));494LLVMCodeGen cg(store, {a, b});495std::vector<void*> args({a_buffer.data(), b_buffer.data()});496ASSERT_EQ(cg.value<int>(args), 0);497ASSERT_EQ(a_buffer[0], 1);498ASSERT_EQ(a_buffer[1], 1);499ASSERT_EQ(a_buffer[2], 1);500ASSERT_EQ(a_buffer[3], 1);501ASSERT_EQ(b_buffer[0], 1);502ASSERT_EQ(b_buffer[1], 1);503ASSERT_EQ(b_buffer[2], 1);504ASSERT_EQ(b_buffer[3], 1);505}
506
507#define FLOAT_INTRINSICS_TEST(Name, Lanes) \508TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) { \509BufHandle a("A", {1}, kFloat); \510BufHandle b("B", {1}, kFloat); \511float val = 0.5f; \512std::vector<float> a_buffer(Lanes, val); \513std::vector<float> b_buffer(Lanes, val); \514auto store = b.store( \515{Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \516LLVMCodeGen cg(store, {a, b}); \517std::vector<void*> args({a_buffer.data(), b_buffer.data()}); \518ASSERT_EQ(cg.value<int>(args), 0); \519for (const auto i : c10::irange(Lanes)) { \520ASSERT_FLOAT_EQ(a_buffer[i], val); \521} \522} // namespace jit523FLOAT_INTRINSICS_TEST(erf, 4)524FLOAT_INTRINSICS_TEST(erfc, 4)525FLOAT_INTRINSICS_TEST(acos, 4)526FLOAT_INTRINSICS_TEST(asin, 4)527FLOAT_INTRINSICS_TEST(atan, 4)528FLOAT_INTRINSICS_TEST(cosh, 4)529FLOAT_INTRINSICS_TEST(sinh, 4)530FLOAT_INTRINSICS_TEST(tanh, 4)531FLOAT_INTRINSICS_TEST(expm1, 4)532FLOAT_INTRINSICS_TEST(lgamma, 4)533FLOAT_INTRINSICS_TEST(erf, 8)534FLOAT_INTRINSICS_TEST(erfc, 8)535FLOAT_INTRINSICS_TEST(acos, 8)536FLOAT_INTRINSICS_TEST(asin, 8)537FLOAT_INTRINSICS_TEST(atan, 8)538FLOAT_INTRINSICS_TEST(cosh, 8)539FLOAT_INTRINSICS_TEST(sinh, 8)540FLOAT_INTRINSICS_TEST(tanh, 8)541FLOAT_INTRINSICS_TEST(expm1, 8)542FLOAT_INTRINSICS_TEST(lgamma, 8)543#undef FLOAT_INTRINSICS_TEST544
545#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \546TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) { \547BufHandle a("A", {1}, kDouble); \548BufHandle b("B", {1}, kDouble); \549float val = 0.5f; \550std::vector<double> a_buffer(Lanes, val); \551std::vector<double> b_buffer(Lanes, val); \552auto store = b.store( \553{Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \554LLVMCodeGen cg(store, {a, b}); \555std::vector<void*> args({a_buffer.data(), b_buffer.data()}); \556ASSERT_EQ(cg.value<int>(args), 0); \557for (const auto i : c10::irange(Lanes)) { \558ASSERT_FLOAT_EQ(a_buffer[i], val); \559} \560} // namespace jit561DOUBLE_INTRINSICS_TEST(erf, 2)562DOUBLE_INTRINSICS_TEST(erfc, 2)563DOUBLE_INTRINSICS_TEST(acos, 2)564DOUBLE_INTRINSICS_TEST(asin, 2)565DOUBLE_INTRINSICS_TEST(atan, 2)566DOUBLE_INTRINSICS_TEST(cosh, 2)567DOUBLE_INTRINSICS_TEST(sinh, 2)568DOUBLE_INTRINSICS_TEST(tanh, 2)569DOUBLE_INTRINSICS_TEST(expm1, 2)570DOUBLE_INTRINSICS_TEST(lgamma, 2)571DOUBLE_INTRINSICS_TEST(erf, 4)572DOUBLE_INTRINSICS_TEST(erfc, 4)573DOUBLE_INTRINSICS_TEST(acos, 4)574DOUBLE_INTRINSICS_TEST(asin, 4)575DOUBLE_INTRINSICS_TEST(atan, 4)576DOUBLE_INTRINSICS_TEST(cosh, 4)577DOUBLE_INTRINSICS_TEST(sinh, 4)578DOUBLE_INTRINSICS_TEST(tanh, 4)579DOUBLE_INTRINSICS_TEST(expm1, 4)580DOUBLE_INTRINSICS_TEST(lgamma, 4)581#undef DOUBLE_INTRINSICS_TEST582
583TEST(LLVM, VectorizerLoadStoreTest) {584BufHandle a("A", {1}, kInt);585
586Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); });587
588BufHandle c_buf(c.buf());589LoopNest l({c});590StmtPtr s = l.root_stmt();591ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));592
593ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);594
595LLVMCodeGen cg(s, {a, c_buf});596
597std::vector<int> a_vec(4, 21);598std::vector<int> c_vec(4, 0);599std::vector<void*> args({a_vec.data(), c_vec.data()});600ASSERT_EQ(cg.value<int>(args), 0);601assertAllEqual(c_vec, 21);602}
603
604TEST(LLVM, VectorizeBitCast) {605BufHandle a("A", {128}, kInt);606
607Tensor c = Compute("c", {128}, [&](const VarHandle& i) {608return bitcast<float>(a.load(i));609});610
611BufHandle c_buf(c.buf());612LoopNest l({c});613StmtPtr s = l.root_stmt();614ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));615ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);616
617LLVMCodeGen cg(s, {a, c_buf});618
619std::vector<int> a_vec(128);620std::vector<float> c_vec(128);621for (const auto i : c10::irange(128)) {622a_vec[i] = raw_bitcast<int>(1337.f);623}624std::vector<void*> args({a_vec.data(), c_vec.data()});625ASSERT_EQ(cg.value<int>(args), 0);626assertAllEqual(c_vec, 1337.f);627}
628
629TEST(LLVM, MemcpyTest) {630constexpr int N = 32;631BufHandle a("A", {N}, kInt);632BufHandle b("B", {N}, kInt);633std::vector<int32_t> a_buffer(N, 42);634std::vector<int32_t> b_buffer(N, 0);635
636VarHandle i("i", kInt);637auto expr = For::make(i, 0, N, b.store({i}, a.load(i)));638
639LLVMCodeGen cg(expr, {a, b});640
641std::vector<void*> args({a_buffer.data(), b_buffer.data()});642ASSERT_EQ(cg.value<int>(args), 0);643
644ASSERT_EQ(a_buffer.size(), N);645ASSERT_EQ(b_buffer.size(), N);646assertAllEqual(a_buffer, 42);647assertAllEqual(b_buffer, 42);648}
649
650TEST(LLVM, BzeroTest) {651constexpr int N = 32;652BufHandle b("B", {N}, kInt);653std::vector<int32_t> b_buffer(N, 11);654
655VarHandle i("i", kInt);656auto expr = For::make(i, 0, N, b.store({i}, 0));657
658LLVMCodeGen cg(expr, {b});659
660std::vector<void*> args({b_buffer.data()});661ASSERT_EQ(cg.value<int>(args), 0);662
663ASSERT_EQ(b_buffer.size(), N);664assertAllEqual(b_buffer, 0);665}
666
667TEST(LLVM, ElemwiseAdd) {668constexpr int N = 1024;669BufHandle a("A", {N}, kInt);670BufHandle b("B", {N}, kInt);671BufHandle c("C", {N}, kInt);672std::vector<int32_t> a_buffer(N, 41);673std::vector<int32_t> b_buffer(N, 1);674std::vector<int32_t> c_buffer(N, 1);675
676VarHandle i("i", kInt);677auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i))));678
679LLVMCodeGen cg(expr, {a, b, c});680
681std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});682ASSERT_EQ(cg.value<int>(args), 0);683
684ASSERT_EQ(a_buffer.size(), N);685ASSERT_EQ(b_buffer.size(), N);686ASSERT_EQ(c_buffer.size(), N);687assertAllEqual(a_buffer, 41);688assertAllEqual(b_buffer, 1);689assertAllEqual(c_buffer, 42);690}
691
692TEST(LLVM, ElemwiseAddFloat) {693constexpr int N = 1024;694BufHandle a("A", {N}, kFloat);695BufHandle b("B", {N}, kFloat);696BufHandle c("C", {N}, kFloat);697std::vector<float> a_buffer(N, 41);698std::vector<float> b_buffer(N, 1);699std::vector<float> c_buffer(N, 1);700
701VarHandle i("i", kInt);702auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i)));703
704LLVMCodeGen cg(expr, {a, b, c});705
706std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});707ASSERT_EQ(cg.value<int>(args), 0);708
709ASSERT_EQ(a_buffer.size(), N);710ASSERT_EQ(b_buffer.size(), N);711ASSERT_EQ(c_buffer.size(), N);712assertAllEqual(a_buffer, 41.0f);713assertAllEqual(b_buffer, 1.0f);714assertAllEqual(c_buffer, 42.0f);715}
716
717TEST(LLVM, ElemwiseLog10Float) {718constexpr int N = 1024;719BufHandle a("A", {N}, kFloat);720BufHandle b("B", {N}, kFloat);721std::vector<float> a_buffer(N, 10.0f);722std::vector<float> b_buffer(N, 2.0f);723
724VarHandle i("i", kInt);725auto expr = For::make(726i,7270,728N / 4,729b.store(730{Ramp::make(i * 4, 1, 4)}, log10(a.load({Ramp::make(i * 4, 1, 4)}))));731
732LLVMCodeGen cg(expr, {a, b});733
734std::vector<void*> args({a_buffer.data(), b_buffer.data()});735ASSERT_EQ(cg.value<int>(args), 0);736
737ASSERT_EQ(a_buffer.size(), N);738ASSERT_EQ(b_buffer.size(), N);739assertAllEqual(a_buffer, 10.0f);740assertAllEqual(b_buffer, 1.0f);741}
742
743TEST(LLVM, ElemwiseLog1pFloat) {744constexpr int N = 1024;745BufHandle a("A", {N}, kFloat);746BufHandle b("B", {N}, kFloat);747std::vector<float> a_buffer(N, expf(3.0f) - 1);748std::vector<float> b_buffer(N, 42.0f);749
750VarHandle i("i", kInt);751auto expr = For::make(752i,7530,754N / 4,755b.store(756{Ramp::make(i * 4, 1, 4)}, log1p(a.load({Ramp::make(i * 4, 1, 4)}))));757
758LLVMCodeGen cg(expr, {a, b});759
760std::vector<void*> args({a_buffer.data(), b_buffer.data()});761ASSERT_EQ(cg.value<int>(args), 0);762
763ASSERT_EQ(a_buffer.size(), N);764ASSERT_EQ(b_buffer.size(), N);765assertAllEqual(a_buffer, expf(3.0f) - 1);766ExpectAllNear(b_buffer, 3.0f, 1e-5f);767}
768
769TEST(LLVM, ElemwiseMaxInt) {770constexpr int N = 1024;771BufHandle a("A", {N}, kInt);772BufHandle b("B", {N}, kInt);773BufHandle c("C", {N}, kInt);774std::vector<int> a_buffer(N, 41);775std::vector<int> b_buffer(N, 1);776std::vector<int> c_buffer(N, 1);777
778VarHandle i("i", kInt);779auto expr =780For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));781
782LLVMCodeGen cg(expr, {a, b, c});783
784std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});785ASSERT_EQ(cg.value<int>(args), 0);786
787ASSERT_EQ(a_buffer.size(), N);788ASSERT_EQ(b_buffer.size(), N);789ASSERT_EQ(c_buffer.size(), N);790assertAllEqual(a_buffer, 41);791assertAllEqual(b_buffer, 1);792assertAllEqual(c_buffer, 41);793}
794
795TEST(LLVM, ElemwiseMinInt) {796constexpr int N = 1024;797BufHandle a("A", {N}, kInt);798BufHandle b("B", {N}, kInt);799BufHandle c("C", {N}, kInt);800std::vector<int> a_buffer(N, 41);801std::vector<int> b_buffer(N, 1);802std::vector<int> c_buffer(N, 1);803
804VarHandle i("i", kInt);805auto expr =806For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));807
808LLVMCodeGen cg(expr, {a, b, c});809
810std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});811ASSERT_EQ(cg.value<int>(args), 0);812
813ASSERT_EQ(a_buffer.size(), N);814ASSERT_EQ(b_buffer.size(), N);815ASSERT_EQ(c_buffer.size(), N);816assertAllEqual(a_buffer, 41);817assertAllEqual(b_buffer, 1);818assertAllEqual(c_buffer, 1);819}
820
821TEST(LLVM, ElemwiseMaxFloat) {822constexpr int N = 1024;823BufHandle a("A", {N}, kFloat);824BufHandle b("B", {N}, kFloat);825BufHandle c("C", {N}, kFloat);826std::vector<float> a_buffer(N, 41);827std::vector<float> b_buffer(N, 1);828std::vector<float> c_buffer(N, 1);829
830VarHandle i("i", kInt);831auto expr =832For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));833
834LLVMCodeGen cg(expr, {a, b, c});835
836std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});837ASSERT_EQ(cg.value<int>(args), 0);838
839ASSERT_EQ(a_buffer.size(), N);840ASSERT_EQ(b_buffer.size(), N);841ASSERT_EQ(c_buffer.size(), N);842assertAllEqual(a_buffer, 41.0f);843assertAllEqual(b_buffer, 1.0f);844assertAllEqual(c_buffer, 41.0f);845}
846
847TEST(LLVM, ElemwiseMaxNaNFloat) {848constexpr int N = 1024;849BufHandle a("A", {N}, kFloat);850BufHandle b("B", {N}, kFloat);851BufHandle c("C", {N}, kFloat);852std::vector<float> a_buffer(N, NAN);853std::vector<float> b_buffer(N, 1);854std::vector<float> c_buffer(N, 1);855
856VarHandle i("i", kInt);857auto expr =858For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));859
860LLVMCodeGen cg(expr, {a, b, c});861
862std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});863ASSERT_EQ(cg.value<int>(args), 0);864
865ASSERT_EQ(a_buffer.size(), N);866ASSERT_EQ(b_buffer.size(), N);867ASSERT_EQ(c_buffer.size(), N);868assertAllEqual(b_buffer, 1.0f);869for (auto const& elt : c_buffer) {870ASSERT_TRUE(std::isnan(elt));871}872}
873
874TEST(LLVM, ElemwiseMinFloat) {875constexpr int N = 1024;876BufHandle a("A", {N}, kFloat);877BufHandle b("B", {N}, kFloat);878BufHandle c("C", {N}, kFloat);879std::vector<float> a_buffer(N, 41);880std::vector<float> b_buffer(N, 1);881std::vector<float> c_buffer(N, 1);882
883VarHandle i("i", kInt);884auto expr =885For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));886
887LLVMCodeGen cg(expr, {a, b, c});888
889std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});890ASSERT_EQ(cg.value<int>(args), 0);891
892ASSERT_EQ(a_buffer.size(), N);893ASSERT_EQ(b_buffer.size(), N);894ASSERT_EQ(c_buffer.size(), N);895assertAllEqual(a_buffer, 41.0f);896assertAllEqual(b_buffer, 1.0f);897assertAllEqual(c_buffer, 1.0f);898}
899
900TEST(LLVM, ElemwiseMinNaNFloat) {901constexpr int N = 1024;902BufHandle a("A", {N}, kFloat);903BufHandle b("B", {N}, kFloat);904BufHandle c("C", {N}, kFloat);905std::vector<float> a_buffer(N, NAN);906std::vector<float> b_buffer(N, 1);907std::vector<float> c_buffer(N, 1);908
909VarHandle i("i", kInt);910auto expr =911For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));912
913LLVMCodeGen cg(expr, {a, b, c});914
915std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});916ASSERT_EQ(cg.value<int>(args), 0);917
918ASSERT_EQ(a_buffer.size(), N);919ASSERT_EQ(b_buffer.size(), N);920ASSERT_EQ(c_buffer.size(), N);921assertAllEqual(b_buffer, 1.0f);922for (auto const& elt : c_buffer) {923ASSERT_TRUE(std::isnan(elt));924}925}
926
927TEST(LLVM, ElemwiseMod) {928constexpr int N = 1024;929BufHandle a("A", {N}, kInt);930BufHandle b("B", {N}, kInt);931BufHandle c("C", {N}, kInt);932std::vector<int32_t> a_buffer(N, 41);933std::vector<int32_t> b_buffer(N, 23);934std::vector<int32_t> c_buffer(N, 18);935
936VarHandle i("i", kInt);937auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i))));938
939LLVMCodeGen cg(expr, {a, b, c});940
941std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});942ASSERT_EQ(cg.value<int>(args), 0);943
944ASSERT_EQ(a_buffer.size(), N);945ASSERT_EQ(b_buffer.size(), N);946ASSERT_EQ(c_buffer.size(), N);947assertAllEqual(a_buffer, 41);948assertAllEqual(b_buffer, 23);949assertAllEqual(c_buffer, 18);950}
951
952TEST(LLVM, CompareSelectIntEQ) {953constexpr int N = 1024;954BufHandle a("A", {N}, kInt);955BufHandle b("B", {N}, kInt);956BufHandle c("C", {N}, kInt);957std::vector<int> a_buffer(N, 1);958std::vector<int> b_buffer(N, 1);959std::vector<int> c_buffer(N, 0);960std::vector<int> c_ref(N, 1);961
962for (int i = 0; i < N / 2; i++) {963b_buffer[i] = 0;964c_ref[i] = 0;965}966
967VarHandle i("i", kInt);968auto expr = For::make(969i,9700,971N,972c.store(973{i},974CompareSelect::make(975a.load(i), b.load(i), CompareSelectOperation::kEQ)));976
977LLVMCodeGen cg(expr, {a, b, c});978
979std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});980ASSERT_EQ(cg.value<int>(args), 0);981
982ASSERT_EQ(a_buffer.size(), N);983ASSERT_EQ(b_buffer.size(), N);984ASSERT_EQ(c_buffer.size(), N);985
986assertAllEqual(a_buffer, 1);987for (const auto i : c10::irange(N)) {988ASSERT_EQ(c_ref[i], c_buffer[i]);989}990}
991
992TEST(LLVM, CompareSelectFloatEQ) {993constexpr int N = 1024;994BufHandle a("A", {N}, kFloat);995BufHandle b("B", {N}, kFloat);996BufHandle c("C", {N}, kInt);997std::vector<float> a_buffer(N, 1.0f);998std::vector<float> b_buffer(N, 1.0f);999std::vector<int> c_buffer(N, 0);1000
1001VarHandle i("i", kInt);1002auto expr = For::make(1003i,10040,1005N,1006c.store(1007{i},1008CompareSelect::make(1009a.load(i), b.load(i), CompareSelectOperation::kEQ)));1010
1011LLVMCodeGen cg(expr, {a, b, c});1012
1013std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});1014ASSERT_EQ(cg.value<int>(args), 0);1015
1016ASSERT_EQ(a_buffer.size(), N);1017ASSERT_EQ(b_buffer.size(), N);1018ASSERT_EQ(c_buffer.size(), N);1019
1020assertAllEqual(a_buffer, 1.0f);1021assertAllEqual(b_buffer, 1.0f);1022assertAllEqual(c_buffer, 1);1023}
1024
1025TEST(LLVM, CompareSelectByteGT) {1026constexpr int N = 1024;1027BufHandle a("A", {N}, kByte);1028BufHandle b("B", {N}, kByte);1029BufHandle c("C", {N}, kInt);1030std::vector<uint8_t> a_buffer(N, 0);1031std::vector<uint8_t> b_buffer(N, 0);1032std::vector<int> c_buffer(N, 0);1033std::vector<int> c_ref(N, 0);1034
1035for (int i = 0; i < N / 2; i++) {1036a_buffer[i] = 128;1037c_ref[i] = 1;1038}1039
1040VarHandle i("i", kInt);1041auto expr = For::make(1042i,10430,1044N,1045c.store(1046{i},1047CompareSelect::make(1048a.load(i), b.load(i), CompareSelectOperation::kGT)));1049
1050LLVMCodeGen cg(expr, {a, b, c});1051
1052std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});1053ASSERT_EQ(cg.value<int>(args), 0);1054
1055ASSERT_EQ(a_buffer.size(), N);1056ASSERT_EQ(b_buffer.size(), N);1057ASSERT_EQ(c_buffer.size(), N);1058
1059assertAllEqual(b_buffer, uint8_t(0));1060for (const auto i : c10::irange(N)) {1061ASSERT_EQ(c_ref[i], c_buffer[i]);1062}1063}
1064
1065TEST(LLVM, CompareSelectByteGE) {1066constexpr int N = 1024;1067BufHandle a("A", {N}, kByte);1068BufHandle b("B", {N}, kByte);1069BufHandle c("C", {N}, kInt);1070std::vector<uint8_t> a_buffer(N, 0);1071std::vector<uint8_t> b_buffer(N, 0);1072std::vector<int> c_buffer(N, 0);1073std::vector<int> c_ref(N, 1);1074
1075VarHandle i("i", kInt);1076auto expr = For::make(1077i,10780,1079N,1080c.store(1081{i},1082CompareSelect::make(1083a.load(i), b.load(i), CompareSelectOperation::kGE)));1084
1085LLVMCodeGen cg(expr, {a, b, c});1086
1087std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});1088ASSERT_EQ(cg.value<int>(args), 0);1089
1090ASSERT_EQ(a_buffer.size(), N);1091ASSERT_EQ(b_buffer.size(), N);1092ASSERT_EQ(c_buffer.size(), N);1093
1094assertAllEqual(b_buffer, uint8_t(0));1095for (const auto i : c10::irange(N)) {1096ASSERT_EQ(c_ref[i], c_buffer[i]);1097}1098}
1099
1100TEST(LLVM, CompareSelectByteLT) {1101constexpr int N = 1024;1102BufHandle a("A", {N}, kByte);1103BufHandle b("B", {N}, kByte);1104BufHandle c("C", {N}, kInt);1105std::vector<uint8_t> a_buffer(N, 0);1106std::vector<uint8_t> b_buffer(N, 128);1107std::vector<int> c_buffer(N, 0);1108std::vector<int> c_ref(N, 1);1109
1110for (int i = 0; i < N / 2; i++) {1111a_buffer[i] = 128;1112c_ref[i] = 0;1113}1114
1115VarHandle i("i", kInt);1116auto expr = For::make(1117i,11180,1119N,1120c.store(1121{i},1122CompareSelect::make(1123a.load(i), b.load(i), CompareSelectOperation::kLT)));1124
1125LLVMCodeGen cg(expr, {a, b, c});1126
1127std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});1128ASSERT_EQ(cg.value<int>(args), 0);1129
1130ASSERT_EQ(a_buffer.size(), N);1131ASSERT_EQ(b_buffer.size(), N);1132ASSERT_EQ(c_buffer.size(), N);1133
1134assertAllEqual(b_buffer, uint8_t(128));1135for (const auto i : c10::irange(N)) {1136ASSERT_EQ(c_ref[i], c_buffer[i]);1137}1138}
1139
1140TEST(LLVM, CompareSelectByteLE) {1141constexpr int N = 1024;1142BufHandle a("A", {N}, kByte);1143BufHandle b("B", {N}, kByte);1144BufHandle c("C", {N}, kInt);1145std::vector<uint8_t> a_buffer(N, 0);1146std::vector<uint8_t> b_buffer(N, 128);1147std::vector<int> c_buffer(N, 0);1148std::vector<int> c_ref(N, 1);1149
1150VarHandle i("i", kInt);1151auto expr = For::make(1152i,11530,1154N,1155c.store(1156{i},1157CompareSelect::make(1158a.load(i), b.load(i), CompareSelectOperation::kLE)));1159
1160LLVMCodeGen cg(expr, {a, b, c});1161
1162std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});1163ASSERT_EQ(cg.value<int>(args), 0);1164
1165ASSERT_EQ(a_buffer.size(), N);1166ASSERT_EQ(b_buffer.size(), N);1167ASSERT_EQ(c_buffer.size(), N);1168
1169assertAllEqual(b_buffer, uint8_t(128));1170for (const auto i : c10::irange(N)) {1171ASSERT_EQ(c_ref[i], c_buffer[i]);1172}1173}
1174
1175TEST(LLVM, StoreFloat) {1176BufHandle result("result", {1}, kFloat);1177std::vector<float> result_buffer = {0.0f};1178auto expr = result.store({0}, FloatImm::make(3.14f));1179LLVMCodeGen cg(expr, {result});1180std::vector<void*> args({result_buffer.data()});1181ASSERT_EQ(cg.value<int>(args), 0);1182ASSERT_EQ(result_buffer[0], 3.14f);1183}
1184
1185TEST(LLVM, SimpleMath01) {1186const int N = 1024;1187Tensor tensor = Compute(1188"f", {N}, [](const VarHandle& i) { return cast<float>(i * i + 1); });1189LoopNest l({tensor});1190StmtPtr stmt = l.root_stmt();1191BufHandle f_buf(tensor.buf());1192LLVMCodeGen cg(stmt, {f_buf});1193
1194PaddedBuffer<float> f_v(N, "f_v");1195std::vector<void*> args({f_v.data()});1196int value = cg.value<int>(args);1197ASSERT_EQ(value, 0);1198PaddedBuffer<float> f_ref(N, "f_ref");1199for (const auto i : c10::irange(N)) {1200f_ref(i) = i * i + 1;1201}1202ExpectAllNear(f_v, f_ref, 1e-5);1203}
1204
1205TEST(LLVM, ComputeMul) {1206const int N = 1024;1207BufHandle a("a", {N}, kFloat);1208BufHandle b("b", {N}, kFloat);1209Tensor c = Compute(1210"c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); });1211
1212BufHandle c_buf(c.buf());1213LoopNest l({c});1214StmtPtr s = l.root_stmt();1215
1216LLVMCodeGen cg(s, {a, b, c_buf});1217
1218std::vector<float> a_vec(N, 21.0f);1219std::vector<float> b_vec(N, 2.0f);1220std::vector<float> c_vec(N, 0.0f);1221std::vector<void*> args({a_vec.data(), b_vec.data(), c_vec.data()});1222ASSERT_EQ(cg.value<int>(args), 0);1223assertAllEqual(c_vec, 42.0f);1224}
1225
1226TEST(LLVM, BroadcastAdd) {1227const int M = 32;1228const int N = 1024;1229BufHandle a("a", {M, N}, kFloat);1230BufHandle b("b", {N}, kFloat);1231Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {1232return a.load(i, j) + b.load(j);1233});1234
1235BufHandle c_buf(c.buf());1236LoopNest l({c});1237l.prepareForCodegen();1238StmtPtr s = l.root_stmt();1239
1240LLVMCodeGen cg(s, {a, b, c_buf});1241
1242std::vector<float> av(M * N);1243std::iota(av.begin(), av.end(), 0);1244std::vector<float> bv(N);1245std::iota(bv.begin(), bv.end(), 0);1246std::vector<float> cv(M * N, 0);1247std::vector<void*> args({av.data(), bv.data(), cv.data()});1248ASSERT_EQ(cg.value<int>(args), 0);1249
1250for (const auto i : c10::irange(M)) {1251for (const auto j : c10::irange(N)) {1252ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]);1253}1254}1255}
1256
1257TEST(LLVM, BitwiseOps) {1258auto a = IntImm::make(59);1259auto b = IntImm::make(11);1260auto c = IntImm::make(101);1261auto d = IntImm::make(2);1262
1263ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;1264LLVMExprEval cg(f);1265
1266ASSERT_EQ(cg.value<int>(), 11);1267}
1268
1269TEST(LLVM, ArithmeticRightShift) {1270auto a = CharImm::make(-4);1271auto b = CharImm::make(1);1272ExprHandle f = a >> b;1273LLVMExprEval cg(f);1274ASSERT_EQ(cg.value<int8_t>(), -2);1275}
1276
1277TEST(LLVM, LogicalRightShift) {1278auto a = ByteImm::make(0xfc);1279auto b = ByteImm::make(1);1280ExprHandle f = a >> b;1281LLVMExprEval cg(f);1282ASSERT_EQ(cg.value<uint8_t>(), 0x7e);1283}
1284
1285TEST(LLVM, DynamicShapeAdd) {1286auto testWithSize = [](int32_t size) {1287VarHandle n("n", kInt);1288BufHandle a("a", {n}, kFloat);1289BufHandle b("b", {n}, kFloat);1290BufHandle c("c", {n}, kFloat);1291VarHandle i("i", kInt);1292StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));1293std::vector<float> aData(size, 1.0f);1294std::vector<float> bData(size, 2.0f);1295std::vector<float> cData(size, 0.0f);1296LLVMCodeGen cg(s, {a, b, c, n});1297std::vector<void*> args({aData.data(), bData.data(), cData.data(), &size});1298cg.value<float>(args);1299ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);1300};1301testWithSize(1);1302testWithSize(16);1303testWithSize(37);1304}
1305
1306TEST(LLVM, BindDynamicShapeAdd) {1307auto testWithSize = [](int32_t size) {1308VarHandle n("n", kInt);1309BufHandle a("a", {n}, kFloat);1310BufHandle b("b", {n}, kFloat);1311BufHandle c("c", {n}, kFloat);1312VarHandle i("i", kInt);1313StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));1314std::vector<float> aData(size, 1.0f);1315std::vector<float> bData(size, 2.0f);1316std::vector<float> cData(size, 0.0f);1317LLVMCodeGen cg(s, {a, b, c, n});1318cg.call({aData, bData, cData, size});1319ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);1320};1321testWithSize(1);1322testWithSize(16);1323testWithSize(37);1324}
1325
1326TEST(LLVM, TensorDynamicShapeAdd) {1327auto testWithSize = [](int32_t size) {1328VarHandle n("n", kInt);1329BufHandle a("a", {n}, kFloat);1330BufHandle b("b", {n}, kFloat);1331Tensor c = Compute(1332"c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); });1333LoopNest l({c});1334StmtPtr s = l.root_stmt();1335LLVMCodeGen cg(s, {a, b, c, n});1336std::vector<float> aData(size, 1.0f);1337std::vector<float> bData(size, 2.0f);1338std::vector<float> cData(size, 0.0f);1339cg.call({aData, bData, cData, size});1340ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);1341};1342testWithSize(1);1343testWithSize(16);1344testWithSize(37);1345}
1346
1347TEST(LLVM, DynamicShape2D) {1348auto testWithSize = [](int32_t M, int32_t N) {1349VarHandle m("m", kInt);1350VarHandle n("n", kInt);1351BufHandle a("a", {m, n}, kFloat);1352BufHandle b("b", {m, n}, kFloat);1353Tensor c =1354Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {1355return a.load(i, j) + b.load(i, j);1356});1357LoopNest l({c});1358l.prepareForCodegen();1359StmtPtr s = l.root_stmt();1360LLVMCodeGen cg(s, {a, b, c, m, n});1361std::vector<float> aData(M * N, 1.0f);1362std::vector<float> bData(M * N, 2.0f);1363std::vector<float> cData(M * N, 0.0f);1364cg.call({aData, bData, cData, M, N});1365ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);1366};1367testWithSize(1, 8);1368testWithSize(16, 32);1369testWithSize(37, 11);1370}
1371
1372TEST(LLVM, EmptyStmt) {1373StmtPtr s = alloc<Block>(std::vector<StmtPtr>({}));1374
1375LLVMCodeGen cg(s, {});1376cg.call({});1377// Just don't crash.1378}
1379
1380TEST(LLVM, EliminatedStmt) {1381BufHandle a("a", {1}, kFloat);1382
1383Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; });1384
1385LoopNest l({c});1386l.prepareForCodegen();1387StmtPtr s = l.root_stmt();1388s = IRSimplifier::simplify(s);1389LLVMCodeGen cg(s, {a, c});1390std::vector<float> aData(1, 1.0f);1391std::vector<float> cData(0, 0.0f);1392cg.call({aData, cData});1393}
1394
1395TEST(LLVM, SimpleReduction) {1396int M = 128;1397int N = 64;1398
1399BufHandle a("a", {1, M, N}, kFloat);1400
1401Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});1402LoopNest loop({b});1403
1404loop.prepareForCodegen();1405StmtPtr s = loop.root_stmt();1406s = IRSimplifier::simplify(s);1407
1408LLVMCodeGen cg(s, {a, b});1409
1410PaddedBuffer<float> a_v(1, M, N, "a_v");1411PaddedBuffer<float> b_v(1, "b_v");1412PaddedBuffer<float> b_ref(1, "b_ref");1413
1414b_ref(0) = 0;1415for (const auto i : c10::irange(M)) {1416for (const auto j : c10::irange(N)) {1417int v = i + j;1418a_v(0, i, j) = v;1419b_ref(0) += v;1420}1421}1422
1423cg.call({a_v, b_v});1424
1425ExpectAllNear(b_v, b_ref, 1e-5);1426}
1427
1428TEST(LLVM, RFactorReduction) {1429int M = 128;1430int N = 64;1431
1432BufHandle a("a", {1, M, N}, kFloat);1433
1434Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});1435LoopNest loop({b});1436
1437std::vector<ForPtr> loops = loop.getLoopStmtsFor(b);1438ForPtr loop_m = loops.at(1);1439ForPtr loop_n = loops.at(2);1440loop.reorderAxis(loop_m, loop_n);1441
1442loops = loop.getLoopStmtsFor(b);1443loop_m = loops.at(2);1444loop_n = loops.at(1);1445auto b_body = loop.getAllWritesToBuf(b.buf())[1];1446ASSERT_TRUE(loop.rfactor(b_body, loop_n));1447
1448loop.prepareForCodegen();1449StmtPtr s = loop.root_stmt();1450s = IRSimplifier::simplify(s);1451
1452LLVMCodeGen cg(s, {a, b});1453
1454PaddedBuffer<float> a_v(1, M, N, "a_v");1455PaddedBuffer<float> b_v(1, "b_v");1456PaddedBuffer<float> b_ref(1, "b_ref");1457
1458b_ref(0) = 0;1459for (const auto i : c10::irange(M)) {1460for (const auto j : c10::irange(N)) {1461int v = i + j;1462a_v(0, i, j) = v;1463b_ref(0) += v;1464}1465}1466
1467cg.call({a_v, b_v});1468
1469ExpectAllNear(b_v, b_ref, 1e-5);1470}
1471
1472TEST(LLVM, RFactorVectorizedReduction) {1473int M = 128;1474int N = 64;1475
1476BufHandle a("a", {1, M, N}, kFloat);1477
1478Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});1479LoopNest loopnest({b});1480std::vector<ForPtr> loops = loopnest.getLoopStmtsFor(b);1481// Reorder n and m loops1482loopnest.reorderAxis(loops.at(1), loops.at(2));1483auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1);1484auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf());1485ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3);1486ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1]));1487auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]);1488
1489// Vectorize initializer of rfac_buf1490ASSERT_TRUE(LoopNest::vectorize(distributed_loops[0]));1491// Vectorize producer of rfac_buf1492ASSERT_TRUE(LoopNest::vectorize(distributed_loops[1]));1493loopnest.simplify();1494
1495loopnest.prepareForCodegen();1496
1497StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt());1498LLVMCodeGen cg(s, {a, b});1499
1500PaddedBuffer<float> a_v(1, M, N, "a_v");1501PaddedBuffer<float> b_v(1, "b_v");1502PaddedBuffer<float> b_ref(1, "b_ref");1503
1504b_ref(0) = 0;1505for (const auto i : c10::irange(M)) {1506for (const auto j : c10::irange(N)) {1507int v = i + j;1508a_v(0, i, j) = v;1509b_ref(0) += v;1510}1511}1512
1513cg.call({a_v, b_v});1514
1515ExpectAllNear(b_v, b_ref, 1e-5);1516}
1517
1518template <bool outer, bool inner>1519static void testSimpleParallel() {1520// Compute a simple operation, and try all loop-axis combination to be1521// parallel or sequential.1522const int M = 4;1523const int N = 6;1524Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) {1525return cast<float>(m + n);1526});1527LoopNest loop_nest({f});1528auto const& loops = loop_nest.getLoopStmtsFor(f);1529ForPtr m = loops[0];1530ForPtr n = loops[1];1531if (outer) {1532m->set_parallel();1533}1534if (inner) {1535n->set_parallel();1536}1537loop_nest.prepareForCodegen();1538StmtPtr stmt = loop_nest.root_stmt();1539LLVMCodeGen cg(stmt, {f});1540
1541PaddedBuffer<float> f_v(M, N, "f_v");1542std::vector<void*> args({f_v.data()});1543int value = cg.value<int>(args);1544ASSERT_EQ(value, 0);1545PaddedBuffer<float> f_ref(M, N, "f_ref");1546for (const auto m : c10::irange(M)) {1547for (const auto n : c10::irange(N)) {1548f_ref(m, n) = m + n;1549}1550}1551ExpectAllNear(f_v, f_ref, 1e-5);1552}
1553
1554TEST(LLVM, SimpleParallelSS) {1555testSimpleParallel<false, false>();1556}
1557TEST(LLVM, SimpleParallelSP) {1558testSimpleParallel<false, true>();1559}
1560TEST(LLVM, SimpleParallelPS) {1561testSimpleParallel<true, false>();1562}
1563TEST(LLVM, SimpleParallelPP) {1564testSimpleParallel<true, true>();1565}
1566
1567TEST(LLVM, CompositeParallel) {1568int loop_count = 6;1569int test_count = 1 << loop_count;1570// Compute a composite operation, and try all loop-axis combination to be1571// parallel or sequential.1572for (const auto test_cfg : c10::irange(test_count)) {1573int M = 5;1574int N = 7;1575Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; });1576Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; });1577Tensor t3 =1578Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) {1579return t1.load(m) * t2.load(n);1580});1581Tensor t4 =1582Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) {1583return t3.load(m, n) + m + n;1584});1585LoopNest loop_nest({t4}, {t1, t2, t3, t4});1586std::vector<ForPtr> loop_list;1587{1588auto const& loops = loop_nest.getLoopStmtsFor(t1);1589loop_list.push_back(loops[0]);1590}1591{1592auto const& loops = loop_nest.getLoopStmtsFor(t2);1593loop_list.push_back(loops[0]);1594}1595{1596auto const& loops = loop_nest.getLoopStmtsFor(t3);1597loop_list.push_back(loops[0]);1598loop_list.push_back(loops[1]);1599}1600{1601auto const& loops = loop_nest.getLoopStmtsFor(t4);1602loop_list.push_back(loops[0]);1603loop_list.push_back(loops[1]);1604}1605ASSERT_EQ(loop_list.size(), loop_count);1606for (const auto i : c10::irange(loop_count)) {1607if (test_cfg & (1 << i)) {1608loop_list[i]->set_parallel();1609}1610}1611loop_nest.prepareForCodegen();1612StmtPtr stmt = loop_nest.root_stmt();1613LLVMCodeGen cg(stmt, {t4});1614
1615PaddedBuffer<float> t4_v(M, N, "t4_v");1616std::vector<void*> args({t4_v.data()});1617int value = cg.value<int>(args);1618ASSERT_EQ(value, 0);1619PaddedBuffer<float> t4_ref(M, N, "t4_ref");1620for (const auto m : c10::irange(M)) {1621for (const auto n : c10::irange(N)) {1622t4_ref(m, n) = (m + 1) * (n + 2) + m + n;1623}1624}1625ExpectAllNear(t4_v, t4_ref, 1e-5);1626}1627}
1628
1629TEST(LLVM, VectorizedGEMM) {1630int M = 32;1631int N = 32;1632int K = 48;1633
1634BufHandle AP("A", {M, K}, kFloat);1635BufHandle BP("B", {K, N}, kFloat);1636Tensor CT = Reduce(1637"gemm",1638{M, N},1639Sum(),1640[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {1641return AP.load(m, k) * BP.load(k, n);1642},1643{K});1644LoopNest loop({CT});1645
1646{1647auto const& loops = loop.getLoopStmtsFor(CT);1648ForPtr m = loops[0];1649loop.splitWithMask(m, 16);1650}1651{1652auto const& loops = loop.getLoopStmtsFor(CT);1653ForPtr n = loops[2];1654loop.splitWithMask(n, 16);1655}1656// mo, mi, no, ni, k ->1657// mo, no, mi, ni, k1658{1659auto const& loops = loop.getLoopStmtsFor(CT);1660ForPtr mi = loops[1];1661ForPtr no = loops[2];1662loop.reorderAxis(mi, no);1663}1664// mo, no, mi, ni, k ->1665// mo, no, mi, k, ni1666{1667auto const& loops = loop.getLoopStmtsFor(CT);1668ForPtr ni = loops[3];1669ForPtr k = loops[4];1670loop.reorderAxis(ni, k);1671}1672// mo, no, mi, k, ni ->1673// mo, no, k, mi, ni1674{1675auto const& loops = loop.getLoopStmtsFor(CT);1676ForPtr mi = loops[2];1677ForPtr k = loops[3];1678loop.reorderAxis(mi, k);1679}1680{1681auto loops = NodeFinder<For>::find(loop.root_stmt());1682ASSERT_TRUE(LoopNest::vectorize(loops[3]));1683ASSERT_TRUE(LoopNest::vectorize(loops.back()));1684}1685
1686loop.prepareForCodegen();1687
1688StmtPtr s = loop.root_stmt();1689s = IRSimplifier::simplify(s);1690LLVMCodeGen cg(s, {AP, BP, CT});1691
1692PaddedBuffer<float> a_v(M, K, "a_v");1693PaddedBuffer<float> b_v(K, N, "b_v");1694PaddedBuffer<float> c_v(M, N, "c_v");1695PaddedBuffer<float> c_ref(M, N, "c_ref");1696
1697for (const auto m : c10::irange(M)) {1698for (const auto n : c10::irange(N)) {1699c_ref(m, n) = 0.f;1700for (const auto k : c10::irange(K)) {1701c_ref(m, n) += a_v(m, k) * b_v(k, n);1702}1703}1704}1705
1706cg.call({a_v, b_v, c_v});1707
1708ExpectAllNear(c_v, c_ref, 1e-5);1709}
1710
1711TEST(LLVM, CallRaw) {1712const int M = 32;1713VarHandle N("N", kInt);1714BufHandle a("a", {M, N}, kFloat);1715BufHandle b("b", {N}, kFloat);1716Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {1717return a.load(i, j) + b.load(j);1718});1719
1720LoopNest l({c});1721l.prepareForCodegen();1722StmtPtr s = l.root_stmt();1723
1724int32_t N_value = 1024;1725std::vector<float> av(M * N_value);1726std::iota(av.begin(), av.end(), 0);1727std::vector<float> bv(N_value);1728std::iota(bv.begin(), bv.end(), 0);1729std::vector<float> cv(M * N_value, 0);1730std::vector<void*> args({av.data(), bv.data(), cv.data(), &N_value});1731
1732LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N});1733cg.call_raw(args);1734
1735for (const auto i : c10::irange(M)) {1736for (const auto j : c10::irange(N_value)) {1737ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]);1738}1739}1740
1741SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N});1742eval.call_raw(args);1743
1744for (const auto i : c10::irange(M)) {1745for (const auto j : c10::irange(N_value)) {1746ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]);1747}1748}1749}
1750
1751TEST(LLVM, CustomTarget) {1752constexpr int M = 16;1753BufHandle a("a", {M}, kFloat);1754BufHandle b("b", {M}, kFloat);1755BufHandle c("c", {M}, kFloat);1756Tensor d = Compute("d", {M}, [&](const VarHandle& m) {1757return a.load(m) * b.load(m) + c.load(m);1758});1759LoopNest nest({d});1760nest.prepareForCodegen();1761auto cg = LLVMCodeGenBuilder(nest.root_stmt(), {a, b, c, d})1762.triple("i686-elf")1763.cpu("i386")1764.build();1765std::ostringstream ss;1766ss << cg->getCodeText("asm");1767torch::jit::testing::FileCheck()1768.check("fadds")1769->check("fmuls")1770->check_not("vfmadd")1771->run(ss.str());1772}
1773
1774TEST(LLVM, CodeGenKernelFuncName) {1775BufHandle a("A", {1}, kInt);1776BufHandle b("B", {1}, kInt);1777std::vector<int32_t> a_buffer = {42};1778std::vector<int32_t> b_buffer = {-11};1779auto store = b.store({0}, a.load(0));1780
1781{1782LLVMCodeGen cg(store, {a, b});1783// Check that the kernel function name used by LLVMCodeGen1784// is not empty.1785ASSERT_NE(cg.kernel_func_name(), "");1786}1787
1788{1789LLVMCodeGen cg(store, {a, b}, at::kCPU, "new_func");1790// Check that the kernel function name used by LLVMCodeGen1791// is the one that was given above.1792ASSERT_EQ(cg.kernel_func_name(), "new_func");1793}1794}
1795
1796} // namespace jit1797} // namespace torch1798
1799#endif // TORCH_ENABLE_LLVM1800