pytorch

Форк
0
/
test_memdependency.cpp 
3252 строки · 101.3 Кб
1
#include <gtest/gtest.h>
2
#include <test/cpp/tensorexpr/test_base.h>
3

4
#include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
5
#include <torch/csrc/jit/tensorexpr/ir.h>
6
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
7
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8
#include <torch/csrc/jit/tensorexpr/loopnest.h>
9
#include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h>
10
#include <torch/csrc/jit/tensorexpr/tensor.h>
11

12
namespace torch {
13
namespace jit {
14

15
using namespace torch::jit::tensorexpr;
16

17
// Test helper function used to determine if two regions of a buffer have an
18
// overlap. No Overlap & partial overlap is obvious. Contains means A is
19
// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal
20
// ranges are ContainedOrEqual.
21
TEST(MemDependency, BoundOverlap) {
22
  using namespace analysis;
23

24
  auto CB = [](int s, int e) {
25
    return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
26
  };
27

28
  // Sanity check 3 overlap cases.
29
  ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0)));
30
  ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5)));
31
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1)));
32

33
  // Partial overlap works in either order.
34
  ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14)));
35
  ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10)));
36

37
  // Total Overlap works when one bound encloses the other, and returns which.
38
  ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9)));
39
  ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16)));
40

41
  // Total overlap works when the bounds are an identical range, returns
42
  // ContainedOrEqual.
43
  ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15)));
44

45
  // Total overlap when only one end of the bound matches.
46
  ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10)));
47
  ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15)));
48
  ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9)));
49
  ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15)));
50
  ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15)));
51

52
  // No overlap when a < b.
53
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10)));
54
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3)));
55
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130)));
56

57
  // No overlap when a > b.
58
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2)));
59
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2)));
60
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120)));
61

62
  // No overlap when adjacent.
63
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120)));
64
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1)));
65

66
  // Partial overlap when middle bounds match.
67
  ASSERT_EQ(
68
      OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120)));
69
  ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4)));
70
  ASSERT_EQ(
71
      OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100)));
72
  ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2)));
73

74
  // Total overlap when one bound is single length over one end of the other.
75
  ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15)));
76
  ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2)));
77
  ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15)));
78
  ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15)));
79
}
80

81
TEST(MemDependency, BoundComparison) {
82
  using namespace analysis;
83

84
  auto CB = [](int s, int e) {
85
    return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
86
  };
87

88
  ASSERT_EQ(
89
      CmpEvalResult::NotDetermined,
90
      compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ));
91
  ASSERT_EQ(
92
      CmpEvalResult::True,
93
      compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ));
94
  ASSERT_EQ(
95
      CmpEvalResult::False,
96
      compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ));
97
  ASSERT_EQ(
98
      CmpEvalResult::False,
99
      compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ));
100
  ASSERT_EQ(
101
      CmpEvalResult::NotDetermined,
102
      compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ));
103
  ASSERT_EQ(
104
      CmpEvalResult::NotDetermined,
105
      compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
106
  ASSERT_EQ(
107
      CmpEvalResult::NotDetermined,
108
      compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ));
109

110
  ASSERT_EQ(
111
      CmpEvalResult::NotDetermined,
112
      compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE));
113
  ASSERT_EQ(
114
      CmpEvalResult::False,
115
      compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE));
116
  ASSERT_EQ(
117
      CmpEvalResult::True,
118
      compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE));
119
  ASSERT_EQ(
120
      CmpEvalResult::True,
121
      compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE));
122
  ASSERT_EQ(
123
      CmpEvalResult::NotDetermined,
124
      compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE));
125
  ASSERT_EQ(
126
      CmpEvalResult::NotDetermined,
127
      compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
128
  ASSERT_EQ(
129
      CmpEvalResult::NotDetermined,
130
      compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE));
131

132
  ASSERT_EQ(
133
      CmpEvalResult::True,
134
      compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT));
135
  ASSERT_EQ(
136
      CmpEvalResult::False,
137
      compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT));
138
  ASSERT_EQ(
139
      CmpEvalResult::False,
140
      compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT));
141
  ASSERT_EQ(
142
      CmpEvalResult::NotDetermined,
143
      compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT));
144
  ASSERT_EQ(
145
      CmpEvalResult::NotDetermined,
146
      compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT));
147
  ASSERT_EQ(
148
      CmpEvalResult::NotDetermined,
149
      compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT));
150

151
  ASSERT_EQ(
152
      CmpEvalResult::False,
153
      compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE));
154
  ASSERT_EQ(
155
      CmpEvalResult::True,
156
      compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE));
157
  ASSERT_EQ(
158
      CmpEvalResult::True,
159
      compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE));
160
  ASSERT_EQ(
161
      CmpEvalResult::NotDetermined,
162
      compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE));
163
  ASSERT_EQ(
164
      CmpEvalResult::NotDetermined,
165
      compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE));
166
  ASSERT_EQ(
167
      CmpEvalResult::NotDetermined,
168
      compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE));
169

170
  ASSERT_EQ(
171
      CmpEvalResult::False,
172
      compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT));
173
  ASSERT_EQ(
174
      CmpEvalResult::False,
175
      compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT));
176
  ASSERT_EQ(
177
      CmpEvalResult::NotDetermined,
178
      compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT));
179
  ASSERT_EQ(
180
      CmpEvalResult::True,
181
      compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT));
182
  ASSERT_EQ(
183
      CmpEvalResult::NotDetermined,
184
      compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT));
185
  ASSERT_EQ(
186
      CmpEvalResult::NotDetermined,
187
      compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT));
188

189
  ASSERT_EQ(
190
      CmpEvalResult::True,
191
      compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE));
192
  ASSERT_EQ(
193
      CmpEvalResult::True,
194
      compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE));
195
  ASSERT_EQ(
196
      CmpEvalResult::NotDetermined,
197
      compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE));
198
  ASSERT_EQ(
199
      CmpEvalResult::False,
200
      compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE));
201
  ASSERT_EQ(
202
      CmpEvalResult::NotDetermined,
203
      compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE));
204
  ASSERT_EQ(
205
      CmpEvalResult::NotDetermined,
206
      compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE));
207
}
208

209
TEST(MemDependency, BoundOverlapSymbolic) {
210
  VarHandle x("x", kInt);
211
  VarHandle y("y", kInt);
212
  VarHandle z("z", kInt);
213
  VarHandle w("w", kInt);
214

215
  using namespace analysis;
216

217
  auto CB = [](ExprHandle s, ExprHandle e) {
218
    return Bound(s.node(), e.node());
219
  };
220

221
  // Sanity check cases where the start and end is symbolic but the diff is
222
  // constant.
223
  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
224
  ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x)));
225
  ASSERT_EQ(
226
      OverlapKind::PartialOverlap,
227
      boundOverlap(CB(x, x + 3), CB(x + 2, x + 5)));
228
  ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1)));
229

230
  // We can't infer the sign of y, so cannot tell whether adding y is larger or
231
  // smaller than y/2.
232
  ASSERT_EQ(
233
      OverlapKind::PartialOverlap,
234
      boundOverlap(CB(x, x + y), CB(x, x + y / 2)));
235

236
  // No information about this bound, have to take the most conservative option:
237
  // there may be an overlap.
238
  ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w)));
239

240
  // Math on opaque terms works.
241
  ASSERT_EQ(
242
      OverlapKind::ContainedOrEqual,
243
      boundOverlap(CB(x + w, y - z), CB(x + w, y - z)));
244
  // Even requiring simplification.
245
  ASSERT_EQ(
246
      OverlapKind::ContainedOrEqual,
247
      boundOverlap(CB(x - w - w, y), CB(x - w * 2, y)));
248
}
249

250
// Tests the helper function for overlap of multi dimensional indices bounds.
251
// This uses boundOverlap on each dimension and return the "lowest" kind of
252
// overlap.
253
TEST(MemDependency, BoundOverlapMultiDim) {
254
  using namespace analysis;
255

256
  auto CB = [](int s, int e) {
257
    return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
258
  };
259

260
  // Sanity check one dimensional cases.
261
  ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)}));
262
  ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)}));
263
  ASSERT_EQ(
264
      OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)}));
265

266
  // Total overlap in 3 dims.
267
  ASSERT_EQ(
268
      OverlapKind::ContainedOrEqual,
269
      overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)}));
270
  ASSERT_EQ(
271
      OverlapKind::ContainedOrEqual,
272
      overlaps(
273
          {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)}));
274

275
  // Total overlap in 2 dims, no overlap in another.
276
  ASSERT_EQ(
277
      OverlapKind::NoOverlap,
278
      overlaps(
279
          {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
280

281
  // Total overlap in 2 dims, partial overlap in another.
282
  ASSERT_EQ(
283
      OverlapKind::PartialOverlap,
284
      overlaps(
285
          {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
286
  // This case is most important, so verify the overlap in any dim. (dim 2)
287
  ASSERT_EQ(
288
      OverlapKind::PartialOverlap,
289
      overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)}));
290
  // Dim 1.
291
  ASSERT_EQ(
292
      OverlapKind::PartialOverlap,
293
      overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)}));
294
  // Total overlap in 1 dim, partial in 2.
295
  ASSERT_EQ(
296
      OverlapKind::PartialOverlap,
297
      overlaps(
298
          {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)}));
299
  // Total overlap, partial overlap, no overlap.
300
  ASSERT_EQ(
301
      OverlapKind::NoOverlap,
302
      overlaps(
303
          {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)}));
304

305
  // Total overlap (B) in 2 dims, total overlap (A) in another.
306
  ASSERT_EQ(
307
      OverlapKind::Contains,
308
      overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)}));
309

310
  // Total overlap (A) in 2 dims, total overlap (B) in another.
311
  ASSERT_EQ(
312
      OverlapKind::Contains,
313
      overlaps(
314
          {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)}));
315

316
  // Total (B), No Overlap, Total (A).
317
  ASSERT_EQ(
318
      OverlapKind::NoOverlap,
319
      overlaps(
320
          {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)}));
321
}
322

323
// Test the helper we use to subtract bounds: returns the regions(s) of A which
324
// remain after removing the region of B.
325
TEST(MemDependency, BoundSubtract) {
326
  using namespace analysis;
327

328
  auto CB = [](int s, int e) {
329
    return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
330
  };
331
  auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
332
    return indexBoundsEquals(x, y);
333
  };
334

335
  // One element subtract.
336
  ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0);
337
  ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0);
338

339
  // No Overlap.
340
  ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)}));
341
  ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)}));
342

343
  // one side overlap.
344
  ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)}));
345
  ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)}));
346
  ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)}));
347
  ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)}));
348

349
  // both sides overlap.
350
  ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {}));
351
  ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {}));
352

353
  // internal overlap.
354
  ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)}));
355
  ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)}));
356
}
357

358
TEST(MemDependency, BoundSubtractSymbolic) {
359
  VarHandle x("x", kInt);
360
  VarHandle y("y", kInt);
361
  VarHandle z("z", kInt);
362
  VarHandle w("w", kInt);
363

364
  using namespace analysis;
365

366
  auto CB = [](ExprHandle s, ExprHandle e) {
367
    return Bound(s.node(), e.node());
368
  };
369
  auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
370
    return indexBoundsEquals(x, y);
371
  };
372

373
  // One element subtract.
374
  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
375
  ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {}));
376
  ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {}));
377
  ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {}));
378

379
  // Subtract constant range low.
380
  ASSERT_TRUE(
381
      EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)}));
382
  // Subtract constant range high.
383
  ASSERT_TRUE(
384
      EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)}));
385
  // Subtract constant range total overlap.
386
  ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {}));
387
  ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {}));
388
  // Subtract constant range internal.
389
  ASSERT_TRUE(
390
      EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)),
391
         {CB(x, x + 2), CB(x + 8, x + 10)}));
392

393
  // Size is inferable but not constant, only works with a single var.
394
  ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {}));
395
  ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)}));
396

397
  // Size is not inferable.
398
  ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)}));
399
  ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)}));
400
  ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)}));
401
  ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)}));
402
}
403

404
// Tests the helper function that does subtraction, but for multi dimensional
405
// indices bounds.
406
TEST(MemDependency, BoundSubtractMultiDim) {
407
  using namespace analysis;
408

409
  auto CB = [](int s, int e) {
410
    return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
411
  };
412
  auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
413
    if (x.size() != y.size()) {
414
      return false;
415
    }
416
    for (auto i = 0U; i < x.size(); ++i) {
417
      if (!indexBoundsEquals(x[i], y[i])) {
418
        return false;
419
      }
420
    }
421
    return true;
422
  };
423

424
  // sanity check one dimension.
425
  ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {}));
426
  ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {}));
427
  ASSERT_TRUE(
428
      EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}}));
429
  ASSERT_TRUE(
430
      EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}}));
431
  ASSERT_TRUE(EQ(
432
      subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}}));
433

434
  // Multi dim total overlap.
435
  ASSERT_TRUE(EQ(
436
      subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {}));
437
  ASSERT_TRUE(EQ(
438
      subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {}));
439

440
  // Mutli dim one way partial in dim 1.
441
  ASSERT_TRUE(
442
      EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}),
443
         {{CB(4, 9), CB(0, 2)}}));
444

445
  // Mutli dim one way partial in dim 2.
446
  ASSERT_TRUE(
447
      EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}),
448
         {{CB(0, 9), CB(11, 20)}}));
449

450
  // Partial overlap in 2 dims.
451
  ASSERT_TRUE(
452
      EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}),
453
         {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}}));
454

455
  // Partial overlap in 3 dims.
456
  ASSERT_TRUE(
457
      EQ(subtractIndicesBounds(
458
             {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}),
459
         {{CB(0, 1), CB(0, 5), CB(0, 5)},
460
          {CB(2, 5), CB(0, 1), CB(0, 5)},
461
          {CB(2, 5), CB(2, 5), CB(0, 1)}}));
462
}
463

464
// Tests the multi dimensional subtraction code for bounds that cannot be fully
465
// materialized.
466
TEST(MemDependency, BoundSubtractMultiDimSymbolic) {
467
  VarHandle x("x", kInt);
468
  VarHandle y("y", kInt);
469

470
  using namespace analysis;
471

472
  auto CB = [](ExprHandle s, ExprHandle e) {
473
    return Bound(s.node(), e.node());
474
  };
475

476
  auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
477
    if (x.size() != y.size()) {
478
      return false;
479
    }
480
    for (auto i = 0U; i < x.size(); ++i) {
481
      if (!indexBoundsEquals(x[i], y[i])) {
482
        return false;
483
      }
484
    }
485
    return true;
486
  };
487

488
  // Cannot determine overlaps.
489
  // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
490
  ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}}));
491

492
  // Various total Overlaps.
493
  ASSERT_TRUE(EQ(
494
      subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {}));
495
  ASSERT_TRUE(EQ(
496
      subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {}));
497
  ASSERT_TRUE(EQ(
498
      subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {}));
499
  ASSERT_TRUE(EQ(
500
      subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {}));
501

502
  // one-way overlap in first dim.
503
  ASSERT_TRUE(
504
      EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}),
505
         {{CB(x - 4, x), CB(0, y)}}));
506
  // second dim.
507
  ASSERT_TRUE(
508
      EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}),
509
         {{CB(0, x), CB(0, 4)}}));
510

511
  // Internal overlap in first dim.
512
  ASSERT_TRUE(
513
      EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}),
514
         {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}}));
515
  // second dim.
516
  ASSERT_TRUE(EQ(
517
      subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}),
518
      {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}}));
519

520
  // Overlap in both dimensions.
521
  ASSERT_TRUE(
522
      EQ(subtractIndicesBounds(
523
             {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}),
524
         {
525
             {CB(0, 4), CB(0, y)},
526
             {CB(x - 4, x), CB(0, y)},
527
             {CB(0, x), CB(0, 9)},
528
             {CB(0, x), CB(y - 9, y)},
529
         }));
530
}
531

532
// Simple check that the analyzer does anything at all...
533
TEST(MemDependency, MemDependencyCheckerSimple) {
534
  BufHandle a("A", {1}, kInt);
535
  BufHandle b("B", {1}, kInt);
536

537
  analysis::MemDependencyChecker analyzer;
538

539
  /*
540
   * A[0] = 3;
541
   * B[0] = A[0] + 1;
542
   */
543

544
  StorePtr aStore = Store::make(a, {0}, 3);
545
  StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
546

547
  StmtPtr stmt = Block::make({aStore, bStore});
548

549
  stmt->accept(&analyzer);
550

551
  ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
552
  ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
553
  // sanity check, but anything that depends directly must depend indirectly.
554
  ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore));
555
}
556

557
// Check that there is a difference between direct and indirect dependence.
558
TEST(MemDependency, MemDependencyCheckerMultiStmt) {
559
  BufHandle a("A", {1}, kInt);
560
  BufHandle b("B", {1}, kInt);
561
  BufHandle c("C", {1}, kInt);
562

563
  analysis::MemDependencyChecker analyzer;
564

565
  /*
566
   * A[0] = 3;
567
   * B[0] = A[0];
568
   * C[0] = B[0] + 1;
569
   */
570

571
  StorePtr aStore = Store::make(a, {0}, 3);
572
  StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
573
  StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1));
574

575
  StmtPtr stmt = Block::make({aStore, bStore, cStore});
576

577
  stmt->accept(&analyzer);
578

579
  // C depends on A indirectly.
580
  ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore));
581
  ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore));
582

583
  // C depends on B directly, which depends on A directly.
584
  ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore));
585
  ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
586

587
  // Dependency goes top to bottom only.
588
  ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore));
589
  ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
590
  ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore));
591
}
592

593
// Verify that we do filter writes that are totally overlapped by later writes.
594
TEST(MemDependency, MemDependencyCheckerOverlap) {
595
  BufHandle a("A", {1}, kInt);
596
  BufHandle b("B", {1}, kInt);
597

598
  analysis::MemDependencyChecker analyzer;
599

600
  /*
601
   * A[0] = 3;
602
   * A[0] = 6;
603
   * B[0] = A[0] + 1;
604
   */
605

606
  StorePtr aStore = Store::make(a, {0}, 3);
607
  StorePtr a2Store = Store::make(a, {0}, 6);
608
  StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
609

610
  StmtPtr stmt = Block::make({aStore, a2Store, bStore});
611

612
  stmt->accept(&analyzer);
613

614
  // B store depends on second A store but not first since it is completely
615
  // overlapped.
616
  ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store));
617
  ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore));
618

619
  // No dependency between either A store.
620
  ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store));
621
  ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore));
622
}
623

624
// Verify that bounds match loop iterations, and that dependencies progress
625
// across loop scopes.
626
TEST(MemDependency, MemDependencyCheckerLoop) {
627
  BufHandle a("A", {1}, kInt);
628
  BufHandle b("B", {1}, kInt);
629
  VarHandle x("x", kInt);
630

631
  using namespace analysis;
632

633
  MemDependencyChecker analyzer;
634

635
  /*
636
   * for (int x = 0; x < 10; ++x) {
637
   *   A[x] = x;
638
   * }
639
   * B[0] = A[0] + 1;
640
   */
641

642
  StorePtr aStore = Store::make(a, {x}, x);
643
  StmtPtr loop = For::make(x, 0, 10, aStore);
644
  StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1));
645

646
  StmtPtr stmt = Block::make({loop, bStore});
647

648
  stmt->accept(&analyzer);
649

650
  // Same A->B dependency.
651
  ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
652

653
  // B depends on the loop.
654
  ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
655
  // A is in the loop but does not depend on any loop iteration.
656
  ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop));
657

658
  auto aStoreAccess = analyzer.accessFor(aStore);
659
  ASSERT_NE(aStoreAccess, nullptr);
660

661
  // It should have bounds covering the range of x: 0 <= x < 10.
662
  ASSERT_TRUE(indexBoundsEquals(
663
      aStoreAccess->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
664
}
665

666
// Reductions should promote dependencies as well.
667
TEST(MemDependency, MemDependencyCheckerLoopReduce) {
668
  BufHandle a("A", {10}, kInt);
669
  BufHandle b("B", {10}, kInt);
670
  VarHandle x("x", kInt);
671

672
  using namespace analysis;
673

674
  MemDependencyChecker analyzer;
675

676
  /*
677
   * A[0] = 0;
678
   * for (int x = 0; x < 10; ++x) {
679
   *   A[0] = A[x] + 1;
680
   * }
681
   * B[0] = A[0];
682
   */
683

684
  StorePtr aInit = Store::make(a, {0}, 0);
685
  ExprHandle reduce = Sum()(a, 1, {x}, {x});
686
  StorePtr aReduce = Store::make(a, {0}, reduce);
687
  StmtPtr loop = For::make(x, 0, 10, aReduce);
688
  StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
689

690
  StmtPtr stmt = Block::make({aInit, loop, bStore});
691

692
  stmt->accept(&analyzer);
693

694
  // B -> A.
695
  ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
696

697
  // B depends indirectly on the initializer of A, since the reduction depends
698
  // on it.
699
  ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
700
  ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
701

702
  ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
703

704
  // B depends on the loop.
705
  ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
706
  // A is in the loop and depends on other iterations.
707
  ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
708

709
  // The loop contents depend on the initializer too.
710
  ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
711

712
  // Find loads within the reduction:
713
  auto reduceLoads = NodeFinder<Load>::find(reduce.node());
714
  // Pull out the access for the load inside the loop.
715
  for (auto load : reduceLoads) {
716
    auto loopLoad = analyzer.accessFor(load);
717
    // It should have 10 element long bounds.
718
    ASSERT_TRUE(indexBoundsEquals(
719
        loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
720
  }
721
}
722

723
// Lowering a reduction doesn't affect dependency analysis.
724
TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) {
725
  BufHandle a("A", {10}, kInt);
726
  BufHandle b("B", {10}, kInt);
727
  VarHandle x("x", kInt);
728

729
  using namespace analysis;
730

731
  MemDependencyChecker analyzer;
732

733
  /*
734
   * A[0] = 0;
735
   * for (int x = 0; x < 10; ++x) {
736
   *   A[0] = A[x] + 1;
737
   * }
738
   * B[0] = A[0];
739
   */
740

741
  StorePtr aInit = Store::make(a, {0}, 0);
742
  ExprHandle aLoad = Load::make(a, {x});
743
  StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1));
744
  StmtPtr loop = For::make(x, 0, 10, aReduce);
745
  StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
746

747
  StmtPtr stmt = Block::make({aInit, loop, bStore});
748

749
  stmt->accept(&analyzer);
750

751
  // B -> A.
752
  ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
753

754
  // B depends indirectly on the initializer of A, since the reduction depends
755
  // on it.
756
  ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
757
  ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
758

759
  ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
760

761
  // B depends on the loop.
762
  ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
763
  // A is in the loop and depends on other iterations.
764
  ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
765

766
  // The loop contents depend on the initializer too.
767
  ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
768

769
  // Pull out the access for the store inside the loop.
770
  auto loopLoad = analyzer.accessFor(aLoad.node());
771
  // It should have 10 element long bounds.
772
  ASSERT_TRUE(indexBoundsEquals(
773
      loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
774
}
775

776
// Can determine dependencies of outputs, through to inputs.
777
TEST(MemDependency, MemDependencyCheckerInputsOutputs) {
778
  BufHandle a("A", {10}, kInt);
779
  BufHandle b("B", {10}, kInt);
780
  VarHandle x("x", kInt);
781

782
  // initialize analyzer with inputs and outputs.
783
  analysis::MemDependencyChecker analyzer({a}, {b});
784

785
  // Here's a Relu.
786
  /*
787
   * for (int x = 0; x < 10; ++x) {
788
   *   B[x] = Max(A[x], 0);
789
   * }
790
   */
791

792
  ExprHandle aLoad = Load::make(a, {x});
793
  StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true));
794
  StmtPtr loop = For::make(x, 0, 10, bStore);
795

796
  StmtPtr stmt = Block::make({loop});
797

798
  stmt->accept(&analyzer);
799

800
  // Output depends indirectly on input.
801
  ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
802
  // aLoad depends directly on the input A.
803
  ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node()));
804
  // bStore therefore depends directly on the input A.
805
  ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node()));
806
  // The output depends directly on the store.
807
  ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
808

809
  // Check AccessInfo based overloads.
810
  auto input = analyzer.input(a.node());
811
  auto output = analyzer.output(b.node());
812

813
  // Output depends indirectly on input.
814
  ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
815
  // Not directly.
816
  ASSERT_FALSE(analyzer.dependsDirectly(output, input));
817
  // Not in reverse order.
818
  ASSERT_FALSE(analyzer.dependsIndirectly(input, output));
819

820
  // output -> bStore -> bLoad -> input.
821
  auto storeAccess = analyzer.accessFor(bStore);
822
  auto loadAccess = analyzer.accessFor(aLoad.node());
823

824
  ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess));
825
  ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input));
826
}
827

828
// Can tell if an output does not depend on an input.
829
TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) {
830
  BufHandle a("A", {10}, kInt);
831
  BufHandle b("B", {10}, kInt);
832
  VarHandle x("x", kInt);
833

834
  // initialize analyzer with inputs and outputs.
835
  analysis::MemDependencyChecker analyzer({a}, {b});
836

837
  // Here's a dumb Relu.
838
  /*
839
   * for (int x = 0; x < 10; ++x) {
840
   *   B[x] = Max(x, 0);
841
   * }
842
   */
843

844
  StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true));
845
  StmtPtr loop = For::make(x, 0, 10, bStore);
846

847
  StmtPtr stmt = Block::make({loop});
848

849
  stmt->accept(&analyzer);
850

851
  // Output does not depend indirectly on input.
852
  ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node()));
853

854
  // The output still depends directly on the store.
855
  ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
856

857
  // Check AccessInfo based overloads.
858
  auto input = analyzer.input(a.node());
859
  auto output = analyzer.output(b.node());
860

861
  // Output does not depend indirectly on input.
862
  ASSERT_FALSE(analyzer.dependsIndirectly(output, input));
863
}
864

865
// Verify different loop extents produce accesses with different bounds, and
866
// that later accesses find dependencies that overlap their entire bound range.
867
TEST(MemDependency, MemDependencyCheckerLoopBounds) {
868
  BufHandle a("A", {10}, kInt);
869
  BufHandle b("B", {10}, kInt);
870
  BufHandle c("C", {10}, kInt);
871
  VarHandle x("x", kInt);
872
  using namespace analysis;
873

874
  MemDependencyChecker analyzer({a}, {c});
875

876
  // This enables using the execution order of the loops to determine if some
877
  // loops are self dependent or not.
878
  analyzer.allowLoopExecutionOrderAnalysis();
879

880
  /*
881
   * for (int x = 1; x < 10; ++x) {
882
   *   B[x] = A[x];
883
   * }
884
   * for (int x = 1; x < 9; ++x) {
885
   *   B[x] = B[x] * 2;
886
   * }
887
   * for (int x = 3; x < 4; ++x) {
888
   *   C[x] = A[x];
889
   * }
890
   * for (int x = 0; x < 10; ++x) {
891
   *   C[x] = B[x];
892
   * }
893
   */
894

895
  std::vector<StmtPtr> stmts(
896
      {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))),
897
       For::make(
898
           x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))),
899
       For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))),
900
       For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))});
901

902
  StmtPtr stmt = Block::make(stmts);
903

904
  stmt->accept(&analyzer);
905

906
  auto input = analyzer.input(a.node());
907
  auto output = analyzer.output(c.node());
908

909
  // sanity check Output -> Input.
910
  ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
911

912
  // Check the For loop dependencies:
913

914
  // Last write to C depends on both writes to B since they contain the last
915
  // write to at least one element.
916
  ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1]));
917
  ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0]));
918

919
  // The last write to C does not depend on the other write to C.
920
  ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2]));
921

922
  auto CB = [](int s, int e) {
923
    return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
924
  };
925
  auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
926
    return indexBoundsEquals(x, y);
927
  };
928

929
  /*  0. Input: A[(0, 9)] - dependents: 1 5
930
   *  1. Load: A[(1, 9)] - depends on: 0  - dependents: 2
931
   *  2. Store: B[(1, 9)] - depends on: 1  - dependents: 3 7
932
   *  3. Load: B[(1, 8)] - depends on: 2  - dependents: 4
933
   *  4. Store: B[(1, 8)] - depends on: 3  - dependents: 7
934
   *  5. Load: A[(3, 3)] - depends on: 0  - dependents: 6
935
   *  6. Store: C[(3, 3)] - depends on: 5
936
   *  7. Load: B[(0, 9)] - depends on: 2 4  - dependents: 8
937
   *  8. Store: C[(0, 9)] - depends on: 7  - dependents: 9
938
   *  9. Output: C[(0, 9)] - depends on: 8
939
   */
940

941
  // Now let's look at the bounds of each access.
942
  // There are 9 accesses in this Stmt, so this is exhaustive, we wont do this
943
  // much.
944
  auto history = analyzer.getHistory();
945
  ASSERT_EQ(history.size(), 10);
946
  VarPtr aVar = a.node()->base_handle();
947
  VarPtr bVar = b.node()->base_handle();
948
  VarPtr cVar = c.node()->base_handle();
949

950
  // The first access is the input A.
951
  ASSERT_EQ(history[0]->type(), AccessType::Input);
952
  ASSERT_EQ(history[0]->var(), aVar);
953
  // It has the bounds of the producing Input.
954
  ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
955
  // sanity check the input we retrieved earlier matches.
956
  ASSERT_EQ(history[0], input);
957

958
  // The second access is the load of A in the first loop.
959
  ASSERT_EQ(history[1]->type(), AccessType::Load);
960
  ASSERT_EQ(history[1]->var(), aVar);
961
  // It has the bounds of the loop, i.e. start == 1.
962
  ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)}));
963
  // It reads from A, so it should have a dependency on the last write to this
964
  // range - with is the input.
965
  ASSERT_EQ(history[1]->dependencies().size(), 1);
966
  ASSERT_TRUE(history[1]->hasDependency(history[0]));
967

968
  // The third access is the store into B in the first loop.
969
  ASSERT_EQ(history[2]->type(), AccessType::Store);
970
  ASSERT_EQ(history[2]->var(), bVar);
971
  // It also has the bounds of the loop, i.e. start == 1.
972
  ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
973
  // The previous load is in its RHS, so it depends on it.
974
  ASSERT_EQ(history[2]->dependencies().size(), 1);
975
  ASSERT_TRUE(history[2]->hasDependency(history[1]));
976

977
  // The third access is the load from B in the second loop.
978
  ASSERT_EQ(history[3]->type(), AccessType::Load);
979
  ASSERT_EQ(history[3]->var(), bVar);
980
  // It has the bounds of the second loop, i.e. >= 1 < 9.
981
  ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)}));
982
  // It reads from B in a smaller range, so should depend on the previous
983
  // store.
984
  ASSERT_EQ(history[3]->dependencies().size(), 1);
985
  ASSERT_TRUE(history[3]->hasDependency(history[2]));
986

987
  // The fourth: the store to B in the second loop.
988
  ASSERT_EQ(history[4]->type(), AccessType::Store);
989
  ASSERT_EQ(history[4]->var(), bVar);
990
  // It also has the bounds of the second loop.
991
  ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)}));
992
  // The previous load is in its RHS, so it depends on it as before.
993
  ASSERT_EQ(history[4]->dependencies().size(), 1);
994
  ASSERT_TRUE(history[4]->hasDependency(history[3]));
995

996
  // The fifth access is the load is from the 3rd loop, and skips previous B
997
  // accesses.
998
  ASSERT_EQ(history[5]->type(), AccessType::Load);
999
  ASSERT_EQ(history[5]->var(), aVar);
1000
  // It has the bounds of the third loop: >= 3 < 4.
1001
  ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)}));
1002
  // It depends on the last thing to write to A, which is the A input.
1003
  ASSERT_EQ(history[5]->dependencies().size(), 1);
1004
  ASSERT_TRUE(history[5]->hasDependency(history[0]));
1005

1006
  // Sixth: the store into the output C.
1007
  ASSERT_EQ(history[6]->type(), AccessType::Store);
1008
  ASSERT_EQ(history[6]->var(), cVar);
1009
  // It also has the bounds of the third loop.
1010
  ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)}));
1011
  // The previous load is in its RHS, so it depends on it as always.
1012
  ASSERT_EQ(history[6]->dependencies().size(), 1);
1013
  ASSERT_TRUE(history[6]->hasDependency(history[5]));
1014

1015
  // The seventh access is the load of B in the fourth loop.
1016
  ASSERT_EQ(history[7]->type(), AccessType::Load);
1017
  ASSERT_EQ(history[7]->var(), bVar);
1018
  // It has the bounds of the final loop, >= 0 < 10
1019
  ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
1020
  // The bounds of this read are larger than the bounds of the previous write,
1021
  // so it depends on both previous Stores to B.
1022
  ASSERT_EQ(history[7]->dependencies().size(), 2);
1023
  ASSERT_TRUE(history[7]->hasDependency(history[2]));
1024
  ASSERT_TRUE(history[7]->hasDependency(history[4]));
1025

1026
  // Eight: the final store into the output C.
1027
  ASSERT_EQ(history[8]->type(), AccessType::Store);
1028
  ASSERT_EQ(history[8]->var(), cVar);
1029
  // It also has the bounds of the final loop.
1030
  ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
1031
  // The previous load is in its RHS, so it depends on it as always.
1032
  ASSERT_EQ(history[8]->dependencies().size(), 1);
1033
  ASSERT_TRUE(history[8]->hasDependency(history[7]));
1034

1035
  // The last access represents the output Buf.
1036
  ASSERT_EQ(history[9]->type(), AccessType::Output);
1037
  ASSERT_EQ(history[9]->var(), cVar);
1038
  // It has the bounds of the output Buf.
1039
  ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)}));
1040
  // sanity check the input we retrieved earlier matches.
1041
  ASSERT_EQ(history[9], output);
1042
  // It depends on the last write to C only.
1043
  ASSERT_EQ(history[9]->dependencies().size(), 1);
1044
  ASSERT_TRUE(history[9]->hasDependency(history[8]));
1045
}
1046

1047
// Verify that we can still infer bounds when the loop var is offset.
1048
TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) {
1049
  BufHandle a("A", {10}, kInt);
1050
  BufHandle b("B", {10}, kInt);
1051
  VarHandle x("x", kInt);
1052

1053
  using namespace analysis;
1054

1055
  MemDependencyChecker analyzer({a}, {b});
1056

1057
  // This enables using the execution order of the loops to determine if some
1058
  // loops are self dependent or not.
1059
  analyzer.allowLoopExecutionOrderAnalysis();
1060

1061
  /*
1062
   * for (int x = 1; x < 10; x++) {
1063
   *   A[x] = A[x - 1];
1064
   * }
1065
   * for (int x = 0; x < 9; x++) {
1066
   *   A[x] = A[x + 1];
1067
   * }
1068
   * for (int x = 0; x < 9; x++) {
1069
   *   A[9 - x] = A[8 - x];
1070
   * }
1071
   * for (int x = 0; x < 10; x++) {
1072
   *   A[x] = A[9 - x];
1073
   * }
1074
   * for (int x = 0; x < 10; x++) {
1075
   *   B[x] = A[x];
1076
   * }
1077
   */
1078

1079
  StmtPtr stmt = Block::make(
1080
      {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
1081
       For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))),
1082
       For::make(
1083
           x,
1084
           0,
1085
           9,
1086
           Store::make(
1087
               a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))),
1088
       For::make(
1089
           x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))),
1090
       For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))});
1091

1092
  stmt->accept(&analyzer);
1093

1094
  // Sanity check output depends on Input.
1095
  ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
1096

1097
  auto CB = [](int s, int e) {
1098
    return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
1099
  };
1100
  auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
1101
    return indexBoundsEquals(x, y);
1102
  };
1103

1104
  /*  0. Input: A[(0, 9)] - dependents: 1
1105
   *  1. Load: A[(0, 8)] - depends on: 0 2  - dependents: 2
1106
   *  2. Store: A[(1, 9)] - depends on: 1  - dependents: 1 3
1107
   *  3. Load: A[(1, 9)] - depends on: 2  - dependents: 4
1108
   *  4. Store: A[(0, 8)] - depends on: 3  - dependents: 5 7
1109
   *  5. Load: A[(0, 8)] - depends on: 4  - dependents: 6
1110
   *  6. Store: A[(1, 9)] - depends on: 5  - dependents: 7
1111
   *  7. Load: A[(0, 9)] - depends on: 4 6 8  - dependents: 8
1112
   *  8. Store: A[(0, 9)] - depends on: 7  - dependents: 7 9
1113
   *  9. Load: A[(0, 9)] - depends on: 8  - dependents: 10
1114
   *  10. Store: B[(0, 9)] - depends on: 9  - dependents: 11
1115
   *  11. Output: B[(0, 9)] - depends on: 10
1116
   */
1117

1118
  // Now let's look at the bounds of each access.
1119
  auto history = analyzer.getHistory();
1120
  ASSERT_EQ(history.size(), 12);
1121
  VarPtr aVar = a.node()->base_handle();
1122
  VarPtr bVar = b.node()->base_handle();
1123

1124
  // The first access is the input A.
1125
  ASSERT_EQ(history[0]->type(), AccessType::Input);
1126
  ASSERT_EQ(history[0]->var(), aVar);
1127
  // It has the bounds of the producing Input.
1128
  ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
1129

1130
  // The second access is the load A[x-1].
1131
  ASSERT_EQ(history[1]->type(), AccessType::Load);
1132
  ASSERT_EQ(history[1]->var(), aVar);
1133
  // It has the bounds of the loop modified by the offset of each index, in
1134
  // this case -1.
1135
  ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)}));
1136
  // It depends on the input, but also the store in the same loop, since
1137
  // different interations of the loop depend on each other.
1138
  ASSERT_EQ(history[1]->dependencies().size(), 2);
1139
  ASSERT_TRUE(history[1]->hasDependency(history[0]));
1140
  ASSERT_TRUE(history[1]->hasDependency(history[2]));
1141

1142
  // The third access is the Store to A[x] in the first loop.
1143
  ASSERT_EQ(history[2]->type(), AccessType::Store);
1144
  ASSERT_EQ(history[2]->var(), aVar);
1145
  // It has no offset on x, so should have the same bounds as the loop.
1146
  ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
1147

1148
  // The fourth access is the load A[x+1] in the second loop.
1149
  ASSERT_EQ(history[3]->type(), AccessType::Load);
1150
  ASSERT_EQ(history[3]->var(), aVar);
1151
  // It has the bounds of the loop (0 <= x < 9) modified by the offset of each
1152
  // index, in this case 1.
1153
  ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)}));
1154
  // This load totally overlaps the previous write to A, so it depends only on
1155
  // it and not the input.
1156
  ASSERT_EQ(history[3]->dependencies().size(), 1);
1157
  ASSERT_TRUE(history[3]->hasDependency(history[2]));
1158

1159
  // The fifth access is the store to A[x] in the second loop.
1160
  ASSERT_EQ(history[4]->type(), AccessType::Store);
1161
  ASSERT_EQ(history[4]->var(), aVar);
1162
  // It has no offset on x, so should have the same bounds as the loop.
1163
  ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)}));
1164

1165
  // The sixth access is the load to A[8 - x] in the third loop.
1166
  ASSERT_EQ(history[5]->type(), AccessType::Load);
1167
  ASSERT_EQ(history[5]->var(), aVar);
1168
  // It has the bounds of the loop (0 <= x < 9) modified by the offset of each
1169
  // index, in this case 8 - x.
1170
  // This access has a negative stride, which will be normalized.
1171
  ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)}));
1172
  // This load totally overlaps the most recent write to A, so it depends only
1173
  // on it and not the input or the first write to A.
1174
  ASSERT_EQ(history[5]->dependencies().size(), 1);
1175
  ASSERT_TRUE(history[5]->hasDependency(history[4]));
1176

1177
  // The seventh access is the store to A[9 - x] in the third loop.
1178
  ASSERT_EQ(history[6]->type(), AccessType::Store);
1179
  ASSERT_EQ(history[6]->var(), aVar);
1180
  // This store has a negative stride on it's indices, but is normalized
1181
  // internally.
1182
  ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)}));
1183

1184
  // The eighth access is the load A[9-x] in the second loop.
1185
  ASSERT_EQ(history[7]->type(), AccessType::Load);
1186
  ASSERT_EQ(history[7]->var(), aVar);
1187
  // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x,
1188
  // which essentially traverses the loop backwards.
1189
  ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
1190
  // This Load has three write dependencies:
1191
  ASSERT_EQ(history[7]->dependencies().size(), 3);
1192
  //  * The previous store (#6) for elements 1-9
1193
  ASSERT_TRUE(history[7]->hasDependency(history[6]));
1194
  //  * An earlier store (#4) covering element 0
1195
  ASSERT_TRUE(history[7]->hasDependency(history[4]));
1196
  //  * A future store inside this loop, since this loop modifies the buffer
1197
  //  in a non distinct way (due to the load and store having different access
1198
  //  strides).
1199
  ASSERT_TRUE(history[7]->hasDependency(history[8]));
1200

1201
  // The ninth access is the store to A[x] in the fourth loop.
1202
  ASSERT_EQ(history[8]->type(), AccessType::Store);
1203
  ASSERT_EQ(history[8]->var(), aVar);
1204
  // This store has a negative stride on it's indices, but is normalized
1205
  // internally.
1206
  ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
1207

1208
  // The tenth and 11th accesses are the copy from A[x] to B[x].
1209
  ASSERT_EQ(history[9]->type(), AccessType::Load);
1210
  ASSERT_EQ(history[9]->var(), aVar);
1211
  ASSERT_EQ(history[10]->type(), AccessType::Store);
1212
  ASSERT_EQ(history[10]->var(), bVar);
1213

1214
  // The last access represents the output Buf.
1215
  ASSERT_EQ(history[11]->type(), AccessType::Output);
1216
  ASSERT_EQ(history[11]->var(), bVar);
1217
  // It has the bounds of the output Buf.
1218
  ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)}));
1219
  // It depends on the last write to B only.
1220
  ASSERT_EQ(history[11]->dependencies().size(), 1);
1221
  ASSERT_TRUE(history[11]->hasDependency(history[10]));
1222

1223
  // ok that's enough of that.
1224
}
1225

1226
// Check many different cases of loop self dependency - when a load within a
1227
// loop is dependent on a Store later in the same loop but in different
1228
// iteration. This is affected by whether or not we can trust the execution
1229
// order of the loop.
1230
TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) {
1231
  BufHandle a("A", {5}, kInt);
1232
  BufHandle b("B", {5}, kInt);
1233
  VarHandle x("x", kInt);
1234
  VarHandle y("y", kInt);
1235
  VarHandle z("z", kInt);
1236

1237
  using namespace analysis;
1238

1239
  // This check assumes that the Stmt has a single Store with a single Load on
1240
  // the RHS.
1241
  auto isSelfDependent =
1242
      [](const std::vector<std::shared_ptr<AccessInfo>>& history) -> bool {
1243
    return history.front()->hasDependency(history.back());
1244
  };
1245

1246
  {
1247
    /* for (int y = 0; y < 10; y++) {
1248
     *   A[y] = (A[y]) + 1;
1249
     * } */
1250

1251
    // Not self dependent since all loop iterations use a different y.
1252

1253
    MemDependencyChecker analyzer;
1254
    StmtPtr stmt = For::make(
1255
        y,
1256
        0,
1257
        10,
1258
        Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))}));
1259

1260
    stmt->accept(&analyzer);
1261

1262
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1263
  }
1264

1265
  {
1266
    /* for (int y = 0; y < 10; y++) {
1267
     *   A[y + 1] = (A[y + 1]) + 1;
1268
     * }
1269
     */
1270

1271
    // Not self dependent due to different y (with offset).
1272

1273
    MemDependencyChecker analyzer;
1274
    StmtPtr stmt = For::make(
1275
        y,
1276
        0,
1277
        10,
1278
        Block::make(
1279
            {Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))}));
1280

1281
    stmt->accept(&analyzer);
1282

1283
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1284
  }
1285

1286
  {
1287
    /* for (int x = 0; x < 10; x++) {
1288
     *   A[0] = (A[0]) + x;
1289
     * }
1290
     */
1291

1292
    // Is self dependent since all loops use a common constant element of A.
1293

1294
    MemDependencyChecker analyzer;
1295
    StmtPtr stmt = For::make(
1296
        x,
1297
        0,
1298
        10,
1299
        Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}));
1300
    stmt->accept(&analyzer);
1301

1302
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1303
  }
1304

1305
  {
1306
    /* for (int x = 0; x < 10; x++) {
1307
     *   A[0] = (B[0]) + x;
1308
     * }
1309
     */
1310

1311
    // Is not self dependent because there is no store to the buffer that is
1312
    // read.
1313

1314
    MemDependencyChecker analyzer;
1315
    StmtPtr stmt = For::make(
1316
        x,
1317
        0,
1318
        10,
1319
        Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))}));
1320
    stmt->accept(&analyzer);
1321

1322
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1323
  }
1324

1325
  {
1326
    /* for (int x = 0; x < 10; x++) {
1327
     *   A[y] = (A[y]) + x;
1328
     * }
1329
     */
1330

1331
    // Is self dependent since all loops use a common symbolic element of A.
1332

1333
    MemDependencyChecker analyzer;
1334
    StmtPtr stmt = For::make(
1335
        x,
1336
        0,
1337
        10,
1338
        Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))}));
1339
    stmt->accept(&analyzer);
1340

1341
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1342
  }
1343

1344
  {
1345
    /* for (int x = 0; x < 10; x++) {
1346
     *   A[x] = A[x + 1];
1347
     * }
1348
     */
1349

1350
    // In this case it depends if we are considering execution order.
1351

1352
    MemDependencyChecker analyzer;
1353

1354
    StmtPtr stmt =
1355
        For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
1356
    stmt->accept(&analyzer);
1357

1358
    // With analysis of order disabled, this is self dependent since the read
1359
    // from X+1 and the write to X+1 could be in reverse order.
1360
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1361
  }
1362

1363
  {
1364
    /* for (int x = 0; x < 10; x++) {
1365
     *   A[x] = A[x + 1];
1366
     * }
1367
     */
1368

1369
    MemDependencyChecker analyzer;
1370
    analyzer.allowLoopExecutionOrderAnalysis();
1371

1372
    StmtPtr stmt =
1373
        For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
1374
    stmt->accept(&analyzer);
1375

1376
    // If order analysis is enabled, this is not dependent since the read for
1377
    // each element occurs before the write to that element.
1378
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1379
  }
1380

1381
  {
1382
    /* for (int x = 1; x < 10; x++) {
1383
     *   A[x] = A[x - 1];
1384
     * }
1385
     */
1386

1387
    MemDependencyChecker analyzer;
1388

1389
    StmtPtr stmt =
1390
        For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
1391
    stmt->accept(&analyzer);
1392

1393
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1394
  }
1395

1396
  {
1397
    /* for (int x = 1; x < 10; x++) {
1398
     *   A[x] = A[x - 1];
1399
     * }
1400
     */
1401

1402
    MemDependencyChecker analyzer;
1403
    analyzer.allowLoopExecutionOrderAnalysis();
1404

1405
    StmtPtr stmt =
1406
        For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
1407
    stmt->accept(&analyzer);
1408

1409
    // In this case, even with order analysis the Load is dependent on the
1410
    // Store, since the write to X occurs before the read from X.
1411
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1412
  }
1413

1414
  {
1415
    /* for (int x = 0; x < 9; x++) {
1416
     *   A[9 - x] = A[8 - x];
1417
     * }
1418
     */
1419

1420
    // Still works if the execution order is reversed, so long as the read
1421
    // comes before the write.
1422

1423
    MemDependencyChecker analyzer;
1424
    analyzer.allowLoopExecutionOrderAnalysis();
1425

1426
    StmtPtr stmt = For::make(
1427
        x,
1428
        3,
1429
        10,
1430
        Store::make(
1431
            a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
1432
    stmt->accept(&analyzer);
1433

1434
    // However here was can determine the A store is earlier in the order than
1435
    // the load.
1436
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1437
  }
1438

1439
  {
1440
    /* for (int x = 0; x < 9; x++) {
1441
     *   A[8 - x] = A[9 - x];
1442
     * }
1443
     */
1444

1445
    // But not if it doesn't.
1446

1447
    MemDependencyChecker analyzer;
1448
    analyzer.allowLoopExecutionOrderAnalysis();
1449

1450
    StmtPtr stmt = For::make(
1451
        x,
1452
        3,
1453
        10,
1454
        Store::make(
1455
            a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x})));
1456
    stmt->accept(&analyzer);
1457

1458
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1459
  }
1460

1461
  {
1462
    /* for (int x = 0; x < 9; x++) {
1463
     *   A[9 - x] = A[8 - x];
1464
     * }
1465
     */
1466

1467
    // And not if we're not relying on execution order.
1468

1469
    MemDependencyChecker analyzer;
1470

1471
    StmtPtr stmt = For::make(
1472
        x,
1473
        3,
1474
        10,
1475
        Store::make(
1476
            a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
1477
    stmt->accept(&analyzer);
1478

1479
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1480
  }
1481

1482
  {
1483
    /* for (int x = 3; x < 10; x++) {
1484
     *   A[x - 2] = A[x - 1];
1485
     * }
1486
     */
1487

1488
    // Forward order but negative indices.
1489

1490
    MemDependencyChecker analyzer;
1491
    analyzer.allowLoopExecutionOrderAnalysis();
1492

1493
    StmtPtr stmt =
1494
        For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1})));
1495
    stmt->accept(&analyzer);
1496

1497
    // However here was can determine the A store is earlier in the order than
1498
    // the load.
1499
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1500
  }
1501

1502
  {
1503
    /* for (int x = 0; x < 10; x++) {
1504
     *   A[x * 2] = A[x * 2];
1505
     * }
1506
     */
1507

1508
    // With an access stride.
1509

1510
    MemDependencyChecker analyzer;
1511
    // Execution order doesn't matter since the read and the write are totally
1512
    // distinct.
1513

1514
    StmtPtr stmt =
1515
        For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2})));
1516
    stmt->accept(&analyzer);
1517

1518
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1519
  }
1520

1521
  {
1522
    /* for (int x = 0; x < 10; x++) {
1523
     *   A[x * 2] = A[x * 2 + 1];
1524
     * }
1525
     */
1526

1527
    // Here we can use the common stride of the accesses to determine they are
1528
    // distinct.
1529
    // Note, this is the only place (loop self dependency) we use this stride
1530
    // to avoid unnecessary dependence.
1531

1532
    MemDependencyChecker analyzer;
1533
    // Execution order doesn't matter since the read and the write are totally
1534
    // distinct.
1535

1536
    StmtPtr stmt = For::make(
1537
        x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1})));
1538
    stmt->accept(&analyzer);
1539

1540
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1541
  }
1542

1543
  {
1544
    /* for (int x = 0; x < 10; x++) {
1545
     *   A[x * 2] = A[x * 2 - 1];
1546
     * }
1547
     */
1548

1549
    // same if the read is behind the write so long as they are distinct.
1550

1551
    MemDependencyChecker analyzer;
1552
    StmtPtr stmt = For::make(
1553
        x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1})));
1554
    stmt->accept(&analyzer);
1555

1556
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1557
  }
1558

1559
  {
1560
    /* for (int x = 0; x < 10; x++) {
1561
     *   A[x * 2] = A[x * 2 + 2];
1562
     * }
1563
     */
1564

1565
    // But not if the offset is in the stride.
1566

1567
    MemDependencyChecker analyzer;
1568
    StmtPtr stmt = For::make(
1569
        x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2})));
1570
    stmt->accept(&analyzer);
1571

1572
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1573
  }
1574

1575
  {
1576
    /* for (int x = 0; x < 10; x++) {
1577
     *   A[x * 2] = A[x * 2 - 2];
1578
     * }
1579
     */
1580

1581
    // Works with negative offsets too.
1582

1583
    MemDependencyChecker analyzer;
1584
    StmtPtr stmt = For::make(
1585
        x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2})));
1586
    stmt->accept(&analyzer);
1587

1588
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1589
  }
1590

1591
  {
1592
    /* for (int x = 0; x < 10; x++) {
1593
     *   A[x * 2] = A[x * 2 + 7];
1594
     * }
1595
     */
1596

1597
    // Detects accesses are distinct when offset is large but not a multiple
1598
    // of stride.
1599
    MemDependencyChecker analyzer;
1600
    StmtPtr stmt = For::make(
1601
        x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7})));
1602
    stmt->accept(&analyzer);
1603

1604
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1605
  }
1606

1607
  {
1608
    /* for (int x = 0; x < 10; x++) {
1609
     *   A[x * 2] = A[x * 2 + 4];
1610
     * }
1611
     */
1612

1613
    // Works with offsets which are multiples of the stride.
1614
    MemDependencyChecker analyzer;
1615
    StmtPtr stmt = For::make(
1616
        x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4})));
1617
    stmt->accept(&analyzer);
1618

1619
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1620
  }
1621

1622
  {
1623
    /* for (int x = 0; x < 10; x++) {
1624
     *   A[x * 6] = A[x * 6 + 5];
1625
     * }
1626
     */
1627

1628
    // detects accesses are distinct with large strides when the offset is
1629
    // within.
1630

1631
    MemDependencyChecker analyzer;
1632
    StmtPtr stmt = For::make(
1633
        x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5})));
1634
    stmt->accept(&analyzer);
1635

1636
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1637
  }
1638

1639
  {
1640
    /* for (int x = 0; x < 10; x++) {
1641
     *   A[x * 2] = A[x * 6];
1642
     * }
1643
     */
1644

1645
    // detects accesses are overlapping when stride is different but a
1646
    // multiple.
1647

1648
    MemDependencyChecker analyzer;
1649
    StmtPtr stmt =
1650
        For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6})));
1651
    stmt->accept(&analyzer);
1652

1653
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1654
  }
1655

1656
  {
1657
    /* for (int x = 0; x < 10; x++) {
1658
     *   A[x * 4] = A[x * 2];
1659
     * }
1660
     */
1661

1662
    // still works when the read axis is the smaller stride.
1663

1664
    MemDependencyChecker analyzer;
1665
    StmtPtr stmt =
1666
        For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2})));
1667
    stmt->accept(&analyzer);
1668

1669
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1670
  }
1671

1672
  {
1673
    /* for (int x = 0; x < 10; x++) {
1674
     *   A[x * 2] = A[x * 6 + 1];
1675
     * }
1676
     */
1677

1678
    // detects accesses are distinct when stride is different but a multiple
1679
    // and there is an offset.
1680

1681
    MemDependencyChecker analyzer;
1682
    StmtPtr stmt = For::make(
1683
        x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1})));
1684
    stmt->accept(&analyzer);
1685

1686
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1687
  }
1688

1689
  {
1690
    /* for (int x = 0; x < 10; x++) {
1691
     *   A[x * 2] = A[x * 6 + 4];
1692
     * }
1693
     */
1694

1695
    // The smaller stride determines whether there is overlap.
1696

1697
    MemDependencyChecker analyzer;
1698
    StmtPtr stmt = For::make(
1699
        x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4})));
1700
    stmt->accept(&analyzer);
1701

1702
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1703
  }
1704

1705
  {
1706
    /* for (int x = 0; x < 10; x++) {
1707
     *   A[x * 2 + 3] = A[x * 6];
1708
     * }
1709
     */
1710

1711
    // The smaller stride determines whether there is overlap, not the larger.
1712

1713
    MemDependencyChecker analyzer;
1714
    StmtPtr stmt = For::make(
1715
        x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6})));
1716
    stmt->accept(&analyzer);
1717

1718
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1719
  }
1720

1721
  {
1722
    /* for (int x = 0; x < 10; x++) {
1723
     *   A[x * 2] = A[x * 3 + 1];
1724
     * }
1725
     */
1726

1727
    // If they have strides with no common multiple > 1, they overlap.
1728
    MemDependencyChecker analyzer;
1729
    StmtPtr stmt = For::make(
1730
        x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1})));
1731
    stmt->accept(&analyzer);
1732

1733
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1734
  }
1735

1736
  {
1737
    /* for (int x = 0; x < 10; x++) {
1738
     *   A[x] = A[x + 10];
1739
     * }
1740
     */
1741

1742
    // If the offset is greater than the size of the loop, they can't overlap.
1743

1744
    MemDependencyChecker analyzer;
1745
    StmtPtr stmt =
1746
        For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10})));
1747
    stmt->accept(&analyzer);
1748

1749
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1750
  }
1751

1752
  {
1753
    /* for (int x = 0; x < 10; x++) {
1754
     *   A[x] = A[9 - x];
1755
     * }
1756
     */
1757

1758
    // If they have different execution orders they may overlap.
1759
    MemDependencyChecker analyzer;
1760
    StmtPtr stmt = For::make(
1761
        x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x})));
1762
    stmt->accept(&analyzer);
1763

1764
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1765
  }
1766

1767
  {
1768
    /* for (int x = 0; x < 10; x++) {
1769
     *   A[x * 2] = A[19 - x * 2];
1770
     * }
1771
     */
1772

1773
    // Or they may not, depending on their start offset and strides.
1774
    MemDependencyChecker analyzer;
1775
    StmtPtr stmt = For::make(
1776
        x,
1777
        0,
1778
        10,
1779
        Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2})));
1780
    stmt->accept(&analyzer);
1781

1782
    ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1783
  }
1784

1785
  {
1786
    /* for (int x = 0; x < 10; x++) {
1787
     *   A[x / 2] = A[x / 2];
1788
     * }
1789
     */
1790

1791
    // If the stride is not monotonic, they overlap.
1792

1793
    MemDependencyChecker analyzer;
1794
    StmtPtr stmt =
1795
        For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2})));
1796
    stmt->accept(&analyzer);
1797

1798
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1799
  }
1800

1801
  {
1802
    /* for (int x = 0; x < 10; x++) {
1803
     *   A[x / 2] = A[x / 2] + 1;
1804
     * }
1805
     */
1806

1807
    // If the stride is not monotonic, they overlap - even with an offset.
1808
    MemDependencyChecker analyzer;
1809
    StmtPtr stmt = For::make(
1810
        x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1})));
1811
    stmt->accept(&analyzer);
1812

1813
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1814
  }
1815

1816
  {
1817
    /* for (int x = 0; x < 10; x++) {
1818
     *   A[x % 2] = A[x % 2];
1819
     * }
1820
     */
1821

1822
    // Mod too...
1823

1824
    analysis::MemDependencyChecker analyzer;
1825
    StmtPtr stmt = For::make(
1826
        x,
1827
        0,
1828
        10,
1829
        Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)})));
1830
    stmt->accept(&analyzer);
1831

1832
    ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1833
  }
1834

1835
  {
1836
    /* for (int x = y; x < z; x++) {
1837
     *   A[x] = A[x + 1];
1838
     * }
1839
     */
1840

1841
    // Still works with symbolic loop extents.
1842

1843
    {
1844
      MemDependencyChecker analyzer;
1845
      StmtPtr stmt =
1846
          For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
1847
      stmt->accept(&analyzer);
1848

1849
      ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1850
    }
1851

1852
    {
1853
      MemDependencyChecker analyzer;
1854
      analyzer.allowLoopExecutionOrderAnalysis();
1855
      StmtPtr stmt =
1856
          For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
1857
      stmt->accept(&analyzer);
1858

1859
      ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1860
    }
1861
  }
1862
}
1863

1864
// Verify that a strided access still works.
1865
// TODO: actually this only works because of the size of the ranges, revisit
1866
// this test after strided overlap is implemented.
1867
TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
1868
  BufHandle a("A", {20}, kInt);
1869
  BufHandle b("B", {20}, kInt);
1870
  VarHandle x("x", kInt);
1871
  VarHandle y("y", kInt);
1872

1873
  using namespace analysis;
1874
  MemDependencyChecker analyzer({a.node()}, {b.node()});
1875
  StmtPtr stmt = Block::make(
1876
      {For::make(
1877
           x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
1878
       For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2})))
1879

1880
      });
1881
  stmt->accept(&analyzer);
1882

1883
  // Sanity check output depends on input.
1884
  ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
1885

1886
  // Output has 2 dependencies... the store in each loop.
1887
  auto outputAccess = analyzer.output(b.node());
1888
  ASSERT_EQ(outputAccess->dependencies().size(), 2);
1889
}
1890

1891
/* TODO(nickg) - this test will fail due to the lack of stride math in Bound
1892
TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
1893
  BufHandle a("A", {20}, kInt);
1894
  BufHandle b("B", {20}, kInt);
1895
  BufHandle c("C", {10}, kInt);
1896
  VarHandle x("x", kInt);
1897
  VarHandle y("y", kInt);
1898

1899
  {
1900
    analysis::MemDependencyChecker analyzer({a.node()}, {c.node()});
1901
    StmtPtr stmt = Block::make(
1902
        {For::make(
1903
             x,
1904
             0,
1905
             10,
1906
             Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
1907
         For::make(
1908
             x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))),
1909
         For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))
1910

1911
        });
1912
    stmt->accept(&analyzer);
1913

1914
    std::cout << *stmt << "\n";
1915
    for (auto& wi : analyzer.getHistory()) {
1916
      wi->print();
1917
    }
1918
  }
1919
}*/
1920

1921
// analysis on Stmts using Cond.
1922
TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) {
1923
  BufHandle a("A", {10}, kInt);
1924
  BufHandle b("B", {10}, kInt);
1925
  BufHandle c("C", {10}, kInt);
1926
  VarHandle x("x", kInt);
1927
  VarHandle y("y", kInt);
1928

1929
  using namespace analysis;
1930

1931
  {
1932
    /* for (int x = 0; x < 10; x++) {
1933
     *   C[x] = A[x];
1934
     * }
1935
     * if (y<5 ? 1 : 0) {
1936
     *   C[0] = (B[0]) + 1;
1937
     * } else {
1938
     *   C[0] = (B[1]) + 1;
1939
     * }
1940
     */
1941

1942
    // Future usages may depend on accesses in both branches of a condition.
1943

1944
    MemDependencyChecker analyzer({a, b}, {c});
1945
    StmtPtr stmt = Block::make(
1946
        {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
1947
         Cond::make(
1948
             CompareSelect::make(y, 5, CompareSelectOperation::kLT),
1949
             Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)),
1950
             Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))});
1951

1952
    stmt->accept(&analyzer);
1953

1954
    // Output C should have 3 dependencies, each of the three stores.
1955
    auto outputAccess = analyzer.output(c.node());
1956
    ASSERT_NE(outputAccess, nullptr);
1957
    ASSERT_EQ(outputAccess->dependencies().size(), 3);
1958

1959
    // C depends indirectly on A and B.
1960
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
1961
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
1962
  }
1963

1964
  {
1965
    /* for (int x = 0; x < 10; x++) {
1966
     *   C[x] = A[x];
1967
     * }
1968
     * if (y<5 ? 1 : 0) {
1969
     *   for (int x = 0; x < 10; x++) {
1970
     *     C[x] = B[x];
1971
     *   }
1972
     * } else {
1973
     *   for (int x = 0; x < 10; x++) {
1974
     *     C[x] = (B[x]) + 1;
1975
     *   }
1976
     * }
1977
     */
1978

1979
    // Future usages may depend on accesses in both branches of a condition.
1980

1981
    MemDependencyChecker analyzer({a, b}, {c});
1982
    StmtPtr stmt = Block::make(
1983
        {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
1984
         Cond::make(
1985
             CompareSelect::make(y, 5, CompareSelectOperation::kLT),
1986
             For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))),
1987
             For::make(
1988
                 x,
1989
                 0,
1990
                 10,
1991
                 Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
1992

1993
    stmt->accept(&analyzer);
1994

1995
    // Output C should have 3 dependencies, each of the three stores.
1996
    auto outputAccess = analyzer.output(c.node());
1997
    ASSERT_NE(outputAccess, nullptr);
1998
    ASSERT_EQ(outputAccess->dependencies().size(), 3);
1999

2000
    // TODO(nickg): actually since the true and false branch cover the total
2001
    // range of the first store this should have 2 dependencies, but we don't
2002
    // do that yet.
2003

2004
    // C depends indirectly on A and B.
2005
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2006
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2007
  }
2008

2009
  {
2010
    /* for (int x = 0; x < 10; x++) {
2011
     *   C[x] = A[x];
2012
     * }
2013
     * if (y<5 ? 1 : 0) {
2014
     *   for (int x = 0; x < 10; x++) {
2015
     *     C[x] = (B[x]) + 1;
2016
     *   }
2017
     * }
2018
     */
2019

2020
    // Only has true branch.
2021

2022
    MemDependencyChecker analyzer({a, b}, {c});
2023
    StmtPtr stmt = Block::make(
2024
        {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2025
         Cond::make(
2026
             CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2027
             For::make(
2028
                 x,
2029
                 0,
2030
                 10,
2031
                 Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))),
2032
             nullptr)});
2033

2034
    stmt->accept(&analyzer);
2035

2036
    // Output C should have 3 dependencies, each of the three stores.
2037
    auto outputAccess = analyzer.output(c.node());
2038
    ASSERT_NE(outputAccess, nullptr);
2039
    ASSERT_EQ(outputAccess->dependencies().size(), 2);
2040

2041
    // C depends indirectly on A and B.
2042
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2043
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2044
  }
2045

2046
  {
2047
    /* for (int x = 0; x < 10; x++) {
2048
     *   C[x] = A[x];
2049
     * }
2050
     * if (y<5 ? 1 : 0) {
2051
     * } else {
2052
     *   for (int x = 0; x < 10; x++) {
2053
     *     C[x] = (B[x]) + 1;
2054
     *   }
2055
     * }
2056
     */
2057

2058
    // Only has false branch.
2059

2060
    MemDependencyChecker analyzer({a, b}, {c});
2061
    StmtPtr stmt = Block::make(
2062
        {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2063
         Cond::make(
2064
             CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2065
             nullptr,
2066
             For::make(
2067
                 x,
2068
                 0,
2069
                 10,
2070
                 Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
2071

2072
    stmt->accept(&analyzer);
2073

2074
    // Output C should have 3 dependencies, each of the three stores.
2075
    auto outputAccess = analyzer.output(c.node());
2076
    ASSERT_NE(outputAccess, nullptr);
2077
    ASSERT_EQ(outputAccess->dependencies().size(), 2);
2078

2079
    // C depends indirectly on A and B.
2080
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2081
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2082
  }
2083

2084
  {
2085
    /* for (int x = 0; x < 10; x++) {
2086
     *   C[x] = A[x];
2087
     * }
2088
     * if (C[0]<5 ? 1 : 0) {
2089
     *   C[0] = 5;
2090
     * }
2091
     */
2092

2093
    // Cond's Condition depends on a previous access.
2094

2095
    MemDependencyChecker analyzer({a}, {c});
2096
    StorePtr initStore = Store::make(c, {x}, Load::make(a, {x}));
2097
    ExprHandle conditionalLoad = Load::make(c, {0});
2098
    StmtPtr stmt = Block::make(
2099
        {For::make(x, 0, 10, initStore),
2100
         Cond::make(
2101
             CompareSelect::make(
2102
                 conditionalLoad, 5, CompareSelectOperation::kLT),
2103
             Store::make(c, {0}, 5),
2104
             nullptr)});
2105

2106
    stmt->accept(&analyzer);
2107

2108
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2109

2110
    ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore));
2111
    ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node()));
2112
    ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node()));
2113
  }
2114
}
2115

2116
// Stmts using IfThenElse.
2117
TEST(MemDependency, MemDependencyCheckerIfThenElse) {
2118
  BufHandle a("A", {10}, kInt);
2119
  BufHandle b("B", {10}, kInt);
2120
  BufHandle c("C", {10}, kInt);
2121
  VarHandle x("x", kInt);
2122
  VarHandle y("y", kInt);
2123

2124
  using namespace analysis;
2125

2126
  {
2127
    /* for (int x = 0; x < 10; x++) {
2128
     *   C[x] = A[x];
2129
     * }
2130
     * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1;
2131
     */
2132

2133
    // Future usages may depend on accesses in both branches of a condition.
2134

2135
    MemDependencyChecker analyzer({a, b}, {c});
2136
    StorePtr ifStore = Store::make(
2137
        c,
2138
        {0},
2139
        IfThenElse::make(
2140
            CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2141
            Add::make(Load::make(b, {0}), 1),
2142
            Add::make(Load::make(b, {1}), 1)));
2143
    StmtPtr stmt = Block::make(
2144
        {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2145
         ifStore});
2146

2147
    stmt->accept(&analyzer);
2148

2149
    // Output C should have 2 dependencies, each of the two stores.
2150
    auto outputAccess = analyzer.output(c.node());
2151
    ASSERT_NE(outputAccess, nullptr);
2152
    ASSERT_EQ(outputAccess->dependencies().size(), 2);
2153

2154
    // Now we need to check the Store containing the IfThenElse.
2155
    auto ifStoreAccess = analyzer.accessFor(ifStore);
2156

2157
    // It should have 2 dependencies.
2158
    ASSERT_EQ(ifStoreAccess->dependencies().size(), 2);
2159

2160
    // C depends indirectly on A and B.
2161
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2162
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2163
  }
2164

2165
  {
2166
    /* for (int x = 0; x < 10; x++) {
2167
     *   C[x] = A[x];
2168
     * }
2169
     * C[0] = (y < 5 ? (B[0]) + 1 : 42;
2170
     */
2171

2172
    // If the load appears in only one side of an IfThenElse the output may be
2173
    // dependent on it.
2174

2175
    MemDependencyChecker analyzer({a, b}, {c});
2176
    StorePtr ifStore = Store::make(
2177
        c,
2178
        {0},
2179
        IfThenElse::make(
2180
            CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2181
            Add::make(Load::make(b, {0}), 1),
2182
            42));
2183
    StmtPtr stmt = Block::make(
2184
        {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2185
         ifStore});
2186

2187
    stmt->accept(&analyzer);
2188

2189
    // C depends indirectly on A and B.
2190
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2191
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2192
  }
2193

2194
  {
2195
    /* for (int x = 0; x < 10; x++) {
2196
     *   C[x] = (x < 5 ? B[x] : A[x];
2197
     * }
2198
     */
2199

2200
    // In this case C is dependent on both A and B.
2201

2202
    // TODO: in cases like this it would be possible to split the range of B
2203
    // into two bounds, one dependent on A and one dependent on B. We'd need to
2204
    // examine conditions relative to previously encountered loop variables. I'm
2205
    // uncertain if this would be helpful.
2206

2207
    MemDependencyChecker analyzer({a, b}, {c});
2208
    StorePtr ifStore = Store::make(
2209
        c,
2210
        {0},
2211
        IfThenElse::make(
2212
            CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2213
            Load::make(b, {x}),
2214
            Load::make(a, {x})));
2215
    StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)});
2216

2217
    stmt->accept(&analyzer);
2218

2219
    // C depends indirectly on A and B.
2220
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2221
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2222
  }
2223
}
2224

2225
// Cutting a loop with single elem writes
2226
TEST(MemDependency, MemDependencyCheckerCutLoop) {
2227
  BufHandle a("A", {10}, kInt);
2228
  BufHandle b("B", {10}, kInt);
2229
  VarHandle x("x", kInt);
2230

2231
  using namespace analysis;
2232

2233
  {
2234
    /* for (int x = 0; x < 10; x++) {
2235
     *   B[x] = A[x];
2236
     * }
2237
     * B[5] = 100;
2238
     */
2239

2240
    // Cutting a loop with single element writes.
2241

2242
    MemDependencyChecker analyzer({a}, {b});
2243
    StmtPtr stmt = Block::make(
2244
        {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))),
2245
         Store::make(b, {5}, 100)});
2246

2247
    stmt->accept(&analyzer);
2248

2249
    // Output depends on input.
2250
    ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2251

2252
    // Output has 2 dependencies.
2253
    auto outputAccess = analyzer.output(b.node());
2254
    ASSERT_NE(outputAccess, nullptr);
2255
    ASSERT_EQ(outputAccess->dependencies().size(), 2);
2256
  }
2257

2258
  {
2259
    /* for (int x = 0; x < 10; x++) {
2260
     *   B[x] = A[x];
2261
     * }
2262
     * for (int x = 4; x < 7; x++) {
2263
     *   B[x] = B[x] + 3;
2264
     * }
2265
     * B[5] = 100;
2266
     * B[6] = 101;
2267
     * B[7] = 102;
2268
     */
2269

2270
    // Cutting a loop with a smaller loop but then totally overlap that second
2271
    // loop with one element writes.
2272

2273
    MemDependencyChecker analyzer({a}, {b});
2274
    ForPtr firstLoop =
2275
        For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})));
2276
    StorePtr secondStore =
2277
        Store::make(b, {x}, Add::make(Load::make(b, {x}), 1));
2278
    ForPtr secondLoop = For::make(x, 4, 7, secondStore);
2279

2280
    StmtPtr stmt = Block::make(
2281
        {firstLoop,
2282
         secondLoop,
2283
         Store::make(b, {4}, 100),
2284
         Store::make(b, {5}, 101),
2285
         Store::make(b, {6}, 102)});
2286

2287
    stmt->accept(&analyzer);
2288

2289
    // Output depends on input.
2290
    ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2291

2292
    // Output has 4 dependencies.
2293
    auto outputAccess = analyzer.output(b.node());
2294
    ASSERT_NE(outputAccess, nullptr);
2295
    ASSERT_EQ(outputAccess->dependencies().size(), 4);
2296

2297
    // Second loop depends on first loop.
2298
    ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop));
2299

2300
    // Output does not depend on second loop or store.
2301
    ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop));
2302
    ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore));
2303
  }
2304
}
2305

2306
// Dynamic shapes (load in indices).
2307
TEST(MemDependency, MemDependencyCheckerDynamicShapes) {
2308
  BufHandle a("A", {100}, kInt);
2309
  BufHandle b("B", {100}, kInt);
2310
  BufHandle c("C", {100}, kInt);
2311
  VarHandle x("x", kInt);
2312

2313
  using namespace analysis;
2314

2315
  auto CB = [](ExprHandle s, ExprHandle e) {
2316
    return Bound(s.node(), e.node());
2317
  };
2318

2319
  auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
2320
    return indexBoundsEquals(x, y);
2321
  };
2322

2323
  {
2324
    /* for (int x = 0; x < B[0]; x++) {
2325
     *   C[x] = A[x];
2326
     * }
2327
     */
2328
    MemDependencyChecker analyzer({a, b}, {c});
2329
    StmtPtr stmt = Block::make({For::make(
2330
        x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))});
2331

2332
    stmt->accept(&analyzer);
2333

2334
    /*  0. Input: B[(0, 99)] - dependents: 2
2335
     *  1. Input: A[(0, 99)] - dependents: 3
2336
     *  2. Load: B[(0, 0)] - depends on: 0  - dependents: 3 4
2337
     *  3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2  - dependents: 4
2338
     *  4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3  - dependents: 5
2339
     *  5. Output: C[(0, 99)] - depends on: 4
2340
     */
2341

2342
    // Output dependent on A input.
2343
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2344
    // Also dependent on B input to determine the size of the region written.
2345
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2346

2347
    auto history = analyzer.getHistory();
2348
    ASSERT_EQ(history.size(), 6);
2349

2350
    // The accesses in the loop depend on the load in the stop condition.
2351
    ASSERT_TRUE(history[4]->hasDependency(history[2]));
2352
    ASSERT_TRUE(history[3]->hasDependency(history[2]));
2353

2354
    // Make a load from B to compare against.
2355
    ExprHandle loadFromB = Load::make(b, {0});
2356

2357
    ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)}));
2358
    ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)}));
2359
  }
2360

2361
  {
2362
    /* for (int x = B[0]; x < B[1]; x++) {
2363
     *   C[x] = A[x];
2364
     * }
2365
     */
2366
    MemDependencyChecker analyzer({a, b}, {c});
2367
    StmtPtr stmt = Block::make({For::make(
2368
        x,
2369
        Load::make(b, {0}),
2370
        Load::make(b, {1}),
2371
        Store::make(c, {x}, Load::make(a, {x})))});
2372

2373
    stmt->accept(&analyzer);
2374

2375
    /*  0. Input: B[(0, 99)] - dependents: 2 3
2376
     *  1. Input: A[(0, 99)] - dependents: 4
2377
     *  2. Load: B[(0, 0)] - depends on: 0  - dependents: 4 5
2378
     *  3. Load: B[(1, 1)] - depends on: 0  - dependents: 4 5
2379
     *  4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3  - dependents: 5
2380
     *  5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4  - dependents: 6
2381
     *  6. Output: C[(0, 99)] - depends on: 5
2382
     */
2383

2384
    // Sanity check output depends on input.
2385
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2386
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2387

2388
    auto history = analyzer.getHistory();
2389
    ASSERT_EQ(history.size(), 7);
2390

2391
    // The accesses in the loop depend on the load in the start condition.
2392
    ASSERT_TRUE(history[5]->hasDependency(history[2]));
2393
    ASSERT_TRUE(history[4]->hasDependency(history[2]));
2394

2395
    // also the stop condition.
2396
    ASSERT_TRUE(history[5]->hasDependency(history[3]));
2397
    ASSERT_TRUE(history[4]->hasDependency(history[3]));
2398

2399
    // Make loads from B to compare against.
2400
    ExprHandle loadFromB0 = Load::make(b, {0});
2401
    ExprHandle loadFromB1 = Load::make(b, {1});
2402
    ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
2403
    ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
2404
  }
2405

2406
  {
2407
    /* for (int x = 0; x < 10; x++) {
2408
     *   C[x] = A[B[x]];
2409
     * }
2410
     */
2411
    MemDependencyChecker analyzer({a, b}, {c});
2412
    StmtPtr stmt = Block::make({For::make(
2413
        x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))});
2414

2415
    stmt->accept(&analyzer);
2416

2417
    /*  0. Input: B[(0, 99)] - dependents: 2
2418
     *  1. Input: A[(0, 99)] - dependents: 3
2419
     *  2. Load: B[(0, 9)] - depends on: 0  - dependents: 3 4
2420
     *  3. Load: A[(B[0], B[9])] - depends on: 1 2  - dependents: 4
2421
     *  4. Store: C[(0, 9)] - depends on: 2 3  - dependents: 5
2422
     *  5. Output: C[(0, 99)] - depends on: 4
2423
     */
2424

2425
    // Sanity check output depends on input.
2426
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2427
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2428

2429
    auto history = analyzer.getHistory();
2430
    ASSERT_EQ(history.size(), 6);
2431

2432
    // The store depends on both loads, the load of A depends on the load of B.
2433
    ASSERT_TRUE(history[4]->hasDependency(history[2]));
2434
    ASSERT_TRUE(history[4]->hasDependency(history[3]));
2435

2436
    ASSERT_TRUE(history[3]->hasDependency(history[2]));
2437

2438
    // The loads in the indices depend on the relevant input buffer.
2439
    ASSERT_TRUE(history[3]->hasDependency(history[1]));
2440
    ASSERT_TRUE(history[2]->hasDependency(history[0]));
2441

2442
    // The load from B has the loop bounds.
2443
    ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2444

2445
    // The load from A has bounds B[0] to B[9].
2446
    ExprHandle loadFromB0 = Load::make(b, {0});
2447
    ExprHandle loadFromB9 = Load::make(b, {9});
2448
    ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)}));
2449
  }
2450

2451
  {
2452
    /* for (int x = 0; x < 10; x++) {
2453
     *   C[B[x]] = A[x];
2454
     * }
2455
     */
2456
    MemDependencyChecker analyzer({a, b}, {c});
2457
    StmtPtr stmt = Block::make({For::make(
2458
        x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))});
2459

2460
    stmt->accept(&analyzer);
2461

2462
    /*  0. Input: B[(0, 99)] - dependents: 3
2463
     *  1. Input: A[(0, 99)] - dependents: 2
2464
     *  2. Load: A[(0, 9)] - depends on: 1  - dependents: 4
2465
     *  3. Load: B[(0, 9)] - depends on: 0  - dependents: 4
2466
     *  4. Store: C[(B[0], B[9])] - depends on: 2 3  - dependents: 5
2467
     *  5. Output: C[(0, 99)] - depends on: 4
2468
     */
2469
    // Sanity check output depends on input.
2470
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2471
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2472

2473
    auto history = analyzer.getHistory();
2474
    ASSERT_EQ(history.size(), 6);
2475

2476
    // The store depends on both loads, neither load is dependent.
2477
    ASSERT_TRUE(history[4]->hasDependency(history[2]));
2478
    ASSERT_TRUE(history[4]->hasDependency(history[3]));
2479

2480
    ASSERT_FALSE(history[3]->hasDependency(history[2]));
2481
    ASSERT_FALSE(history[2]->hasDependency(history[3]));
2482

2483
    // The loads each depend on their relevant input. (but accesses are in a
2484
    // different order than the last case).
2485
    ASSERT_TRUE(history[3]->hasDependency(history[0]));
2486
    ASSERT_TRUE(history[2]->hasDependency(history[1]));
2487

2488
    // The load from B has the loop bounds.
2489
    ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)}));
2490

2491
    // And so does the load from A.
2492
    ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2493
  }
2494

2495
  {
2496
    /* for (int x = 0; x < 10; x++) {
2497
     *   C[B[A[x]]] = x;
2498
     * }
2499
     */
2500
    MemDependencyChecker analyzer({a, b}, {c});
2501
    StmtPtr stmt = Block::make({For::make(
2502
        x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))});
2503

2504
    stmt->accept(&analyzer);
2505

2506
    /*  0. Input: B[(0, 99)] - dependents: 3
2507
     *  1. Input: A[(0, 99)] - dependents: 2
2508
     *  2. Load: A[(0, 9)] - depends on: 1  - dependents: 3 4
2509
     *  3. Load: B[(A[0], A[9])] - depends on: 0 2  - dependents: 4
2510
     *  4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3  - dependents: 5
2511
     *  5. Output: C[(0, 99)] - depends on: 4
2512
     */
2513

2514
    // Sanity check output depends on input.
2515
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2516
    ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2517

2518
    auto history = analyzer.getHistory();
2519
    ASSERT_EQ(history.size(), 6);
2520

2521
    // The store depends on both loads.
2522
    ASSERT_TRUE(history[4]->hasDependency(history[2]));
2523
    ASSERT_TRUE(history[4]->hasDependency(history[3]));
2524

2525
    // The outer load depends on the inner.
2526
    ASSERT_TRUE(history[3]->hasDependency(history[2]));
2527

2528
    // The loads each depend on their relevant input. (but accesses are in a
2529
    // different order than the last case).
2530
    ASSERT_TRUE(history[3]->hasDependency(history[0]));
2531
    ASSERT_TRUE(history[2]->hasDependency(history[1]));
2532

2533
    // The load from A has the loop bounds.
2534
    ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2535
    // The load from B as bounds A[0] to A[9].
2536
    ExprHandle loadFromA0 = Load::make(a, {0});
2537
    ExprHandle loadFromA9 = Load::make(a, {9});
2538
    ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)}));
2539

2540
    // The store has bounds of B[A[0]] to B[A[9]].
2541
    ExprHandle loadFromBA0 = Load::make(b, {loadFromA0});
2542
    ExprHandle loadFromBA9 = Load::make(b, {loadFromA9});
2543
    ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)}));
2544
  }
2545
}
2546

2547
// Verify multi dimensional bounds work.
2548
TEST(MemDependency, MemDependencyCheckerMultiDim) {
2549
  int M = 10, N = 9, K = 12;
2550
  BufHandle a("A", {M, N, K}, kInt);
2551
  BufHandle b("B", {M, N, K}, kInt);
2552
  BufHandle c("C", {M, K}, kInt);
2553
  VarHandle x("x", kInt);
2554
  VarHandle y("y", kInt);
2555
  VarHandle z("z", kInt);
2556

2557
  using namespace analysis;
2558

2559
  auto CB = [](ExprHandle s, ExprHandle e) {
2560
    return Bound(s.node(), e.node());
2561
  };
2562

2563
  auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
2564
    return indexBoundsEquals(x, y);
2565
  };
2566

2567
  {
2568
    /* for (int x = 0; x < 10; x++) {
2569
     *   for (int y = 0; y < 9; y++) {
2570
     *     for (int z = 0; z < 12; z++) {
2571
     *       B[x, y, z] = A[x, y, z];
2572
     *     }
2573
     *   }
2574
     * }
2575
     */
2576
    // Full range.
2577

2578
    MemDependencyChecker analyzer({a}, {b});
2579
    StmtPtr stmt = Block::make({For::make(
2580
        x,
2581
        0,
2582
        M,
2583
        For::make(
2584
            y,
2585
            0,
2586
            N,
2587
            For::make(
2588
                z,
2589
                0,
2590
                K,
2591
                Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
2592

2593
    stmt->accept(&analyzer);
2594

2595
    // Sanity test: Output depends on input.
2596
    ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2597

2598
    // 4 accesses: input, load, store, output.
2599
    auto history = analyzer.getHistory();
2600
    ASSERT_EQ(history.size(), 4);
2601

2602
    // Simple chain from input to output.
2603
    ASSERT_TRUE(history[3]->hasDependency(history[2]));
2604
    ASSERT_TRUE(history[2]->hasDependency(history[1]));
2605
    ASSERT_TRUE(history[1]->hasDependency(history[0]));
2606

2607
    ASSERT_TRUE(
2608
        EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2609
    ASSERT_TRUE(
2610
        EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2611
  }
2612

2613
  {
2614
    /* for (int x = 0; x < 5; x++) {
2615
     *   for (int y = 0; y < 5; y++) {
2616
     *     for (int z = 0; z < 5; z++) {
2617
     *       B[x, y, z] = A[x, y, z];
2618
     *     }
2619
     *   }
2620
     * }
2621
     */
2622
    // Partial range.
2623

2624
    MemDependencyChecker analyzer({a}, {b});
2625
    StmtPtr stmt = Block::make({For::make(
2626
        x,
2627
        0,
2628
        5,
2629
        For::make(
2630
            y,
2631
            0,
2632
            5,
2633
            For::make(
2634
                z,
2635
                0,
2636
                5,
2637
                Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
2638

2639
    stmt->accept(&analyzer);
2640

2641
    // Sanity test: Output depends on input.
2642
    ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2643

2644
    // 4 accesses: input, load, store, output.
2645
    auto history = analyzer.getHistory();
2646
    ASSERT_EQ(history.size(), 4);
2647

2648
    // Simple chain from input to output.
2649
    ASSERT_TRUE(history[3]->hasDependency(history[2]));
2650
    ASSERT_TRUE(history[2]->hasDependency(history[1]));
2651
    ASSERT_TRUE(history[1]->hasDependency(history[0]));
2652

2653
    ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
2654
    ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
2655
  }
2656

2657
  {
2658
    /* for (int x = 0; x < 10; x++) {
2659
     *   for (int y = 0; y < 12; y++) {
2660
     *     B[x, 0, y] = A[x, 0, y];
2661
     *   }
2662
     * }
2663
     */
2664

2665
    // Partial loops.
2666

2667
    MemDependencyChecker analyzer({a}, {b});
2668
    StmtPtr stmt = Block::make({For::make(
2669
        x,
2670
        0,
2671
        N,
2672
        For::make(
2673
            y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))});
2674

2675
    stmt->accept(&analyzer);
2676

2677
    // Sanity test: Output depends on input.
2678
    ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2679

2680
    // 4 accesses: input, load, store, output.
2681
    auto history = analyzer.getHistory();
2682
    ASSERT_EQ(history.size(), 4);
2683

2684
    // Simple chain from input to output.
2685
    ASSERT_TRUE(history[3]->hasDependency(history[2]));
2686
    ASSERT_TRUE(history[2]->hasDependency(history[1]));
2687
    ASSERT_TRUE(history[1]->hasDependency(history[0]));
2688

2689
    ASSERT_TRUE(
2690
        EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
2691
    ASSERT_TRUE(
2692
        EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
2693
  }
2694

2695
  {
2696
    /* for (int x = 0; x < 10; x++) {
2697
     *   for (int y = 0; y < 100; y++) {
2698
     *     for (int z = 0; z < 12; z++) {
2699
     *       B[x, 0, z] = (A[x, 0, z]) + (C[x, z]);
2700
     *     }
2701
     *   }
2702
     * }
2703
     */
2704

2705
    // Loops that don't correspond to an index, bufs with different
2706
    // dimensionality.
2707

2708
    MemDependencyChecker analyzer({a, c}, {b});
2709
    StmtPtr stmt = Block::make({For::make(
2710
        x,
2711
        0,
2712
        M,
2713
        For::make(
2714
            y,
2715
            0,
2716
            100,
2717
            For::make(
2718
                z,
2719
                0,
2720
                K,
2721
                Store::make(
2722
                    b,
2723
                    {x, 0, z},
2724
                    Add::make(
2725
                        Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))});
2726

2727
    stmt->accept(&analyzer);
2728

2729
    // Sanity test: Output depends on both inputs.
2730
    ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2731
    ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node()));
2732

2733
    // 6 accesses: 2 inputs, 2 loads, store, output.
2734
    auto history = analyzer.getHistory();
2735
    ASSERT_EQ(history.size(), 6);
2736

2737
    // Simple chain from input to output over the A buf.
2738
    // history[0] is the C input, history[3] is the load from C.
2739
    ASSERT_TRUE(history[5]->hasDependency(history[4]));
2740
    ASSERT_TRUE(history[4]->hasDependency(history[2]));
2741
    ASSERT_TRUE(history[2]->hasDependency(history[1]));
2742
    // The store also depends on the load from the C input.
2743
    ASSERT_TRUE(history[4]->hasDependency(history[3]));
2744
    ASSERT_TRUE(history[3]->hasDependency(history[0]));
2745

2746
    // A Buf accesses.
2747
    ASSERT_TRUE(
2748
        EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
2749
    ASSERT_TRUE(
2750
        EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
2751

2752
    // C buf access.
2753
    ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)}));
2754
  }
2755

2756
  {
2757
    /* for (int x = 0; x < 9; x++) {
2758
     *   for (int y = 0; y < 10; y++) {
2759
     *     for (int z = 0; z < 12; z++) {
2760
     *       B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]);
2761
     *     }
2762
     *   }
2763
     * }
2764
     */
2765
    // Multi-dim reductions.
2766

2767
    MemDependencyChecker analyzer({a}, {b});
2768
    StmtPtr stmt = Block::make({For::make(
2769
        x,
2770
        0,
2771
        M,
2772
        For::make(
2773
            y,
2774
            0,
2775
            N,
2776
            For::make(
2777
                z,
2778
                0,
2779
                K,
2780
                Store::make(
2781
                    b,
2782
                    {x, 0, 0},
2783
                    Add::make(
2784
                        Load::make(b, {x, y, z}),
2785
                        Load::make(a, {x, y, z}))))))});
2786

2787
    stmt->accept(&analyzer);
2788

2789
    // Sanity test: Output depends on input.
2790
    ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2791

2792
    // 4 accesses: input, 2 loads, store, output.
2793
    auto history = analyzer.getHistory();
2794
    ASSERT_EQ(history.size(), 5);
2795

2796
    // Simple chain from input to output.
2797
    ASSERT_TRUE(history[4]->hasDependency(history[3]));
2798
    ASSERT_TRUE(history[3]->hasDependency(history[2]));
2799
    ASSERT_TRUE(history[3]->hasDependency(history[1]));
2800
    ASSERT_TRUE(history[2]->hasDependency(history[0]));
2801

2802
    // The load from B depends on the store to B.
2803
    ASSERT_TRUE(history[1]->hasDependency(history[3]));
2804

2805
    ASSERT_TRUE(
2806
        EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2807
    ASSERT_TRUE(
2808
        EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2809
    ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)}));
2810
  }
2811
}
2812

2813
// Various tests using the external Compute/Reduce API.
2814
TEST(MemDependency, MemDependencyCheckerComputeAPI) {
2815
  using namespace analysis;
2816

2817
  /* for (int m = 0; m < 4; m++) {
2818
   *   for (int n = 0; n < 5; n++) {
2819
   *     for (int k = 0; k < 6; k++) {
2820
   *       broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]);
2821
   *     }
2822
   *   }
2823
   * }
2824
   * for (int m_1 = 0; m_1 < 4; m_1++) {
2825
   *   for (int n_1 = 0; n_1 < 5; n_1++) {
2826
   *     for (int k_1 = 0; k_1 < 6; k_1++) {
2827
   *       d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1);
2828
   *     }
2829
   *   }
2830
   * }
2831
   */
2832

2833
  // Can determine if 2 loops created by Compute are dependent.
2834
  BufHandle a_buf("a", {4, 5}, kFloat);
2835
  BufHandle b_buf("b", {5, 6}, kFloat);
2836
  Tensor c = Compute(
2837
      "broadcast_add",
2838
      {4, 5, 6},
2839
      [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2840
        return a_buf.load(m, n) + b_buf.load(n, k);
2841
      });
2842
  Tensor d = Compute(
2843
      "d",
2844
      {4, 5, 6},
2845
      [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2846
        return c.load(m, n, k) + 1;
2847
      });
2848

2849
  LoopNest l({d}, {c, d});
2850

2851
  MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
2852

2853
  l.root_stmt()->accept(&analyzer);
2854

2855
  // Sanity test: Output depends on input.
2856
  ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
2857
  ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
2858

2859
  // Second loop depends on first loop.
2860
  auto c_loop = l.getLoopStmtsFor(c)[0];
2861
  auto d_loop = l.getLoopStmtsFor(d)[0];
2862
  ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
2863
}
2864

2865
TEST(MemDependency, MemDependencyCheckerComputeInline) {
2866
  using namespace analysis;
2867

2868
  /* for (int m = 0; m < 4; m++) {
2869
   *   for (int n = 0; n < 5; n++) {
2870
   *     for (int k = 0; k < 6; k++) {
2871
   *       d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1);
2872
   *     }
2873
   *   }
2874
   * }
2875
   */
2876

2877
  // Check inlining affects the number of accesses returned.
2878

2879
  BufHandle a_buf("a", {4, 5}, kFloat);
2880
  BufHandle b_buf("b", {5, 6}, kFloat);
2881
  Tensor c = Compute(
2882
      "broadcast_add",
2883
      {4, 5, 6},
2884
      [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2885
        return a_buf.load(m, n) + b_buf.load(n, k);
2886
      });
2887
  Tensor d = Compute(
2888
      "d",
2889
      {4, 5, 6},
2890
      [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2891
        return c.load(m, n, k) + 1;
2892
      });
2893

2894
  LoopNest l({d}, {c, d});
2895
  l.computeInline(c.buf());
2896

2897
  MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
2898
  l.root_stmt()->accept(&analyzer);
2899

2900
  // Sanity test: Output depends on input.
2901
  ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
2902
  ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
2903

2904
  // broadcast_add tensor should not appear in trace at all.
2905
  for (auto& wi : analyzer.getHistory()) {
2906
    ASSERT_NE(wi->var(), c.buf()->base_handle());
2907
  }
2908
}
2909

2910
TEST(MemDependency, MemDependencyCheckerComputeSplit) {
2911
  using namespace analysis;
2912
  // Split an axis, so the number of loops != the number of dimensions.
2913

2914
  BufHandle a_buf("a", {4, 5}, kFloat);
2915
  BufHandle b_buf("b", {5, 6}, kFloat);
2916
  Tensor c = Compute(
2917
      "broadcast_add",
2918
      {4, 5, 6},
2919
      [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2920
        return a_buf.load(m, n) + b_buf.load(n, k);
2921
      });
2922

2923
  LoopNest l({c});
2924

2925
  MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
2926
  l.root_stmt()->accept(&analyzer_before);
2927

2928
  l.splitWithTail(l.getLoopStmtsFor(c)[0], 2);
2929

2930
  MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
2931
  StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2932
  stmt->accept(&analyzer_after);
2933

2934
  // Splitting should not change accesses at all.
2935
  auto history_before = analyzer_before.getHistory();
2936
  auto history_after = analyzer_after.getHistory();
2937

2938
  ASSERT_EQ(history_before.size(), history_after.size());
2939

2940
  for (size_t i = 0; i < history_before.size(); ++i) {
2941
    ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
2942
    ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
2943
    ASSERT_EQ(
2944
        history_before[i]->bounds().size(), history_after[i]->bounds().size());
2945
    ASSERT_TRUE(indexBoundsEquals(
2946
        history_before[i]->bounds(), history_after[i]->bounds()));
2947
    ASSERT_EQ(
2948
        history_before[i]->dependencies().size(),
2949
        history_after[i]->dependencies().size());
2950
    ASSERT_EQ(
2951
        history_before[i]->dependents().size(),
2952
        history_after[i]->dependents().size());
2953
  }
2954
}
2955

2956
TEST(MemDependency, MemDependencyCheckerComputeReorder) {
2957
  using namespace analysis;
2958
  // Reorder an axis, so the loop order doesn't match the indexing order.
2959

2960
  BufHandle a_buf("a", {4, 5}, kFloat);
2961
  BufHandle b_buf("b", {5, 6}, kFloat);
2962
  Tensor c = Compute(
2963
      "broadcast_add",
2964
      {4, 5, 6},
2965
      [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2966
        return a_buf.load(m, n) + b_buf.load(n, k);
2967
      });
2968

2969
  LoopNest l({c});
2970

2971
  MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
2972
  l.root_stmt()->accept(&analyzer_before);
2973

2974
  auto loops = l.getLoopStmtsFor(c);
2975
  l.reorderAxis(loops[0], loops[1]);
2976

2977
  MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
2978
  StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2979
  stmt->accept(&analyzer_after);
2980

2981
  // Reordering should not change accesses at all.
2982
  auto history_before = analyzer_before.getHistory();
2983
  auto history_after = analyzer_after.getHistory();
2984

2985
  ASSERT_EQ(history_before.size(), history_after.size());
2986

2987
  for (size_t i = 0; i < history_before.size(); ++i) {
2988
    ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
2989
    ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
2990
    ASSERT_EQ(
2991
        history_before[i]->bounds().size(), history_after[i]->bounds().size());
2992
    ASSERT_TRUE(indexBoundsEquals(
2993
        history_before[i]->bounds(), history_after[i]->bounds()));
2994
    ASSERT_EQ(
2995
        history_before[i]->dependencies().size(),
2996
        history_after[i]->dependencies().size());
2997
    ASSERT_EQ(
2998
        history_before[i]->dependents().size(),
2999
        history_after[i]->dependents().size());
3000
  }
3001
}
3002

3003
TEST(MemDependency, MemDependencyCheckerComputeReduce) {
3004
  using namespace analysis;
3005
  /* for (int l2 = 0; l2 < 2; l2++) {
3006
   *   for (int n1 = 0; n1 < 3; n1++) {
3007
   *     for (int m1 = 0; m1 < 6; m1++) {
3008
   *       scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]);
3009
   *     }
3010
   *   }
3011
   * }
3012
   * for (int l1 = 0; l1 < 2; l1++) {
3013
   *   sum[l1] = float(0);
3014
   *   for (int n1_1 = 0; n1_1 < 3; n1_1++) {
3015
   *     for (int m1_1 = 0; m1_1 < 6; m1_1++) {
3016
   *       sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)),
3017
   *                    out_args={l1}, reduce_args={n1, m1});
3018
   *     }
3019
   *   }
3020
   * }
3021
   */
3022

3023
  // Can determine dependencies of a Reduction.
3024

3025
  BufHandle a("a", {2, 3, 6}, kFloat);
3026
  BufHandle b("b", {2, 3, 6}, kFloat);
3027

3028
  Tensor c = Compute(
3029
      "scale",
3030
      {2, 3, 6},
3031
      [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
3032
        return b.load(l, n, m) * a.load(l, n, m);
3033
      });
3034
  Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6});
3035
  LoopNest l({d}, {c, d});
3036

3037
  MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()});
3038

3039
  l.root_stmt()->accept(&analyzer);
3040

3041
  // Sanity test: Output depends on input.
3042
  ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node()));
3043
  ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node()));
3044

3045
  // Second loop depends on first loop.
3046
  auto c_loop = l.getLoopStmtsFor(c)[0];
3047
  auto d_loop = l.getLoopStmtsFor(d)[0];
3048
  ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
3049

3050
  // Reduction depends on both inputs.
3051
  auto reduces = NodeFinder<ReduceOp>::find(l.root_stmt());
3052
  ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node()));
3053
  ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node()));
3054
}
3055

3056
TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
3057
  int M = 1024;
3058
  int N = 1024;
3059
  int K = 2048;
3060
  using namespace analysis;
3061

3062
  BufHandle AP("A", {M, K}, kFloat);
3063
  BufHandle BP("B", {K, N}, kFloat);
3064
  Tensor CT = Reduce(
3065
      "gemm",
3066
      {M, N},
3067
      Sum(),
3068
      [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
3069
        return AP.load(m, k) * BP.load(k, n);
3070
      },
3071
      {K});
3072
  LoopNest loop({CT});
3073

3074
  {
3075
    auto const& loops = loop.getLoopStmtsFor(CT);
3076
    ForPtr m = loops[0];
3077
    loop.splitWithMask(m, 4);
3078
  }
3079
  {
3080
    auto const& loops = loop.getLoopStmtsFor(CT);
3081
    ForPtr n = loops[2];
3082
    loop.splitWithMask(n, 16);
3083
  }
3084
  // mo, mi, no, ni, k ->
3085
  // mo, no, mi, ni, k
3086
  {
3087
    auto const& loops = loop.getLoopStmtsFor(CT);
3088
    ForPtr mi = loops[1];
3089
    ForPtr no = loops[2];
3090
    loop.reorderAxis(mi, no);
3091
  }
3092
  // mo, no, mi, ni, k ->
3093
  // mo, no, mi, k, ni
3094
  {
3095
    auto const& loops = loop.getLoopStmtsFor(CT);
3096
    ForPtr ni = loops[3];
3097
    ForPtr k = loops[4];
3098
    loop.reorderAxis(ni, k);
3099
  }
3100
  // mo, no, mi, k, ni ->
3101
  // mo, no, k, mi, ni
3102
  {
3103
    auto const& loops = loop.getLoopStmtsFor(CT);
3104
    ForPtr mi = loops[2];
3105
    ForPtr k = loops[3];
3106
    loop.reorderAxis(mi, k);
3107
  }
3108
  {
3109
    auto const& loops = loop.getLoopStmtsFor(CT);
3110
    loop.cacheAccesses(CT.buf(), "C_regs", loops[2]);
3111
  }
3112

3113
  MemDependencyChecker analyzer_unlowered(
3114
      loop.getInputBufs(), loop.getOutputBufs());
3115

3116
  MemDependencyChecker analyzer_lowered(
3117
      loop.getInputBufs(), loop.getOutputBufs());
3118

3119
  // Test both unlowered and lowered form.
3120
  {
3121
    StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt());
3122
    stmt->accept(&analyzer_unlowered);
3123

3124
    // Outputs depend on inputs.
3125
    ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node()));
3126
    ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node()));
3127

3128
    // The last write to gemm should cover the total bound of the output.
3129
    std::shared_ptr<AccessInfo> outputAccess =
3130
        analyzer_unlowered.output(CT.buf());
3131
    // A single dependency.
3132
    ASSERT_EQ(outputAccess->dependencies().size(), 1);
3133

3134
    // dependencies is a set with 1 element, so can just deref begin().
3135
    std::shared_ptr<AccessInfo> gemmStore =
3136
        outputAccess->dependencies().begin()->second;
3137
    // Check its a store.
3138
    ASSERT_EQ(gemmStore->type(), AccessType::Store);
3139

3140
    ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds()));
3141

3142
    // Likewise the first read from each input cover the entire range of the
3143
    // input.
3144
    auto aInput = analyzer_unlowered.input(AP.node());
3145
    auto bInput = analyzer_unlowered.input(BP.node());
3146

3147
    // A single dependent each.
3148
    ASSERT_EQ(aInput->dependents().size(), 1);
3149
    ASSERT_EQ(bInput->dependents().size(), 1);
3150

3151
    // They're both loads.
3152
    std::shared_ptr<AccessInfo> aLoad = aInput->dependents().begin()->second;
3153
    std::shared_ptr<AccessInfo> bLoad = bInput->dependents().begin()->second;
3154
    ASSERT_EQ(aLoad->type(), AccessType::Load);
3155
    ASSERT_EQ(bLoad->type(), AccessType::Load);
3156

3157
    ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds()));
3158
    ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds()));
3159
  }
3160

3161
  loop.prepareForCodegen();
3162
  SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT});
3163

3164
  // now check lowered dependency graph.
3165
  {
3166
    StmtPtr stmt = IRSimplifier::simplify(cg.stmt());
3167
    stmt->accept(&analyzer_lowered);
3168

3169
    // Lowering will change the dimensionality of all bounds due to index
3170
    // flattening and will insert Allocates and Frees.
3171

3172
    auto history_before = analyzer_unlowered.getHistory();
3173
    auto history_after = analyzer_lowered.getHistory();
3174

3175
    ASSERT_EQ(history_before.size() + 2, history_after.size());
3176

3177
    // Filter out the alloc/free;
3178
    auto isAllocFree = [](const auto& info) {
3179
      return info->type() == AccessType::Alloc ||
3180
          info->type() == AccessType::Free;
3181
    };
3182
    history_after.erase(
3183
        std::remove_if(history_after.begin(), history_after.end(), isAllocFree),
3184
        history_after.end());
3185

3186
    ASSERT_EQ(history_before.size(), history_after.size());
3187

3188
    for (size_t i = 0; i < history_before.size(); ++i) {
3189
      ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
3190
      ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
3191

3192
      if (history_before[i]->dependencies().size() !=
3193
          history_after[i]->dependencies().size()) {
3194
        // Must depend on an Alloc.
3195
        ASSERT_TRUE(std::any_of(
3196
            history_after[i]->dependencies().begin(),
3197
            history_after[i]->dependencies().end(),
3198
            [](const auto& pair) {
3199
              return pair.second->type() == AccessType::Alloc;
3200
            }));
3201

3202
        ASSERT_EQ(
3203
            history_before[i]->dependencies().size() + 1,
3204
            history_after[i]->dependencies().size());
3205
      }
3206

3207
      if (history_before[i]->dependents().size() !=
3208
          history_after[i]->dependents().size()) {
3209
        // Must depend on an Free.
3210
        ASSERT_TRUE(std::any_of(
3211
            history_after[i]->dependents().begin(),
3212
            history_after[i]->dependents().end(),
3213
            [](const auto& pair) {
3214
              return pair.second->type() == AccessType::Free;
3215
            }));
3216

3217
        ASSERT_EQ(
3218
            history_before[i]->dependents().size() + 1,
3219
            history_after[i]->dependents().size());
3220
      }
3221

3222
      // Inputs and outputs are not flattened, only accesses.
3223
      if (history_before[i]->type() == AccessType::Input ||
3224
          history_before[i]->type() == AccessType::Output) {
3225
        ASSERT_EQ(
3226
            history_before[i]->bounds().size(),
3227
            history_after[i]->bounds().size());
3228
        ASSERT_TRUE(indexBoundsEquals(
3229
            history_before[i]->bounds(), history_after[i]->bounds()));
3230
      } else {
3231
        ASSERT_EQ(history_after[i]->bounds().size(), 1);
3232
        ExprPtr flat_bounds = alloc<IntImm>(1);
3233

3234
        for (auto& b : history_before[i]->bounds()) {
3235
          flat_bounds =
3236
              alloc<Mul>(flat_bounds, alloc<Add>(b.end, alloc<IntImm>(1)));
3237

3238
          // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3239
          ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start));
3240
        }
3241

3242
        flat_bounds = IRSimplifier::simplify(flat_bounds);
3243
        ExprPtr after_bounds = IRSimplifier::simplify(
3244
            alloc<Add>(history_after[i]->bounds()[0].end, alloc<IntImm>(1)));
3245
        ASSERT_TRUE(exprEquals(flat_bounds, after_bounds));
3246
      }
3247
    }
3248
  }
3249
}
3250

3251
} // namespace jit
3252
} // namespace torch
3253

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.