pytorch
836 строк · 25.5 Кб
1#include <gtest/gtest.h>
2
3#include <test/cpp/tensorexpr/test_base.h>
4
5#include <c10/util/irange.h>
6#include <test/cpp/tensorexpr/padded_buffer.h>
7#include <test/cpp/tensorexpr/test_utils.h>
8#include <torch/csrc/jit/tensorexpr/eval.h>
9#include <torch/csrc/jit/tensorexpr/ir.h>
10#include <torch/csrc/jit/tensorexpr/ir_printer.h>
11#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
12#include <torch/csrc/jit/tensorexpr/ir_verifier.h>
13#include <torch/csrc/jit/tensorexpr/loopnest.h>
14#include <torch/csrc/jit/tensorexpr/tensor.h>
15
16#include <cmath>
17#include <sstream>
18#include <stdexcept>
19#include <string>
20#include <vector>
21
22namespace torch {
23namespace jit {
24using namespace torch::jit::tensorexpr;
25
26using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
27
28TEST(Expr, BasicValueTest) {
29ExprHandle a = IntImm::make(2), b = IntImm::make(3);
30ExprHandle c = Add::make(a, b);
31SimpleIRExprEval eval(c);
32ASSERT_EQ(eval.value<int>(), 5);
33}
34
35TEST(Expr, BasicValueTest02) {
36ExprHandle a(2.0f);
37ExprHandle b(3.0f);
38ExprHandle c(4.0f);
39ExprHandle d(5.0f);
40ExprHandle f = (a + b) - (c + d);
41SimpleIRExprEval eval(f);
42ASSERT_EQ(eval.value<float>(), -4.0f);
43}
44
45TEST(Expr, IsChannelsLastContiguous) {
46std::vector<VarHandle> vars = {
47VarHandle("var1", kLong),
48VarHandle("var2", kLong),
49VarHandle("var3", kLong),
50VarHandle("var4", kLong),
51VarHandle("var5", kLong)};
52
53// {
54// key: ndims,
55// value: [
56// ...
57// [dim_2, dim_1, ..., dim_n]
58// ]
59// }
60using shapGenInfo = std::unordered_map<int, std::vector<std::vector<int>>>;
61
62// {
63// size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n],
64// strides: [
65// ...
66// [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z]
67// ]
68// }
69using shapeInfo =
70std::pair<std::vector<ExprHandle>, std::vector<std::vector<ExprHandle>>>;
71
72std::vector<int> dims = {3, 4, 5};
73
74std::unordered_map<int, std::vector<ExprHandle>> dims_expr_vec_conf = {
75{3, std::vector<ExprHandle>(vars.begin(), vars.begin() + 2)},
76{4, std::vector<ExprHandle>(vars.begin(), vars.begin() + 3)},
77{5, std::vector<ExprHandle>(vars.begin(), vars.begin() + 4)},
78};
79
80shapGenInfo channels_last_cont_shape_conf = {
81{3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}};
82shapGenInfo channels_last_non_cont_shape_conf = {
83{3, {{2, 1, 0}, {1, 0, 2}}},
84{4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}},
85{5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}};
86
87shapGenInfo cont_shape_conf = {
88{3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}};
89
90auto shape_gen_fn = [dims_expr_vec_conf](
91int ndims, shapGenInfo shape_gen_info) -> shapeInfo {
92auto dims_expr_vec = dims_expr_vec_conf.at(ndims);
93std::vector<std::vector<ExprHandle>> strides_expr_vec;
94for (size_t i = 0; i < strides_expr_vec.size(); i++) {
95strides_expr_vec[i].resize(ndims);
96}
97
98auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) {
99if (indicator % 2 == 0) {
100return a * b;
101} else {
102return b * a;
103}
104};
105
106auto stride_order_vec = shape_gen_info.at(ndims);
107for (size_t i = 0; i < strides_expr_vec.size(); i++) {
108auto stride_order = stride_order_vec[i];
109
110strides_expr_vec[i][stride_order[0]] = 1;
111for (size_t j = 1; j < stride_order.size(); j++) {
112auto cur_dim_idx = stride_order[j];
113auto adjacent_dim_idx = stride_order[j - 1];
114
115strides_expr_vec[i][cur_dim_idx] = stride_gen_fn(
116i,
117dims_expr_vec[adjacent_dim_idx],
118strides_expr_vec[i][adjacent_dim_idx]);
119}
120}
121
122return {dims_expr_vec, strides_expr_vec};
123};
124
125auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool {
126if (ndims == 3) {
127return buf_handle.is_channels_last_1d_contiguous();
128} else if (ndims == 4) {
129return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast);
130} else {
131return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d);
132}
133};
134
135// channels-last contiguous
136for (size_t i = 0; i < dims.size(); i++) {
137auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
138for (size_t j = 0; j < shape_info.second.size(); j++) {
139BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
140ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true);
141}
142}
143
144// channels-last non-contiguous
145for (size_t i = 0; i < dims.size(); i++) {
146auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf);
147for (size_t j = 0; j < shape_info.second.size(); j++) {
148BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
149ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false);
150}
151}
152
153// contiguous
154for (size_t i = 0; i < dims.size(); i++) {
155auto shape_info = shape_gen_fn(dims[i], cont_shape_conf);
156for (size_t j = 0; j < shape_info.second.size(); j++) {
157BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
158ASSERT_EQ(buf_handle.is_contiguous(), true);
159}
160}
161
162// non-contiguous
163for (size_t i = 0; i < dims.size(); i++) {
164auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
165for (size_t j = 0; j < shape_info.second.size(); j++) {
166BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
167ASSERT_EQ(buf_handle.is_contiguous(), false);
168}
169}
170}
171
172TEST(Expr, LetTest01) {
173VarHandle x("x", kFloat);
174ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
175SimpleIRExprEval eval(body);
176eval.bindVar(x, ExprHandle(3.f));
177ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
178}
179
180TEST(Expr, LetTest02) {
181VarHandle x("x", kFloat);
182VarHandle y("y", kFloat);
183ExprHandle body =
184ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y);
185SimpleIRExprEval eval(body);
186eval.bindVar(x, ExprHandle(3.f));
187eval.bindVar(y, ExprHandle(6.f));
188ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
189}
190
191TEST(Expr, LetStmtTest01) {
192BufHandle a_buf("a", {1}, kFloat);
193BufHandle b_buf("b", {1}, kFloat);
194
195ExprHandle load_a = a_buf.load(0);
196VarHandle var = VarHandle("v", kFloat);
197StmtPtr let_store = Let::make(var, load_a);
198StmtPtr store_b = b_buf.store({0}, var);
199BlockPtr block = Block::make({let_store, store_b});
200
201SimpleIREvaluator eval(block, {a_buf, b_buf});
202
203PaddedBuffer<float> a_v(1);
204PaddedBuffer<float> b_v(1);
205PaddedBuffer<float> b_ref(1);
206
207a_v(0) = 23;
208b_ref(0) = a_v(0);
209eval(a_v, b_v);
210
211ExpectAllNear(b_v, b_ref, 1e-5);
212}
213
214TEST(Expr, IntTest) {
215VarHandle x("x", kInt);
216ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4));
217SimpleIRExprEval eval(body);
218eval.bindVar(x, ExprHandle(3));
219ASSERT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
220}
221
222TEST(Expr, FloatTest) {
223VarHandle x("x", kFloat);
224ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
225SimpleIRExprEval eval(body);
226eval.bindVar(x, ExprHandle(3.f));
227ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
228}
229
230TEST(Expr, ByteTest) {
231VarHandle x("x", kByte);
232ExprHandle body = ExprHandle((uint8_t)2) +
233(x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4));
234SimpleIRExprEval eval(body);
235eval.bindVar(x, ExprHandle((uint8_t)3));
236ASSERT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
237}
238
239TEST(Expr, CharTest) {
240VarHandle x("x", kChar);
241ExprHandle body = ExprHandle((int8_t)2) +
242(x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4));
243SimpleIRExprEval eval(body);
244eval.bindVar(x, ExprHandle((int8_t)3));
245ASSERT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
246}
247
248TEST(Expr, ShortTest) {
249VarHandle x("x", kShort);
250ExprHandle body = ExprHandle((int16_t)2) +
251(x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4));
252SimpleIRExprEval eval(body);
253eval.bindVar(x, ExprHandle((int16_t)3));
254ASSERT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
255}
256
257TEST(Expr, LongTest) {
258VarHandle x("x", kLong);
259ExprHandle body = ExprHandle((int64_t)2) +
260(x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4));
261SimpleIRExprEval eval(body);
262eval.bindVar(x, ExprHandle((int64_t)3));
263ASSERT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
264}
265
266TEST(Expr, HalfTest) {
267VarHandle x("x", kHalf);
268ExprHandle body = ExprHandle((at::Half)2) +
269(x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4));
270SimpleIRExprEval eval(body);
271eval.bindVar(x, ExprHandle((at::Half)3));
272ASSERT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
273}
274
275TEST(Expr, DoubleTest) {
276VarHandle x("x", kDouble);
277ExprHandle body = ExprHandle((double)2) +
278(x * ExprHandle((double)3) + ExprHandle((double)4));
279SimpleIRExprEval eval(body);
280eval.bindVar(x, ExprHandle((double)3));
281ASSERT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
282}
283
284TEST(Expr, VectorAdd01) {
285const int kVectorSize = 8;
286const int kVectorCount = 128;
287const int kTotalSize = kVectorSize * kVectorCount;
288
289BufHandle a_buf("A", {kTotalSize}, kFloat);
290BufHandle b_buf("B", {kTotalSize}, kFloat);
291BufHandle c_buf("C", {kTotalSize}, kFloat);
292
293/*
294Build the following:
295for (const auto index : c10::irange(kVectorCount)) {
296store(c_buf, ramp(index * 8, 1, 8),
297load(a_buf, ramp(index * 8, 1, 8) +
298load(b_buf, ramp(index * 8, 1, 8))))
299}
300*/
301VarHandle index = VarHandle("index", kInt);
302ExprHandle load_a =
303a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
304ExprHandle load_b =
305b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
306ExprHandle value = load_a + load_b;
307StmtPtr store_c =
308c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value);
309StmtPtr stmt = For::make(index, 0, kVectorCount, store_c);
310
311ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
312ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
313ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
314
315PaddedBuffer<float> a_v(kTotalSize);
316PaddedBuffer<float> b_v(kTotalSize);
317PaddedBuffer<float> c_v(kTotalSize);
318PaddedBuffer<float> c_ref(kTotalSize);
319for (const auto i : c10::irange(kTotalSize)) {
320a_v(i) = i * i;
321b_v(i) = i * i * 4;
322c_ref(i) = a_v(i) + b_v(i);
323}
324SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
325ir_eval(a_v, b_v, c_v);
326ExpectAllNear(c_v, c_ref, 1e-5);
327}
328
329TEST(Expr, CompareSelectEQ) {
330constexpr int N = 1024;
331BufHandle a("A", {N}, kInt);
332BufHandle b("B", {N}, kInt);
333BufHandle c("C", {N}, kInt);
334std::vector<int> a_buffer(N, 1);
335std::vector<int> b_buffer(N, 1);
336std::vector<int> c_buffer(N, 0);
337std::vector<int> c_ref(N, 0);
338
339VarHandle i("i", kInt);
340auto memcpy_expr = For::make(
341i,
3420,
343N,
344c.store(
345{i},
346CompareSelect::make(
347a.load(i), b.load(i), CompareSelectOperation::kEQ)));
348
349SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
350ir_eval(a_buffer, b_buffer, c_buffer);
351
352ASSERT_EQ(a_buffer.size(), N);
353ASSERT_EQ(b_buffer.size(), N);
354ASSERT_EQ(c_buffer.size(), N);
355
356assertAllEqual(a_buffer, 1);
357assertAllEqual(b_buffer, 1);
358assertAllEqual(c_buffer, 1);
359}
360
361TEST(Expr, CompareSelectDtypes) {
362// LHS and RHS expressions should have the same dtype, but this dtype could
363// differ from the dtype of the return values (but dtypes of true and false
364// return values should be the same).
365// This test constructs a CompareSelect expression where the input dtype is
366// different from the output dtype and verifies that it works correctly:
367// result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2
368constexpr int N = 1024;
369BufHandle a("A", {N}, kInt);
370BufHandle b("B", {N}, kInt);
371BufHandle c("C", {N}, kFloat);
372std::vector<int> a_buffer(N, 1);
373std::vector<int> b_buffer(N, 1);
374std::vector<float> c_buffer(N, 0.0f);
375std::vector<float> c_ref(N, 3.14f);
376
377VarHandle i("i", kInt);
378// C[i] = (A[i] == B[i]) ? 3.14f : 2.78f
379// A and B are int, C is float.
380auto select_expr = For::make(
381i,
3820,
383N,
384c.store(
385{i},
386CompareSelect::make(
387a.load(i),
388b.load(i),
389FloatImm::make(3.14f),
390FloatImm::make(2.78f),
391CompareSelectOperation::kEQ)));
392
393SimpleIREvaluator ir_eval(select_expr, {a, b, c});
394ir_eval(a_buffer, b_buffer, c_buffer);
395
396ASSERT_EQ(a_buffer.size(), N);
397ASSERT_EQ(b_buffer.size(), N);
398ASSERT_EQ(c_buffer.size(), N);
399
400assertAllEqual(a_buffer, 1);
401assertAllEqual(b_buffer, 1);
402ExpectAllNear(c_buffer, c_ref, 1e-7);
403}
404
405TEST(Expr, IntrinsicsDtypes) {
406constexpr int N = 256;
407BufHandle a("A", {N}, kDouble);
408BufHandle b("B", {N}, kDouble);
409std::vector<double> a_buffer(N, -10.0);
410std::vector<double> b_buffer(N, 0.0);
411std::vector<double> b_ref(N, 10.0);
412
413VarHandle i("i", kInt);
414auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i))));
415
416SimpleIREvaluator ir_eval(abs_expr, {a, b});
417ir_eval(a_buffer, b_buffer);
418
419ASSERT_EQ(a_buffer.size(), N);
420ASSERT_EQ(b_buffer.size(), N);
421
422assertAllEqual(a_buffer, -10.0);
423ExpectAllNear(b_buffer, b_ref, 1e-7);
424}
425
426TEST(Expr, Substitute01) {
427VarPtr x = alloc<Var>("x", kFloat);
428VarPtr y = alloc<Var>("y", kFloat);
429ExprPtr e =
430alloc<Mul>(alloc<Sub>(x, alloc<FloatImm>(1.0f)), alloc<Add>(x, y));
431
432VarPtr z = alloc<Var>("z", kFloat);
433ExprPtr e2 = Substitute(e, {{x, alloc<Add>(z, alloc<FloatImm>(5.0f))}});
434ExprPtr e2_ref = alloc<Mul>(
435alloc<Sub>(alloc<Add>(z, alloc<FloatImm>(5.0f)), alloc<FloatImm>(1.0f)),
436alloc<Add>(alloc<Add>(z, alloc<FloatImm>(5.0f)), y));
437std::ostringstream oss;
438oss << *e2;
439std::string e2_str = oss.str();
440
441oss.str("");
442oss << *e2_ref;
443std::string e2_ref_str = oss.str();
444ASSERT_EQ(e2_str, e2_ref_str);
445}
446
447TEST(Expr, Math01) {
448ExprHandle v = sin(ExprHandle(1.0f));
449
450std::ostringstream oss;
451oss << v;
452ASSERT_EQ(oss.str(), "sin(1.f)");
453
454SimpleIRExprEval eval(v);
455float v_ref = std::sin(1.0f);
456float res = eval.value<float>();
457ASSERT_NEAR(res, v_ref, 1e-6);
458}
459
460TEST(Expr, UnaryMath01) {
461struct TestConfig {
462std::function<ExprHandle(const ExprHandle&)> func;
463std::function<float(float)> ref_func;
464};
465
466std::vector<TestConfig> test_configs = {
467{[](const ExprHandle& v) { return sin(v); },
468[](float v) { return std::sin(v); }},
469{[](const ExprHandle& v) { return sin(v); },
470[](float v) { return std::sin(v); }},
471{[](const ExprHandle& v) { return tan(v); },
472[](float v) { return std::tan(v); }},
473{[](const ExprHandle& v) { return asin(v); },
474[](float v) { return std::asin(v); }},
475{[](const ExprHandle& v) { return acos(v); },
476[](float v) { return std::acos(v); }},
477{[](const ExprHandle& v) { return atan(v); },
478[](float v) { return std::atan(v); }},
479{[](const ExprHandle& v) { return sinh(v); },
480[](float v) { return std::sinh(v); }},
481{[](const ExprHandle& v) { return cosh(v); },
482[](float v) { return std::cosh(v); }},
483{[](const ExprHandle& v) { return tanh(v); },
484[](float v) { return std::tanh(v); }},
485{[](const ExprHandle& v) { return exp(v); },
486[](float v) { return std::exp(v); }},
487{[](const ExprHandle& v) { return tensorexpr::abs(v); },
488[](float v) { return std::fabs(v); }},
489{[](const ExprHandle& v) { return log(v); },
490[](float v) { return std::log(v); }},
491{[](const ExprHandle& v) { return log2(v); },
492[](float v) { return std::log2(v); }},
493{[](const ExprHandle& v) { return log10(v); },
494[](float v) { return std::log10(v); }},
495{[](const ExprHandle& v) { return erf(v); },
496[](float v) { return std::erf(v); }},
497{[](const ExprHandle& v) { return sqrt(v); },
498[](float v) { return std::sqrt(v); }},
499{[](const ExprHandle& v) { return rsqrt(v); },
500[](float v) { return 1.0f / std::sqrt(v); }},
501{[](const ExprHandle& v) { return ceil(v); },
502[](float v) { return std::ceil(v); }},
503{[](const ExprHandle& v) { return floor(v); },
504[](float v) { return std::floor(v); }},
505{[](const ExprHandle& v) { return round(v); },
506[](float v) { return std::round(v); }},
507{[](const ExprHandle& v) { return trunc(v); },
508[](float v) { return std::trunc(v); }},
509};
510
511for (const TestConfig& test_config : test_configs) {
512const float input_v = 0.8765f;
513ExprHandle v = test_config.func(ExprHandle(input_v));
514float v_ref = test_config.ref_func(input_v);
515SimpleIRExprEval eval(v);
516ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
517}
518
519// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
520for (float input_v : {std::nan("1"), 0., .5}) {
521ExprHandle v = FloatImm::make(input_v);
522SimpleIRExprEval eval(Intrinsics::make(kIsNan, v));
523ASSERT_NEAR(eval.value<int>(), std::isnan(input_v), 0);
524}
525}
526
527TEST(Expr, BinaryMath01) {
528struct TestConfig {
529std::function<ExprHandle(const ExprHandle&, const ExprHandle&)> func;
530std::function<float(float, float)> ref_func;
531};
532
533std::vector<TestConfig> test_configs = {
534{[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); },
535[](float v1, float v2) { return std::pow(v1, v2); }},
536{[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); },
537[](float v1, float v2) { return std::fmod(v1, v2); }},
538};
539
540for (const TestConfig& test_config : test_configs) {
541const float v1 = 0.8765f;
542float v2 = 1.2345f;
543ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2));
544float v_ref = test_config.ref_func(v1, v2);
545SimpleIRExprEval eval(v_expr);
546ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
547}
548}
549
550TEST(Expr, LogicalOps01) {
551ExprHandle a(23);
552ExprHandle b(11);
553ExprHandle c(0.72f);
554ExprHandle d(0.69f);
555ExprHandle f1 = (a > b) && (c > d);
556ExprHandle f2 = (a > b) && (c < d);
557ExprHandle f3 = (a < b) && (c > d);
558ExprHandle f4 = (a < b) && (c < d);
559ExprHandle f5 = (a < b) || (c > d);
560ExprHandle f6 = (a < b) || (c < d);
561ExprHandle f7 = (a > b) || (c < d);
562ExprHandle f8 = (a > b) || (c > d);
563
564SimpleIRExprEval eval1(f1);
565SimpleIRExprEval eval2(f2);
566SimpleIRExprEval eval3(f3);
567SimpleIRExprEval eval4(f4);
568SimpleIRExprEval eval5(f5);
569SimpleIRExprEval eval6(f6);
570SimpleIRExprEval eval7(f7);
571SimpleIRExprEval eval8(f8);
572ASSERT_EQ(eval1.value<int>(), 1);
573ASSERT_EQ(eval2.value<int>(), 0);
574ASSERT_EQ(eval3.value<int>(), 0);
575ASSERT_EQ(eval4.value<int>(), 0);
576ASSERT_EQ(eval5.value<int>(), 1);
577ASSERT_EQ(eval6.value<int>(), 0);
578ASSERT_EQ(eval7.value<int>(), 1);
579ASSERT_EQ(eval8.value<int>(), 1);
580}
581
582TEST(Expr, LogicalOps02) {
583ExprHandle a(23);
584ExprHandle b(11);
585ExprHandle c(0.72f);
586ExprHandle d(0.72f);
587
588ExprHandle f1 = (a > b) || (c > d);
589ExprHandle f2 = (a > b) && (c <= d);
590ExprHandle f3 = (a > b) && (c > d);
591ExprHandle ff1 = f1 && f2;
592ExprHandle ff2 = f2 || f3;
593
594SimpleIRExprEval eval1(ff1);
595SimpleIRExprEval eval2(ff2);
596ASSERT_EQ(eval1.value<int>(), 1);
597ASSERT_EQ(eval2.value<int>(), 1);
598}
599
600TEST(Expr, LogicalOps03) {
601ExprHandle a(23);
602ExprHandle b(11);
603ExprHandle c(0.72f);
604ExprHandle d(0.69f);
605
606// Bool types
607ExprHandle bool_f1 = (a > b) && BoolImm::make(true);
608ExprHandle bool_f2 = (c <= d) || BoolImm::make(true);
609
610// Int types
611ExprHandle int_f1 = (a > b) && IntImm::make(1);
612ExprHandle int_f2 = (c <= d) || IntImm::make(1);
613
614// Short types
615ExprHandle short_f1 = (a > b) && ShortImm::make(1);
616ExprHandle short_f2 = (c <= d) || ShortImm::make(1);
617
618// Long types
619ExprHandle long_f1 = (a > b) && LongImm::make(1);
620ExprHandle long_f2 = (c <= d) || LongImm::make(1);
621
622// Char types
623ExprHandle char_f1 = (a > b) && CharImm::make(1);
624ExprHandle char_f2 = (c <= d) || CharImm::make(1);
625
626// Byte types
627ExprHandle byte_f1 = (a > b) && ByteImm::make(1);
628ExprHandle byte_f2 = (c <= d) || ByteImm::make(1);
629
630SimpleIRExprEval eval1(bool_f1);
631SimpleIRExprEval eval2(bool_f2);
632SimpleIRExprEval eval3(int_f1);
633SimpleIRExprEval eval4(int_f2);
634SimpleIRExprEval eval5(short_f1);
635SimpleIRExprEval eval6(short_f2);
636SimpleIRExprEval eval7(long_f1);
637SimpleIRExprEval eval8(long_f2);
638SimpleIRExprEval eval9(char_f1);
639SimpleIRExprEval eval10(char_f2);
640SimpleIRExprEval eval11(byte_f1);
641SimpleIRExprEval eval12(byte_f2);
642
643ASSERT_EQ(eval1.value<bool>(), true);
644ASSERT_EQ(eval2.value<bool>(), true);
645ASSERT_EQ(eval3.value<int>(), 1);
646ASSERT_EQ(eval4.value<int>(), 1);
647ASSERT_EQ(eval5.value<int16_t>(), 1);
648ASSERT_EQ(eval6.value<int16_t>(), 1);
649ASSERT_EQ(eval7.value<int64_t>(), 1);
650ASSERT_EQ(eval8.value<int64_t>(), 1);
651ASSERT_EQ(eval9.value<int8_t>(), 1);
652ASSERT_EQ(eval10.value<int8_t>(), 1);
653ASSERT_EQ(eval11.value<uint8_t>(), 1);
654ASSERT_EQ(eval12.value<uint8_t>(), 1);
655}
656
657TEST(Expr, BitwiseOps) {
658ExprHandle a(59);
659ExprHandle b(11);
660ExprHandle c(101);
661ExprHandle d(2);
662ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
663
664SimpleIRExprEval eval(f);
665ASSERT_EQ(eval.value<int>(), 11);
666}
667
668TEST(Expr, DynamicShapeAdd) {
669auto testWithSize = [](int32_t size) {
670VarHandle n("n", kInt);
671BufHandle a("a", {n}, kFloat);
672BufHandle b("b", {n}, kFloat);
673BufHandle c("c", {n}, kFloat);
674VarHandle i("i", kInt);
675StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
676std::vector<float> aData(size, 1.0f);
677std::vector<float> bData(size, 2.0f);
678std::vector<float> cData(size, 0.0f);
679SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size);
680ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
681};
682testWithSize(1);
683testWithSize(16);
684testWithSize(37);
685}
686
687TEST(Expr, OutOfBounds) {
688ExprHandle N(10);
689ExprHandle start(0);
690ExprHandle stop(15);
691VarHandle i("i", kInt);
692
693BufHandle X("X", {N}, kInt);
694
695auto body = Store::make(X, {i}, i);
696auto stmt = For::make(i, start, stop, body);
697
698PaddedBuffer<int> data(20);
699
700EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
701}
702
703TEST(Expr, OutOfBounds2d) {
704std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}};
705for (auto sizes : size_options) {
706ExprHandle N(sizes.first);
707ExprHandle M(sizes.second);
708ExprHandle start(0);
709ExprHandle stopInner(15);
710ExprHandle stopOuter(15);
711VarHandle i("i", kInt);
712VarHandle j("j", kInt);
713
714BufHandle X("X", {N, M}, kInt);
715
716auto body = Store::make(X, {i, j}, i);
717auto inner = For::make(j, start, stopInner, body);
718auto stmt = For::make(i, start, stopOuter, inner);
719
720PaddedBuffer<int> data(400);
721
722EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
723}
724}
725
726TEST(Expr, OutOfBounds2dFlattenedIndex) {
727ExprHandle buf_size(149);
728ExprHandle start(0);
729ExprHandle stopInner(15);
730ExprHandle stopOuter(10);
731VarHandle i("i", kInt);
732VarHandle j("j", kInt);
733
734BufHandle X("X", {buf_size}, kInt);
735
736auto idx = Add::make(Mul::make(i, stopInner), j);
737auto body = Store::make(X, {idx}, i);
738auto inner = For::make(j, start, stopInner, body);
739auto stmt = For::make(i, start, stopOuter, inner);
740
741PaddedBuffer<int> data(400);
742
743EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
744}
745
746void testCond01() {
747const int N = 16;
748PaddedBuffer<float> a_v(N);
749BufHandle a_buf("a", {N}, kFloat);
750VarHandle index = VarHandle("index", kInt);
751StmtPtr assign_x2 = a_buf.store({index}, cast<float>(index) * 2);
752StmtPtr assign_x3 = a_buf.store({index}, cast<float>(index) * 3);
753ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
754StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3);
755StmtPtr for_stmt = For::make(index, 0, N, assign);
756SimpleIREvaluator(for_stmt, {a_buf})(a_v);
757
758PaddedBuffer<float> a_ref(N);
759for (const auto i : c10::irange(N)) {
760if (i % 2 == 0) {
761a_ref(i) = i * 2;
762} else {
763a_ref(i) = i * 3;
764}
765}
766ExpectAllNear(a_v, a_ref, 1e-5);
767}
768
769void testIfThenElse01() {
770ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f));
771
772std::ostringstream oss;
773oss << v;
774ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)");
775
776SimpleIRExprEval eval(v);
777ASSERT_EQ(eval.value<float>(), 1.0f);
778}
779
780void testIfThenElse02() {
781ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f));
782
783std::ostringstream oss;
784oss << v;
785ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
786
787SimpleIRExprEval eval(v);
788ASSERT_EQ(eval.value<float>(), 2.0f);
789}
790
791void testIfThenElse03() {
792ExprHandle v =
793ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f));
794
795std::ostringstream oss;
796oss << v;
797ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
798
799SimpleIRExprEval eval(v);
800ASSERT_EQ(eval.value<float>(), 2.0f);
801}
802
803void testStmtClone() {
804const int N = 16;
805
806BufHandle a_buf("a", {N}, kInt);
807VarHandle index = VarHandle("index", kInt);
808StmtPtr body = a_buf.store({index}, 5);
809StmtPtr loop = For::make(index, 0, N, body);
810
811StmtPtr cloned_loop = Stmt::clone(loop);
812std::vector<int> orig_loop_results(N);
813std::vector<int> cloned_loop_results(N);
814SimpleIREvaluator(loop, {a_buf})(orig_loop_results);
815SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results);
816
817assertAllEqual(orig_loop_results, 5);
818assertAllEqual(cloned_loop_results, 5);
819
820// Let's add another assign to the body in the cloned loop and verify that the
821// original statement hasn't changed while the cloned one has.
822StmtPtr body_addition = a_buf.store({index}, 33);
823BlockPtr cloned_body = static_to<Block>(static_to<For>(cloned_loop)->body());
824cloned_body->append_stmt(body_addition);
825
826std::vector<int> orig_loop_results_after_mutation(N);
827std::vector<int> cloned_loop_results_after_mutation(N);
828SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation);
829SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation);
830
831assertAllEqual(orig_loop_results_after_mutation, 5);
832assertAllEqual(cloned_loop_results_after_mutation, 33);
833}
834
835} // namespace jit
836} // namespace torch
837