pytorch

Форк
0
/
test_expr.cpp 
836 строк · 25.5 Кб
1
#include <gtest/gtest.h>
2

3
#include <test/cpp/tensorexpr/test_base.h>
4

5
#include <c10/util/irange.h>
6
#include <test/cpp/tensorexpr/padded_buffer.h>
7
#include <test/cpp/tensorexpr/test_utils.h>
8
#include <torch/csrc/jit/tensorexpr/eval.h>
9
#include <torch/csrc/jit/tensorexpr/ir.h>
10
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
11
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
12
#include <torch/csrc/jit/tensorexpr/ir_verifier.h>
13
#include <torch/csrc/jit/tensorexpr/loopnest.h>
14
#include <torch/csrc/jit/tensorexpr/tensor.h>
15

16
#include <cmath>
17
#include <sstream>
18
#include <stdexcept>
19
#include <string>
20
#include <vector>
21

22
namespace torch {
23
namespace jit {
24
using namespace torch::jit::tensorexpr;
25

26
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
27

28
TEST(Expr, BasicValueTest) {
29
  ExprHandle a = IntImm::make(2), b = IntImm::make(3);
30
  ExprHandle c = Add::make(a, b);
31
  SimpleIRExprEval eval(c);
32
  ASSERT_EQ(eval.value<int>(), 5);
33
}
34

35
TEST(Expr, BasicValueTest02) {
36
  ExprHandle a(2.0f);
37
  ExprHandle b(3.0f);
38
  ExprHandle c(4.0f);
39
  ExprHandle d(5.0f);
40
  ExprHandle f = (a + b) - (c + d);
41
  SimpleIRExprEval eval(f);
42
  ASSERT_EQ(eval.value<float>(), -4.0f);
43
}
44

45
TEST(Expr, IsChannelsLastContiguous) {
46
  std::vector<VarHandle> vars = {
47
      VarHandle("var1", kLong),
48
      VarHandle("var2", kLong),
49
      VarHandle("var3", kLong),
50
      VarHandle("var4", kLong),
51
      VarHandle("var5", kLong)};
52

53
  // {
54
  //   key: ndims,
55
  //   value: [
56
  //     ...
57
  //     [dim_2, dim_1, ..., dim_n]
58
  //   ]
59
  // }
60
  using shapGenInfo = std::unordered_map<int, std::vector<std::vector<int>>>;
61

62
  // {
63
  //   size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n],
64
  //   strides: [
65
  //     ...
66
  //     [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z]
67
  //   ]
68
  // }
69
  using shapeInfo =
70
      std::pair<std::vector<ExprHandle>, std::vector<std::vector<ExprHandle>>>;
71

72
  std::vector<int> dims = {3, 4, 5};
73

74
  std::unordered_map<int, std::vector<ExprHandle>> dims_expr_vec_conf = {
75
      {3, std::vector<ExprHandle>(vars.begin(), vars.begin() + 2)},
76
      {4, std::vector<ExprHandle>(vars.begin(), vars.begin() + 3)},
77
      {5, std::vector<ExprHandle>(vars.begin(), vars.begin() + 4)},
78
  };
79

80
  shapGenInfo channels_last_cont_shape_conf = {
81
      {3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}};
82
  shapGenInfo channels_last_non_cont_shape_conf = {
83
      {3, {{2, 1, 0}, {1, 0, 2}}},
84
      {4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}},
85
      {5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}};
86

87
  shapGenInfo cont_shape_conf = {
88
      {3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}};
89

90
  auto shape_gen_fn = [dims_expr_vec_conf](
91
                          int ndims, shapGenInfo shape_gen_info) -> shapeInfo {
92
    auto dims_expr_vec = dims_expr_vec_conf.at(ndims);
93
    std::vector<std::vector<ExprHandle>> strides_expr_vec;
94
    for (size_t i = 0; i < strides_expr_vec.size(); i++) {
95
      strides_expr_vec[i].resize(ndims);
96
    }
97

98
    auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) {
99
      if (indicator % 2 == 0) {
100
        return a * b;
101
      } else {
102
        return b * a;
103
      }
104
    };
105

106
    auto stride_order_vec = shape_gen_info.at(ndims);
107
    for (size_t i = 0; i < strides_expr_vec.size(); i++) {
108
      auto stride_order = stride_order_vec[i];
109

110
      strides_expr_vec[i][stride_order[0]] = 1;
111
      for (size_t j = 1; j < stride_order.size(); j++) {
112
        auto cur_dim_idx = stride_order[j];
113
        auto adjacent_dim_idx = stride_order[j - 1];
114

115
        strides_expr_vec[i][cur_dim_idx] = stride_gen_fn(
116
            i,
117
            dims_expr_vec[adjacent_dim_idx],
118
            strides_expr_vec[i][adjacent_dim_idx]);
119
      }
120
    }
121

122
    return {dims_expr_vec, strides_expr_vec};
123
  };
124

125
  auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool {
126
    if (ndims == 3) {
127
      return buf_handle.is_channels_last_1d_contiguous();
128
    } else if (ndims == 4) {
129
      return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast);
130
    } else {
131
      return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d);
132
    }
133
  };
134

135
  // channels-last contiguous
136
  for (size_t i = 0; i < dims.size(); i++) {
137
    auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
138
    for (size_t j = 0; j < shape_info.second.size(); j++) {
139
      BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
140
      ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true);
141
    }
142
  }
143

144
  // channels-last non-contiguous
145
  for (size_t i = 0; i < dims.size(); i++) {
146
    auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf);
147
    for (size_t j = 0; j < shape_info.second.size(); j++) {
148
      BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
149
      ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false);
150
    }
151
  }
152

153
  // contiguous
154
  for (size_t i = 0; i < dims.size(); i++) {
155
    auto shape_info = shape_gen_fn(dims[i], cont_shape_conf);
156
    for (size_t j = 0; j < shape_info.second.size(); j++) {
157
      BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
158
      ASSERT_EQ(buf_handle.is_contiguous(), true);
159
    }
160
  }
161

162
  // non-contiguous
163
  for (size_t i = 0; i < dims.size(); i++) {
164
    auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf);
165
    for (size_t j = 0; j < shape_info.second.size(); j++) {
166
      BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat);
167
      ASSERT_EQ(buf_handle.is_contiguous(), false);
168
    }
169
  }
170
}
171

172
TEST(Expr, LetTest01) {
173
  VarHandle x("x", kFloat);
174
  ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
175
  SimpleIRExprEval eval(body);
176
  eval.bindVar(x, ExprHandle(3.f));
177
  ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
178
}
179

180
TEST(Expr, LetTest02) {
181
  VarHandle x("x", kFloat);
182
  VarHandle y("y", kFloat);
183
  ExprHandle body =
184
      ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y);
185
  SimpleIRExprEval eval(body);
186
  eval.bindVar(x, ExprHandle(3.f));
187
  eval.bindVar(y, ExprHandle(6.f));
188
  ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
189
}
190

191
TEST(Expr, LetStmtTest01) {
192
  BufHandle a_buf("a", {1}, kFloat);
193
  BufHandle b_buf("b", {1}, kFloat);
194

195
  ExprHandle load_a = a_buf.load(0);
196
  VarHandle var = VarHandle("v", kFloat);
197
  StmtPtr let_store = Let::make(var, load_a);
198
  StmtPtr store_b = b_buf.store({0}, var);
199
  BlockPtr block = Block::make({let_store, store_b});
200

201
  SimpleIREvaluator eval(block, {a_buf, b_buf});
202

203
  PaddedBuffer<float> a_v(1);
204
  PaddedBuffer<float> b_v(1);
205
  PaddedBuffer<float> b_ref(1);
206

207
  a_v(0) = 23;
208
  b_ref(0) = a_v(0);
209
  eval(a_v, b_v);
210

211
  ExpectAllNear(b_v, b_ref, 1e-5);
212
}
213

214
TEST(Expr, IntTest) {
215
  VarHandle x("x", kInt);
216
  ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4));
217
  SimpleIRExprEval eval(body);
218
  eval.bindVar(x, ExprHandle(3));
219
  ASSERT_EQ(eval.value<int>(), 2 + (3 * 3 + 4));
220
}
221

222
TEST(Expr, FloatTest) {
223
  VarHandle x("x", kFloat);
224
  ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f));
225
  SimpleIRExprEval eval(body);
226
  eval.bindVar(x, ExprHandle(3.f));
227
  ASSERT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
228
}
229

230
TEST(Expr, ByteTest) {
231
  VarHandle x("x", kByte);
232
  ExprHandle body = ExprHandle((uint8_t)2) +
233
      (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4));
234
  SimpleIRExprEval eval(body);
235
  eval.bindVar(x, ExprHandle((uint8_t)3));
236
  ASSERT_EQ(eval.value<uint8_t>(), 2 + (3 * 3 + 4));
237
}
238

239
TEST(Expr, CharTest) {
240
  VarHandle x("x", kChar);
241
  ExprHandle body = ExprHandle((int8_t)2) +
242
      (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4));
243
  SimpleIRExprEval eval(body);
244
  eval.bindVar(x, ExprHandle((int8_t)3));
245
  ASSERT_EQ(eval.value<int8_t>(), 2 + (3 * 3 + 4));
246
}
247

248
TEST(Expr, ShortTest) {
249
  VarHandle x("x", kShort);
250
  ExprHandle body = ExprHandle((int16_t)2) +
251
      (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4));
252
  SimpleIRExprEval eval(body);
253
  eval.bindVar(x, ExprHandle((int16_t)3));
254
  ASSERT_EQ(eval.value<int16_t>(), 2 + (3 * 3 + 4));
255
}
256

257
TEST(Expr, LongTest) {
258
  VarHandle x("x", kLong);
259
  ExprHandle body = ExprHandle((int64_t)2) +
260
      (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4));
261
  SimpleIRExprEval eval(body);
262
  eval.bindVar(x, ExprHandle((int64_t)3));
263
  ASSERT_EQ(eval.value<int64_t>(), 2 + (3 * 3 + 4));
264
}
265

266
TEST(Expr, HalfTest) {
267
  VarHandle x("x", kHalf);
268
  ExprHandle body = ExprHandle((at::Half)2) +
269
      (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4));
270
  SimpleIRExprEval eval(body);
271
  eval.bindVar(x, ExprHandle((at::Half)3));
272
  ASSERT_EQ(eval.value<at::Half>(), 2 + (3 * 3 + 4));
273
}
274

275
TEST(Expr, DoubleTest) {
276
  VarHandle x("x", kDouble);
277
  ExprHandle body = ExprHandle((double)2) +
278
      (x * ExprHandle((double)3) + ExprHandle((double)4));
279
  SimpleIRExprEval eval(body);
280
  eval.bindVar(x, ExprHandle((double)3));
281
  ASSERT_EQ(eval.value<double>(), 2 + (3 * 3 + 4));
282
}
283

284
TEST(Expr, VectorAdd01) {
285
  const int kVectorSize = 8;
286
  const int kVectorCount = 128;
287
  const int kTotalSize = kVectorSize * kVectorCount;
288

289
  BufHandle a_buf("A", {kTotalSize}, kFloat);
290
  BufHandle b_buf("B", {kTotalSize}, kFloat);
291
  BufHandle c_buf("C", {kTotalSize}, kFloat);
292

293
  /*
294
  Build the following:
295
    for (const auto index : c10::irange(kVectorCount)) {
296
      store(c_buf, ramp(index * 8, 1, 8),
297
            load(a_buf, ramp(index * 8, 1, 8) +
298
            load(b_buf, ramp(index * 8, 1, 8))))
299
    }
300
  */
301
  VarHandle index = VarHandle("index", kInt);
302
  ExprHandle load_a =
303
      a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
304
  ExprHandle load_b =
305
      b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)});
306
  ExprHandle value = load_a + load_b;
307
  StmtPtr store_c =
308
      c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value);
309
  StmtPtr stmt = For::make(index, 0, kVectorCount, store_c);
310

311
  ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize));
312
  ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize));
313
  ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize));
314

315
  PaddedBuffer<float> a_v(kTotalSize);
316
  PaddedBuffer<float> b_v(kTotalSize);
317
  PaddedBuffer<float> c_v(kTotalSize);
318
  PaddedBuffer<float> c_ref(kTotalSize);
319
  for (const auto i : c10::irange(kTotalSize)) {
320
    a_v(i) = i * i;
321
    b_v(i) = i * i * 4;
322
    c_ref(i) = a_v(i) + b_v(i);
323
  }
324
  SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf});
325
  ir_eval(a_v, b_v, c_v);
326
  ExpectAllNear(c_v, c_ref, 1e-5);
327
}
328

329
TEST(Expr, CompareSelectEQ) {
330
  constexpr int N = 1024;
331
  BufHandle a("A", {N}, kInt);
332
  BufHandle b("B", {N}, kInt);
333
  BufHandle c("C", {N}, kInt);
334
  std::vector<int> a_buffer(N, 1);
335
  std::vector<int> b_buffer(N, 1);
336
  std::vector<int> c_buffer(N, 0);
337
  std::vector<int> c_ref(N, 0);
338

339
  VarHandle i("i", kInt);
340
  auto memcpy_expr = For::make(
341
      i,
342
      0,
343
      N,
344
      c.store(
345
          {i},
346
          CompareSelect::make(
347
              a.load(i), b.load(i), CompareSelectOperation::kEQ)));
348

349
  SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c});
350
  ir_eval(a_buffer, b_buffer, c_buffer);
351

352
  ASSERT_EQ(a_buffer.size(), N);
353
  ASSERT_EQ(b_buffer.size(), N);
354
  ASSERT_EQ(c_buffer.size(), N);
355

356
  assertAllEqual(a_buffer, 1);
357
  assertAllEqual(b_buffer, 1);
358
  assertAllEqual(c_buffer, 1);
359
}
360

361
TEST(Expr, CompareSelectDtypes) {
362
  // LHS and RHS expressions should have the same dtype, but this dtype could
363
  // differ from the dtype of the return values (but dtypes of true and false
364
  // return values should be the same).
365
  // This test constructs a CompareSelect expression where the input dtype is
366
  // different from the output dtype and verifies that it works correctly:
367
  //   result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2
368
  constexpr int N = 1024;
369
  BufHandle a("A", {N}, kInt);
370
  BufHandle b("B", {N}, kInt);
371
  BufHandle c("C", {N}, kFloat);
372
  std::vector<int> a_buffer(N, 1);
373
  std::vector<int> b_buffer(N, 1);
374
  std::vector<float> c_buffer(N, 0.0f);
375
  std::vector<float> c_ref(N, 3.14f);
376

377
  VarHandle i("i", kInt);
378
  // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f
379
  // A and B are int, C is float.
380
  auto select_expr = For::make(
381
      i,
382
      0,
383
      N,
384
      c.store(
385
          {i},
386
          CompareSelect::make(
387
              a.load(i),
388
              b.load(i),
389
              FloatImm::make(3.14f),
390
              FloatImm::make(2.78f),
391
              CompareSelectOperation::kEQ)));
392

393
  SimpleIREvaluator ir_eval(select_expr, {a, b, c});
394
  ir_eval(a_buffer, b_buffer, c_buffer);
395

396
  ASSERT_EQ(a_buffer.size(), N);
397
  ASSERT_EQ(b_buffer.size(), N);
398
  ASSERT_EQ(c_buffer.size(), N);
399

400
  assertAllEqual(a_buffer, 1);
401
  assertAllEqual(b_buffer, 1);
402
  ExpectAllNear(c_buffer, c_ref, 1e-7);
403
}
404

405
TEST(Expr, IntrinsicsDtypes) {
406
  constexpr int N = 256;
407
  BufHandle a("A", {N}, kDouble);
408
  BufHandle b("B", {N}, kDouble);
409
  std::vector<double> a_buffer(N, -10.0);
410
  std::vector<double> b_buffer(N, 0.0);
411
  std::vector<double> b_ref(N, 10.0);
412

413
  VarHandle i("i", kInt);
414
  auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i))));
415

416
  SimpleIREvaluator ir_eval(abs_expr, {a, b});
417
  ir_eval(a_buffer, b_buffer);
418

419
  ASSERT_EQ(a_buffer.size(), N);
420
  ASSERT_EQ(b_buffer.size(), N);
421

422
  assertAllEqual(a_buffer, -10.0);
423
  ExpectAllNear(b_buffer, b_ref, 1e-7);
424
}
425

426
TEST(Expr, Substitute01) {
427
  VarPtr x = alloc<Var>("x", kFloat);
428
  VarPtr y = alloc<Var>("y", kFloat);
429
  ExprPtr e =
430
      alloc<Mul>(alloc<Sub>(x, alloc<FloatImm>(1.0f)), alloc<Add>(x, y));
431

432
  VarPtr z = alloc<Var>("z", kFloat);
433
  ExprPtr e2 = Substitute(e, {{x, alloc<Add>(z, alloc<FloatImm>(5.0f))}});
434
  ExprPtr e2_ref = alloc<Mul>(
435
      alloc<Sub>(alloc<Add>(z, alloc<FloatImm>(5.0f)), alloc<FloatImm>(1.0f)),
436
      alloc<Add>(alloc<Add>(z, alloc<FloatImm>(5.0f)), y));
437
  std::ostringstream oss;
438
  oss << *e2;
439
  std::string e2_str = oss.str();
440

441
  oss.str("");
442
  oss << *e2_ref;
443
  std::string e2_ref_str = oss.str();
444
  ASSERT_EQ(e2_str, e2_ref_str);
445
}
446

447
TEST(Expr, Math01) {
448
  ExprHandle v = sin(ExprHandle(1.0f));
449

450
  std::ostringstream oss;
451
  oss << v;
452
  ASSERT_EQ(oss.str(), "sin(1.f)");
453

454
  SimpleIRExprEval eval(v);
455
  float v_ref = std::sin(1.0f);
456
  float res = eval.value<float>();
457
  ASSERT_NEAR(res, v_ref, 1e-6);
458
}
459

460
TEST(Expr, UnaryMath01) {
461
  struct TestConfig {
462
    std::function<ExprHandle(const ExprHandle&)> func;
463
    std::function<float(float)> ref_func;
464
  };
465

466
  std::vector<TestConfig> test_configs = {
467
      {[](const ExprHandle& v) { return sin(v); },
468
       [](float v) { return std::sin(v); }},
469
      {[](const ExprHandle& v) { return sin(v); },
470
       [](float v) { return std::sin(v); }},
471
      {[](const ExprHandle& v) { return tan(v); },
472
       [](float v) { return std::tan(v); }},
473
      {[](const ExprHandle& v) { return asin(v); },
474
       [](float v) { return std::asin(v); }},
475
      {[](const ExprHandle& v) { return acos(v); },
476
       [](float v) { return std::acos(v); }},
477
      {[](const ExprHandle& v) { return atan(v); },
478
       [](float v) { return std::atan(v); }},
479
      {[](const ExprHandle& v) { return sinh(v); },
480
       [](float v) { return std::sinh(v); }},
481
      {[](const ExprHandle& v) { return cosh(v); },
482
       [](float v) { return std::cosh(v); }},
483
      {[](const ExprHandle& v) { return tanh(v); },
484
       [](float v) { return std::tanh(v); }},
485
      {[](const ExprHandle& v) { return exp(v); },
486
       [](float v) { return std::exp(v); }},
487
      {[](const ExprHandle& v) { return tensorexpr::abs(v); },
488
       [](float v) { return std::fabs(v); }},
489
      {[](const ExprHandle& v) { return log(v); },
490
       [](float v) { return std::log(v); }},
491
      {[](const ExprHandle& v) { return log2(v); },
492
       [](float v) { return std::log2(v); }},
493
      {[](const ExprHandle& v) { return log10(v); },
494
       [](float v) { return std::log10(v); }},
495
      {[](const ExprHandle& v) { return erf(v); },
496
       [](float v) { return std::erf(v); }},
497
      {[](const ExprHandle& v) { return sqrt(v); },
498
       [](float v) { return std::sqrt(v); }},
499
      {[](const ExprHandle& v) { return rsqrt(v); },
500
       [](float v) { return 1.0f / std::sqrt(v); }},
501
      {[](const ExprHandle& v) { return ceil(v); },
502
       [](float v) { return std::ceil(v); }},
503
      {[](const ExprHandle& v) { return floor(v); },
504
       [](float v) { return std::floor(v); }},
505
      {[](const ExprHandle& v) { return round(v); },
506
       [](float v) { return std::round(v); }},
507
      {[](const ExprHandle& v) { return trunc(v); },
508
       [](float v) { return std::trunc(v); }},
509
  };
510

511
  for (const TestConfig& test_config : test_configs) {
512
    const float input_v = 0.8765f;
513
    ExprHandle v = test_config.func(ExprHandle(input_v));
514
    float v_ref = test_config.ref_func(input_v);
515
    SimpleIRExprEval eval(v);
516
    ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
517
  }
518

519
  // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
520
  for (float input_v : {std::nan("1"), 0., .5}) {
521
    ExprHandle v = FloatImm::make(input_v);
522
    SimpleIRExprEval eval(Intrinsics::make(kIsNan, v));
523
    ASSERT_NEAR(eval.value<int>(), std::isnan(input_v), 0);
524
  }
525
}
526

527
TEST(Expr, BinaryMath01) {
528
  struct TestConfig {
529
    std::function<ExprHandle(const ExprHandle&, const ExprHandle&)> func;
530
    std::function<float(float, float)> ref_func;
531
  };
532

533
  std::vector<TestConfig> test_configs = {
534
      {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); },
535
       [](float v1, float v2) { return std::pow(v1, v2); }},
536
      {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); },
537
       [](float v1, float v2) { return std::fmod(v1, v2); }},
538
  };
539

540
  for (const TestConfig& test_config : test_configs) {
541
    const float v1 = 0.8765f;
542
    float v2 = 1.2345f;
543
    ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2));
544
    float v_ref = test_config.ref_func(v1, v2);
545
    SimpleIRExprEval eval(v_expr);
546
    ASSERT_NEAR(eval.value<float>(), v_ref, 1e-6);
547
  }
548
}
549

550
TEST(Expr, LogicalOps01) {
551
  ExprHandle a(23);
552
  ExprHandle b(11);
553
  ExprHandle c(0.72f);
554
  ExprHandle d(0.69f);
555
  ExprHandle f1 = (a > b) && (c > d);
556
  ExprHandle f2 = (a > b) && (c < d);
557
  ExprHandle f3 = (a < b) && (c > d);
558
  ExprHandle f4 = (a < b) && (c < d);
559
  ExprHandle f5 = (a < b) || (c > d);
560
  ExprHandle f6 = (a < b) || (c < d);
561
  ExprHandle f7 = (a > b) || (c < d);
562
  ExprHandle f8 = (a > b) || (c > d);
563

564
  SimpleIRExprEval eval1(f1);
565
  SimpleIRExprEval eval2(f2);
566
  SimpleIRExprEval eval3(f3);
567
  SimpleIRExprEval eval4(f4);
568
  SimpleIRExprEval eval5(f5);
569
  SimpleIRExprEval eval6(f6);
570
  SimpleIRExprEval eval7(f7);
571
  SimpleIRExprEval eval8(f8);
572
  ASSERT_EQ(eval1.value<int>(), 1);
573
  ASSERT_EQ(eval2.value<int>(), 0);
574
  ASSERT_EQ(eval3.value<int>(), 0);
575
  ASSERT_EQ(eval4.value<int>(), 0);
576
  ASSERT_EQ(eval5.value<int>(), 1);
577
  ASSERT_EQ(eval6.value<int>(), 0);
578
  ASSERT_EQ(eval7.value<int>(), 1);
579
  ASSERT_EQ(eval8.value<int>(), 1);
580
}
581

582
TEST(Expr, LogicalOps02) {
583
  ExprHandle a(23);
584
  ExprHandle b(11);
585
  ExprHandle c(0.72f);
586
  ExprHandle d(0.72f);
587

588
  ExprHandle f1 = (a > b) || (c > d);
589
  ExprHandle f2 = (a > b) && (c <= d);
590
  ExprHandle f3 = (a > b) && (c > d);
591
  ExprHandle ff1 = f1 && f2;
592
  ExprHandle ff2 = f2 || f3;
593

594
  SimpleIRExprEval eval1(ff1);
595
  SimpleIRExprEval eval2(ff2);
596
  ASSERT_EQ(eval1.value<int>(), 1);
597
  ASSERT_EQ(eval2.value<int>(), 1);
598
}
599

600
TEST(Expr, LogicalOps03) {
601
  ExprHandle a(23);
602
  ExprHandle b(11);
603
  ExprHandle c(0.72f);
604
  ExprHandle d(0.69f);
605

606
  // Bool types
607
  ExprHandle bool_f1 = (a > b) && BoolImm::make(true);
608
  ExprHandle bool_f2 = (c <= d) || BoolImm::make(true);
609

610
  // Int types
611
  ExprHandle int_f1 = (a > b) && IntImm::make(1);
612
  ExprHandle int_f2 = (c <= d) || IntImm::make(1);
613

614
  // Short types
615
  ExprHandle short_f1 = (a > b) && ShortImm::make(1);
616
  ExprHandle short_f2 = (c <= d) || ShortImm::make(1);
617

618
  // Long types
619
  ExprHandle long_f1 = (a > b) && LongImm::make(1);
620
  ExprHandle long_f2 = (c <= d) || LongImm::make(1);
621

622
  // Char types
623
  ExprHandle char_f1 = (a > b) && CharImm::make(1);
624
  ExprHandle char_f2 = (c <= d) || CharImm::make(1);
625

626
  // Byte types
627
  ExprHandle byte_f1 = (a > b) && ByteImm::make(1);
628
  ExprHandle byte_f2 = (c <= d) || ByteImm::make(1);
629

630
  SimpleIRExprEval eval1(bool_f1);
631
  SimpleIRExprEval eval2(bool_f2);
632
  SimpleIRExprEval eval3(int_f1);
633
  SimpleIRExprEval eval4(int_f2);
634
  SimpleIRExprEval eval5(short_f1);
635
  SimpleIRExprEval eval6(short_f2);
636
  SimpleIRExprEval eval7(long_f1);
637
  SimpleIRExprEval eval8(long_f2);
638
  SimpleIRExprEval eval9(char_f1);
639
  SimpleIRExprEval eval10(char_f2);
640
  SimpleIRExprEval eval11(byte_f1);
641
  SimpleIRExprEval eval12(byte_f2);
642

643
  ASSERT_EQ(eval1.value<bool>(), true);
644
  ASSERT_EQ(eval2.value<bool>(), true);
645
  ASSERT_EQ(eval3.value<int>(), 1);
646
  ASSERT_EQ(eval4.value<int>(), 1);
647
  ASSERT_EQ(eval5.value<int16_t>(), 1);
648
  ASSERT_EQ(eval6.value<int16_t>(), 1);
649
  ASSERT_EQ(eval7.value<int64_t>(), 1);
650
  ASSERT_EQ(eval8.value<int64_t>(), 1);
651
  ASSERT_EQ(eval9.value<int8_t>(), 1);
652
  ASSERT_EQ(eval10.value<int8_t>(), 1);
653
  ASSERT_EQ(eval11.value<uint8_t>(), 1);
654
  ASSERT_EQ(eval12.value<uint8_t>(), 1);
655
}
656

657
TEST(Expr, BitwiseOps) {
658
  ExprHandle a(59);
659
  ExprHandle b(11);
660
  ExprHandle c(101);
661
  ExprHandle d(2);
662
  ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
663

664
  SimpleIRExprEval eval(f);
665
  ASSERT_EQ(eval.value<int>(), 11);
666
}
667

668
TEST(Expr, DynamicShapeAdd) {
669
  auto testWithSize = [](int32_t size) {
670
    VarHandle n("n", kInt);
671
    BufHandle a("a", {n}, kFloat);
672
    BufHandle b("b", {n}, kFloat);
673
    BufHandle c("c", {n}, kFloat);
674
    VarHandle i("i", kInt);
675
    StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
676
    std::vector<float> aData(size, 1.0f);
677
    std::vector<float> bData(size, 2.0f);
678
    std::vector<float> cData(size, 0.0f);
679
    SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size);
680
    ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
681
  };
682
  testWithSize(1);
683
  testWithSize(16);
684
  testWithSize(37);
685
}
686

687
TEST(Expr, OutOfBounds) {
688
  ExprHandle N(10);
689
  ExprHandle start(0);
690
  ExprHandle stop(15);
691
  VarHandle i("i", kInt);
692

693
  BufHandle X("X", {N}, kInt);
694

695
  auto body = Store::make(X, {i}, i);
696
  auto stmt = For::make(i, start, stop, body);
697

698
  PaddedBuffer<int> data(20);
699

700
  EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
701
}
702

703
TEST(Expr, OutOfBounds2d) {
704
  std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}};
705
  for (auto sizes : size_options) {
706
    ExprHandle N(sizes.first);
707
    ExprHandle M(sizes.second);
708
    ExprHandle start(0);
709
    ExprHandle stopInner(15);
710
    ExprHandle stopOuter(15);
711
    VarHandle i("i", kInt);
712
    VarHandle j("j", kInt);
713

714
    BufHandle X("X", {N, M}, kInt);
715

716
    auto body = Store::make(X, {i, j}, i);
717
    auto inner = For::make(j, start, stopInner, body);
718
    auto stmt = For::make(i, start, stopOuter, inner);
719

720
    PaddedBuffer<int> data(400);
721

722
    EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
723
  }
724
}
725

726
TEST(Expr, OutOfBounds2dFlattenedIndex) {
727
  ExprHandle buf_size(149);
728
  ExprHandle start(0);
729
  ExprHandle stopInner(15);
730
  ExprHandle stopOuter(10);
731
  VarHandle i("i", kInt);
732
  VarHandle j("j", kInt);
733

734
  BufHandle X("X", {buf_size}, kInt);
735

736
  auto idx = Add::make(Mul::make(i, stopInner), j);
737
  auto body = Store::make(X, {idx}, i);
738
  auto inner = For::make(j, start, stopInner, body);
739
  auto stmt = For::make(i, start, stopOuter, inner);
740

741
  PaddedBuffer<int> data(400);
742

743
  EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
744
}
745

746
void testCond01() {
747
  const int N = 16;
748
  PaddedBuffer<float> a_v(N);
749
  BufHandle a_buf("a", {N}, kFloat);
750
  VarHandle index = VarHandle("index", kInt);
751
  StmtPtr assign_x2 = a_buf.store({index}, cast<float>(index) * 2);
752
  StmtPtr assign_x3 = a_buf.store({index}, cast<float>(index) * 3);
753
  ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ);
754
  StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3);
755
  StmtPtr for_stmt = For::make(index, 0, N, assign);
756
  SimpleIREvaluator(for_stmt, {a_buf})(a_v);
757

758
  PaddedBuffer<float> a_ref(N);
759
  for (const auto i : c10::irange(N)) {
760
    if (i % 2 == 0) {
761
      a_ref(i) = i * 2;
762
    } else {
763
      a_ref(i) = i * 3;
764
    }
765
  }
766
  ExpectAllNear(a_v, a_ref, 1e-5);
767
}
768

769
void testIfThenElse01() {
770
  ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f));
771

772
  std::ostringstream oss;
773
  oss << v;
774
  ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)");
775

776
  SimpleIRExprEval eval(v);
777
  ASSERT_EQ(eval.value<float>(), 1.0f);
778
}
779

780
void testIfThenElse02() {
781
  ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f));
782

783
  std::ostringstream oss;
784
  oss << v;
785
  ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
786

787
  SimpleIRExprEval eval(v);
788
  ASSERT_EQ(eval.value<float>(), 2.0f);
789
}
790

791
void testIfThenElse03() {
792
  ExprHandle v =
793
      ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f));
794

795
  std::ostringstream oss;
796
  oss << v;
797
  ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)");
798

799
  SimpleIRExprEval eval(v);
800
  ASSERT_EQ(eval.value<float>(), 2.0f);
801
}
802

803
void testStmtClone() {
804
  const int N = 16;
805

806
  BufHandle a_buf("a", {N}, kInt);
807
  VarHandle index = VarHandle("index", kInt);
808
  StmtPtr body = a_buf.store({index}, 5);
809
  StmtPtr loop = For::make(index, 0, N, body);
810

811
  StmtPtr cloned_loop = Stmt::clone(loop);
812
  std::vector<int> orig_loop_results(N);
813
  std::vector<int> cloned_loop_results(N);
814
  SimpleIREvaluator(loop, {a_buf})(orig_loop_results);
815
  SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results);
816

817
  assertAllEqual(orig_loop_results, 5);
818
  assertAllEqual(cloned_loop_results, 5);
819

820
  // Let's add another assign to the body in the cloned loop and verify that the
821
  // original statement hasn't changed while the cloned one has.
822
  StmtPtr body_addition = a_buf.store({index}, 33);
823
  BlockPtr cloned_body = static_to<Block>(static_to<For>(cloned_loop)->body());
824
  cloned_body->append_stmt(body_addition);
825

826
  std::vector<int> orig_loop_results_after_mutation(N);
827
  std::vector<int> cloned_loop_results_after_mutation(N);
828
  SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation);
829
  SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation);
830

831
  assertAllEqual(orig_loop_results_after_mutation, 5);
832
  assertAllEqual(cloned_loop_results_after_mutation, 33);
833
}
834

835
} // namespace jit
836
} // namespace torch
837

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.