pytorch
3252 строки · 101.3 Кб
1#include <gtest/gtest.h>
2#include <test/cpp/tensorexpr/test_base.h>
3
4#include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
5#include <torch/csrc/jit/tensorexpr/ir.h>
6#include <torch/csrc/jit/tensorexpr/ir_printer.h>
7#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8#include <torch/csrc/jit/tensorexpr/loopnest.h>
9#include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h>
10#include <torch/csrc/jit/tensorexpr/tensor.h>
11
12namespace torch {
13namespace jit {
14
15using namespace torch::jit::tensorexpr;
16
17// Test helper function used to determine if two regions of a buffer have an
18// overlap. No Overlap & partial overlap is obvious. Contains means A is
19// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal
20// ranges are ContainedOrEqual.
21TEST(MemDependency, BoundOverlap) {
22using namespace analysis;
23
24auto CB = [](int s, int e) {
25return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
26};
27
28// Sanity check 3 overlap cases.
29ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0)));
30ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5)));
31ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1)));
32
33// Partial overlap works in either order.
34ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14)));
35ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10)));
36
37// Total Overlap works when one bound encloses the other, and returns which.
38ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9)));
39ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16)));
40
41// Total overlap works when the bounds are an identical range, returns
42// ContainedOrEqual.
43ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15)));
44
45// Total overlap when only one end of the bound matches.
46ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10)));
47ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15)));
48ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9)));
49ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15)));
50ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15)));
51
52// No overlap when a < b.
53ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10)));
54ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3)));
55ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130)));
56
57// No overlap when a > b.
58ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2)));
59ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2)));
60ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120)));
61
62// No overlap when adjacent.
63ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120)));
64ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1)));
65
66// Partial overlap when middle bounds match.
67ASSERT_EQ(
68OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120)));
69ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4)));
70ASSERT_EQ(
71OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100)));
72ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2)));
73
74// Total overlap when one bound is single length over one end of the other.
75ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15)));
76ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2)));
77ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15)));
78ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15)));
79}
80
81TEST(MemDependency, BoundComparison) {
82using namespace analysis;
83
84auto CB = [](int s, int e) {
85return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
86};
87
88ASSERT_EQ(
89CmpEvalResult::NotDetermined,
90compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ));
91ASSERT_EQ(
92CmpEvalResult::True,
93compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ));
94ASSERT_EQ(
95CmpEvalResult::False,
96compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ));
97ASSERT_EQ(
98CmpEvalResult::False,
99compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ));
100ASSERT_EQ(
101CmpEvalResult::NotDetermined,
102compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ));
103ASSERT_EQ(
104CmpEvalResult::NotDetermined,
105compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
106ASSERT_EQ(
107CmpEvalResult::NotDetermined,
108compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ));
109
110ASSERT_EQ(
111CmpEvalResult::NotDetermined,
112compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE));
113ASSERT_EQ(
114CmpEvalResult::False,
115compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE));
116ASSERT_EQ(
117CmpEvalResult::True,
118compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE));
119ASSERT_EQ(
120CmpEvalResult::True,
121compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE));
122ASSERT_EQ(
123CmpEvalResult::NotDetermined,
124compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE));
125ASSERT_EQ(
126CmpEvalResult::NotDetermined,
127compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ));
128ASSERT_EQ(
129CmpEvalResult::NotDetermined,
130compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE));
131
132ASSERT_EQ(
133CmpEvalResult::True,
134compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT));
135ASSERT_EQ(
136CmpEvalResult::False,
137compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT));
138ASSERT_EQ(
139CmpEvalResult::False,
140compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT));
141ASSERT_EQ(
142CmpEvalResult::NotDetermined,
143compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT));
144ASSERT_EQ(
145CmpEvalResult::NotDetermined,
146compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT));
147ASSERT_EQ(
148CmpEvalResult::NotDetermined,
149compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT));
150
151ASSERT_EQ(
152CmpEvalResult::False,
153compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE));
154ASSERT_EQ(
155CmpEvalResult::True,
156compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE));
157ASSERT_EQ(
158CmpEvalResult::True,
159compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE));
160ASSERT_EQ(
161CmpEvalResult::NotDetermined,
162compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE));
163ASSERT_EQ(
164CmpEvalResult::NotDetermined,
165compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE));
166ASSERT_EQ(
167CmpEvalResult::NotDetermined,
168compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE));
169
170ASSERT_EQ(
171CmpEvalResult::False,
172compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT));
173ASSERT_EQ(
174CmpEvalResult::False,
175compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT));
176ASSERT_EQ(
177CmpEvalResult::NotDetermined,
178compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT));
179ASSERT_EQ(
180CmpEvalResult::True,
181compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT));
182ASSERT_EQ(
183CmpEvalResult::NotDetermined,
184compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT));
185ASSERT_EQ(
186CmpEvalResult::NotDetermined,
187compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT));
188
189ASSERT_EQ(
190CmpEvalResult::True,
191compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE));
192ASSERT_EQ(
193CmpEvalResult::True,
194compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE));
195ASSERT_EQ(
196CmpEvalResult::NotDetermined,
197compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE));
198ASSERT_EQ(
199CmpEvalResult::False,
200compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE));
201ASSERT_EQ(
202CmpEvalResult::NotDetermined,
203compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE));
204ASSERT_EQ(
205CmpEvalResult::NotDetermined,
206compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE));
207}
208
209TEST(MemDependency, BoundOverlapSymbolic) {
210VarHandle x("x", kInt);
211VarHandle y("y", kInt);
212VarHandle z("z", kInt);
213VarHandle w("w", kInt);
214
215using namespace analysis;
216
217auto CB = [](ExprHandle s, ExprHandle e) {
218return Bound(s.node(), e.node());
219};
220
221// Sanity check cases where the start and end is symbolic but the diff is
222// constant.
223// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
224ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x)));
225ASSERT_EQ(
226OverlapKind::PartialOverlap,
227boundOverlap(CB(x, x + 3), CB(x + 2, x + 5)));
228ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1)));
229
230// We can't infer the sign of y, so cannot tell whether adding y is larger or
231// smaller than y/2.
232ASSERT_EQ(
233OverlapKind::PartialOverlap,
234boundOverlap(CB(x, x + y), CB(x, x + y / 2)));
235
236// No information about this bound, have to take the most conservative option:
237// there may be an overlap.
238ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w)));
239
240// Math on opaque terms works.
241ASSERT_EQ(
242OverlapKind::ContainedOrEqual,
243boundOverlap(CB(x + w, y - z), CB(x + w, y - z)));
244// Even requiring simplification.
245ASSERT_EQ(
246OverlapKind::ContainedOrEqual,
247boundOverlap(CB(x - w - w, y), CB(x - w * 2, y)));
248}
249
250// Tests the helper function for overlap of multi dimensional indices bounds.
251// This uses boundOverlap on each dimension and return the "lowest" kind of
252// overlap.
253TEST(MemDependency, BoundOverlapMultiDim) {
254using namespace analysis;
255
256auto CB = [](int s, int e) {
257return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
258};
259
260// Sanity check one dimensional cases.
261ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)}));
262ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)}));
263ASSERT_EQ(
264OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)}));
265
266// Total overlap in 3 dims.
267ASSERT_EQ(
268OverlapKind::ContainedOrEqual,
269overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)}));
270ASSERT_EQ(
271OverlapKind::ContainedOrEqual,
272overlaps(
273{CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)}));
274
275// Total overlap in 2 dims, no overlap in another.
276ASSERT_EQ(
277OverlapKind::NoOverlap,
278overlaps(
279{CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
280
281// Total overlap in 2 dims, partial overlap in another.
282ASSERT_EQ(
283OverlapKind::PartialOverlap,
284overlaps(
285{CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)}));
286// This case is most important, so verify the overlap in any dim. (dim 2)
287ASSERT_EQ(
288OverlapKind::PartialOverlap,
289overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)}));
290// Dim 1.
291ASSERT_EQ(
292OverlapKind::PartialOverlap,
293overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)}));
294// Total overlap in 1 dim, partial in 2.
295ASSERT_EQ(
296OverlapKind::PartialOverlap,
297overlaps(
298{CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)}));
299// Total overlap, partial overlap, no overlap.
300ASSERT_EQ(
301OverlapKind::NoOverlap,
302overlaps(
303{CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)}));
304
305// Total overlap (B) in 2 dims, total overlap (A) in another.
306ASSERT_EQ(
307OverlapKind::Contains,
308overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)}));
309
310// Total overlap (A) in 2 dims, total overlap (B) in another.
311ASSERT_EQ(
312OverlapKind::Contains,
313overlaps(
314{CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)}));
315
316// Total (B), No Overlap, Total (A).
317ASSERT_EQ(
318OverlapKind::NoOverlap,
319overlaps(
320{CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)}));
321}
322
323// Test the helper we use to subtract bounds: returns the regions(s) of A which
324// remain after removing the region of B.
325TEST(MemDependency, BoundSubtract) {
326using namespace analysis;
327
328auto CB = [](int s, int e) {
329return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
330};
331auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
332return indexBoundsEquals(x, y);
333};
334
335// One element subtract.
336ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0);
337ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0);
338
339// No Overlap.
340ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)}));
341ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)}));
342
343// one side overlap.
344ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)}));
345ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)}));
346ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)}));
347ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)}));
348
349// both sides overlap.
350ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {}));
351ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {}));
352
353// internal overlap.
354ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)}));
355ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)}));
356}
357
358TEST(MemDependency, BoundSubtractSymbolic) {
359VarHandle x("x", kInt);
360VarHandle y("y", kInt);
361VarHandle z("z", kInt);
362VarHandle w("w", kInt);
363
364using namespace analysis;
365
366auto CB = [](ExprHandle s, ExprHandle e) {
367return Bound(s.node(), e.node());
368};
369auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
370return indexBoundsEquals(x, y);
371};
372
373// One element subtract.
374// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
375ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {}));
376ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {}));
377ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {}));
378
379// Subtract constant range low.
380ASSERT_TRUE(
381EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)}));
382// Subtract constant range high.
383ASSERT_TRUE(
384EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)}));
385// Subtract constant range total overlap.
386ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {}));
387ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {}));
388// Subtract constant range internal.
389ASSERT_TRUE(
390EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)),
391{CB(x, x + 2), CB(x + 8, x + 10)}));
392
393// Size is inferable but not constant, only works with a single var.
394ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {}));
395ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)}));
396
397// Size is not inferable.
398ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)}));
399ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)}));
400ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)}));
401ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)}));
402}
403
404// Tests the helper function that does subtraction, but for multi dimensional
405// indices bounds.
406TEST(MemDependency, BoundSubtractMultiDim) {
407using namespace analysis;
408
409auto CB = [](int s, int e) {
410return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
411};
412auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
413if (x.size() != y.size()) {
414return false;
415}
416for (auto i = 0U; i < x.size(); ++i) {
417if (!indexBoundsEquals(x[i], y[i])) {
418return false;
419}
420}
421return true;
422};
423
424// sanity check one dimension.
425ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {}));
426ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {}));
427ASSERT_TRUE(
428EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}}));
429ASSERT_TRUE(
430EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}}));
431ASSERT_TRUE(EQ(
432subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}}));
433
434// Multi dim total overlap.
435ASSERT_TRUE(EQ(
436subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {}));
437ASSERT_TRUE(EQ(
438subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {}));
439
440// Mutli dim one way partial in dim 1.
441ASSERT_TRUE(
442EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}),
443{{CB(4, 9), CB(0, 2)}}));
444
445// Mutli dim one way partial in dim 2.
446ASSERT_TRUE(
447EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}),
448{{CB(0, 9), CB(11, 20)}}));
449
450// Partial overlap in 2 dims.
451ASSERT_TRUE(
452EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}),
453{{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}}));
454
455// Partial overlap in 3 dims.
456ASSERT_TRUE(
457EQ(subtractIndicesBounds(
458{CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}),
459{{CB(0, 1), CB(0, 5), CB(0, 5)},
460{CB(2, 5), CB(0, 1), CB(0, 5)},
461{CB(2, 5), CB(2, 5), CB(0, 1)}}));
462}
463
464// Tests the multi dimensional subtraction code for bounds that cannot be fully
465// materialized.
466TEST(MemDependency, BoundSubtractMultiDimSymbolic) {
467VarHandle x("x", kInt);
468VarHandle y("y", kInt);
469
470using namespace analysis;
471
472auto CB = [](ExprHandle s, ExprHandle e) {
473return Bound(s.node(), e.node());
474};
475
476auto EQ = [](std::vector<IndexBounds> x, std::vector<IndexBounds> y) {
477if (x.size() != y.size()) {
478return false;
479}
480for (auto i = 0U; i < x.size(); ++i) {
481if (!indexBoundsEquals(x[i], y[i])) {
482return false;
483}
484}
485return true;
486};
487
488// Cannot determine overlaps.
489// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
490ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}}));
491
492// Various total Overlaps.
493ASSERT_TRUE(EQ(
494subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {}));
495ASSERT_TRUE(EQ(
496subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {}));
497ASSERT_TRUE(EQ(
498subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {}));
499ASSERT_TRUE(EQ(
500subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {}));
501
502// one-way overlap in first dim.
503ASSERT_TRUE(
504EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}),
505{{CB(x - 4, x), CB(0, y)}}));
506// second dim.
507ASSERT_TRUE(
508EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}),
509{{CB(0, x), CB(0, 4)}}));
510
511// Internal overlap in first dim.
512ASSERT_TRUE(
513EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}),
514{{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}}));
515// second dim.
516ASSERT_TRUE(EQ(
517subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}),
518{{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}}));
519
520// Overlap in both dimensions.
521ASSERT_TRUE(
522EQ(subtractIndicesBounds(
523{CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}),
524{
525{CB(0, 4), CB(0, y)},
526{CB(x - 4, x), CB(0, y)},
527{CB(0, x), CB(0, 9)},
528{CB(0, x), CB(y - 9, y)},
529}));
530}
531
532// Simple check that the analyzer does anything at all...
533TEST(MemDependency, MemDependencyCheckerSimple) {
534BufHandle a("A", {1}, kInt);
535BufHandle b("B", {1}, kInt);
536
537analysis::MemDependencyChecker analyzer;
538
539/*
540* A[0] = 3;
541* B[0] = A[0] + 1;
542*/
543
544StorePtr aStore = Store::make(a, {0}, 3);
545StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
546
547StmtPtr stmt = Block::make({aStore, bStore});
548
549stmt->accept(&analyzer);
550
551ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
552ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
553// sanity check, but anything that depends directly must depend indirectly.
554ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore));
555}
556
557// Check that there is a difference between direct and indirect dependence.
558TEST(MemDependency, MemDependencyCheckerMultiStmt) {
559BufHandle a("A", {1}, kInt);
560BufHandle b("B", {1}, kInt);
561BufHandle c("C", {1}, kInt);
562
563analysis::MemDependencyChecker analyzer;
564
565/*
566* A[0] = 3;
567* B[0] = A[0];
568* C[0] = B[0] + 1;
569*/
570
571StorePtr aStore = Store::make(a, {0}, 3);
572StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
573StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1));
574
575StmtPtr stmt = Block::make({aStore, bStore, cStore});
576
577stmt->accept(&analyzer);
578
579// C depends on A indirectly.
580ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore));
581ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore));
582
583// C depends on B directly, which depends on A directly.
584ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore));
585ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
586
587// Dependency goes top to bottom only.
588ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore));
589ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore));
590ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore));
591}
592
593// Verify that we do filter writes that are totally overlapped by later writes.
594TEST(MemDependency, MemDependencyCheckerOverlap) {
595BufHandle a("A", {1}, kInt);
596BufHandle b("B", {1}, kInt);
597
598analysis::MemDependencyChecker analyzer;
599
600/*
601* A[0] = 3;
602* A[0] = 6;
603* B[0] = A[0] + 1;
604*/
605
606StorePtr aStore = Store::make(a, {0}, 3);
607StorePtr a2Store = Store::make(a, {0}, 6);
608StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1));
609
610StmtPtr stmt = Block::make({aStore, a2Store, bStore});
611
612stmt->accept(&analyzer);
613
614// B store depends on second A store but not first since it is completely
615// overlapped.
616ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store));
617ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore));
618
619// No dependency between either A store.
620ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store));
621ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore));
622}
623
624// Verify that bounds match loop iterations, and that dependencies progress
625// across loop scopes.
626TEST(MemDependency, MemDependencyCheckerLoop) {
627BufHandle a("A", {1}, kInt);
628BufHandle b("B", {1}, kInt);
629VarHandle x("x", kInt);
630
631using namespace analysis;
632
633MemDependencyChecker analyzer;
634
635/*
636* for (int x = 0; x < 10; ++x) {
637* A[x] = x;
638* }
639* B[0] = A[0] + 1;
640*/
641
642StorePtr aStore = Store::make(a, {x}, x);
643StmtPtr loop = For::make(x, 0, 10, aStore);
644StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1));
645
646StmtPtr stmt = Block::make({loop, bStore});
647
648stmt->accept(&analyzer);
649
650// Same A->B dependency.
651ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore));
652
653// B depends on the loop.
654ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
655// A is in the loop but does not depend on any loop iteration.
656ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop));
657
658auto aStoreAccess = analyzer.accessFor(aStore);
659ASSERT_NE(aStoreAccess, nullptr);
660
661// It should have bounds covering the range of x: 0 <= x < 10.
662ASSERT_TRUE(indexBoundsEquals(
663aStoreAccess->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
664}
665
666// Reductions should promote dependencies as well.
667TEST(MemDependency, MemDependencyCheckerLoopReduce) {
668BufHandle a("A", {10}, kInt);
669BufHandle b("B", {10}, kInt);
670VarHandle x("x", kInt);
671
672using namespace analysis;
673
674MemDependencyChecker analyzer;
675
676/*
677* A[0] = 0;
678* for (int x = 0; x < 10; ++x) {
679* A[0] = A[x] + 1;
680* }
681* B[0] = A[0];
682*/
683
684StorePtr aInit = Store::make(a, {0}, 0);
685ExprHandle reduce = Sum()(a, 1, {x}, {x});
686StorePtr aReduce = Store::make(a, {0}, reduce);
687StmtPtr loop = For::make(x, 0, 10, aReduce);
688StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
689
690StmtPtr stmt = Block::make({aInit, loop, bStore});
691
692stmt->accept(&analyzer);
693
694// B -> A.
695ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
696
697// B depends indirectly on the initializer of A, since the reduction depends
698// on it.
699ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
700ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
701
702ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
703
704// B depends on the loop.
705ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
706// A is in the loop and depends on other iterations.
707ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
708
709// The loop contents depend on the initializer too.
710ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
711
712// Find loads within the reduction:
713auto reduceLoads = NodeFinder<Load>::find(reduce.node());
714// Pull out the access for the load inside the loop.
715for (auto load : reduceLoads) {
716auto loopLoad = analyzer.accessFor(load);
717// It should have 10 element long bounds.
718ASSERT_TRUE(indexBoundsEquals(
719loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
720}
721}
722
723// Lowering a reduction doesn't affect dependency analysis.
724TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) {
725BufHandle a("A", {10}, kInt);
726BufHandle b("B", {10}, kInt);
727VarHandle x("x", kInt);
728
729using namespace analysis;
730
731MemDependencyChecker analyzer;
732
733/*
734* A[0] = 0;
735* for (int x = 0; x < 10; ++x) {
736* A[0] = A[x] + 1;
737* }
738* B[0] = A[0];
739*/
740
741StorePtr aInit = Store::make(a, {0}, 0);
742ExprHandle aLoad = Load::make(a, {x});
743StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1));
744StmtPtr loop = For::make(x, 0, 10, aReduce);
745StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
746
747StmtPtr stmt = Block::make({aInit, loop, bStore});
748
749stmt->accept(&analyzer);
750
751// B -> A.
752ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce));
753
754// B depends indirectly on the initializer of A, since the reduction depends
755// on it.
756ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit));
757ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit));
758
759ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit));
760
761// B depends on the loop.
762ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop));
763// A is in the loop and depends on other iterations.
764ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop));
765
766// The loop contents depend on the initializer too.
767ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit));
768
769// Pull out the access for the store inside the loop.
770auto loopLoad = analyzer.accessFor(aLoad.node());
771// It should have 10 element long bounds.
772ASSERT_TRUE(indexBoundsEquals(
773loopLoad->bounds(), {Bound(alloc<IntImm>(0), alloc<IntImm>(9))}));
774}
775
776// Can determine dependencies of outputs, through to inputs.
777TEST(MemDependency, MemDependencyCheckerInputsOutputs) {
778BufHandle a("A", {10}, kInt);
779BufHandle b("B", {10}, kInt);
780VarHandle x("x", kInt);
781
782// initialize analyzer with inputs and outputs.
783analysis::MemDependencyChecker analyzer({a}, {b});
784
785// Here's a Relu.
786/*
787* for (int x = 0; x < 10; ++x) {
788* B[x] = Max(A[x], 0);
789* }
790*/
791
792ExprHandle aLoad = Load::make(a, {x});
793StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true));
794StmtPtr loop = For::make(x, 0, 10, bStore);
795
796StmtPtr stmt = Block::make({loop});
797
798stmt->accept(&analyzer);
799
800// Output depends indirectly on input.
801ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
802// aLoad depends directly on the input A.
803ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node()));
804// bStore therefore depends directly on the input A.
805ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node()));
806// The output depends directly on the store.
807ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
808
809// Check AccessInfo based overloads.
810auto input = analyzer.input(a.node());
811auto output = analyzer.output(b.node());
812
813// Output depends indirectly on input.
814ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
815// Not directly.
816ASSERT_FALSE(analyzer.dependsDirectly(output, input));
817// Not in reverse order.
818ASSERT_FALSE(analyzer.dependsIndirectly(input, output));
819
820// output -> bStore -> bLoad -> input.
821auto storeAccess = analyzer.accessFor(bStore);
822auto loadAccess = analyzer.accessFor(aLoad.node());
823
824ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess));
825ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input));
826}
827
828// Can tell if an output does not depend on an input.
829TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) {
830BufHandle a("A", {10}, kInt);
831BufHandle b("B", {10}, kInt);
832VarHandle x("x", kInt);
833
834// initialize analyzer with inputs and outputs.
835analysis::MemDependencyChecker analyzer({a}, {b});
836
837// Here's a dumb Relu.
838/*
839* for (int x = 0; x < 10; ++x) {
840* B[x] = Max(x, 0);
841* }
842*/
843
844StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true));
845StmtPtr loop = For::make(x, 0, 10, bStore);
846
847StmtPtr stmt = Block::make({loop});
848
849stmt->accept(&analyzer);
850
851// Output does not depend indirectly on input.
852ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node()));
853
854// The output still depends directly on the store.
855ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore));
856
857// Check AccessInfo based overloads.
858auto input = analyzer.input(a.node());
859auto output = analyzer.output(b.node());
860
861// Output does not depend indirectly on input.
862ASSERT_FALSE(analyzer.dependsIndirectly(output, input));
863}
864
865// Verify different loop extents produce accesses with different bounds, and
866// that later accesses find dependencies that overlap their entire bound range.
867TEST(MemDependency, MemDependencyCheckerLoopBounds) {
868BufHandle a("A", {10}, kInt);
869BufHandle b("B", {10}, kInt);
870BufHandle c("C", {10}, kInt);
871VarHandle x("x", kInt);
872using namespace analysis;
873
874MemDependencyChecker analyzer({a}, {c});
875
876// This enables using the execution order of the loops to determine if some
877// loops are self dependent or not.
878analyzer.allowLoopExecutionOrderAnalysis();
879
880/*
881* for (int x = 1; x < 10; ++x) {
882* B[x] = A[x];
883* }
884* for (int x = 1; x < 9; ++x) {
885* B[x] = B[x] * 2;
886* }
887* for (int x = 3; x < 4; ++x) {
888* C[x] = A[x];
889* }
890* for (int x = 0; x < 10; ++x) {
891* C[x] = B[x];
892* }
893*/
894
895std::vector<StmtPtr> stmts(
896{For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))),
897For::make(
898x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))),
899For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))),
900For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))});
901
902StmtPtr stmt = Block::make(stmts);
903
904stmt->accept(&analyzer);
905
906auto input = analyzer.input(a.node());
907auto output = analyzer.output(c.node());
908
909// sanity check Output -> Input.
910ASSERT_TRUE(analyzer.dependsIndirectly(output, input));
911
912// Check the For loop dependencies:
913
914// Last write to C depends on both writes to B since they contain the last
915// write to at least one element.
916ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1]));
917ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0]));
918
919// The last write to C does not depend on the other write to C.
920ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2]));
921
922auto CB = [](int s, int e) {
923return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
924};
925auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
926return indexBoundsEquals(x, y);
927};
928
929/* 0. Input: A[(0, 9)] - dependents: 1 5
930* 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2
931* 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7
932* 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4
933* 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7
934* 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6
935* 6. Store: C[(3, 3)] - depends on: 5
936* 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8
937* 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9
938* 9. Output: C[(0, 9)] - depends on: 8
939*/
940
941// Now let's look at the bounds of each access.
942// There are 9 accesses in this Stmt, so this is exhaustive, we wont do this
943// much.
944auto history = analyzer.getHistory();
945ASSERT_EQ(history.size(), 10);
946VarPtr aVar = a.node()->base_handle();
947VarPtr bVar = b.node()->base_handle();
948VarPtr cVar = c.node()->base_handle();
949
950// The first access is the input A.
951ASSERT_EQ(history[0]->type(), AccessType::Input);
952ASSERT_EQ(history[0]->var(), aVar);
953// It has the bounds of the producing Input.
954ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
955// sanity check the input we retrieved earlier matches.
956ASSERT_EQ(history[0], input);
957
958// The second access is the load of A in the first loop.
959ASSERT_EQ(history[1]->type(), AccessType::Load);
960ASSERT_EQ(history[1]->var(), aVar);
961// It has the bounds of the loop, i.e. start == 1.
962ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)}));
963// It reads from A, so it should have a dependency on the last write to this
964// range - with is the input.
965ASSERT_EQ(history[1]->dependencies().size(), 1);
966ASSERT_TRUE(history[1]->hasDependency(history[0]));
967
968// The third access is the store into B in the first loop.
969ASSERT_EQ(history[2]->type(), AccessType::Store);
970ASSERT_EQ(history[2]->var(), bVar);
971// It also has the bounds of the loop, i.e. start == 1.
972ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
973// The previous load is in its RHS, so it depends on it.
974ASSERT_EQ(history[2]->dependencies().size(), 1);
975ASSERT_TRUE(history[2]->hasDependency(history[1]));
976
977// The third access is the load from B in the second loop.
978ASSERT_EQ(history[3]->type(), AccessType::Load);
979ASSERT_EQ(history[3]->var(), bVar);
980// It has the bounds of the second loop, i.e. >= 1 < 9.
981ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)}));
982// It reads from B in a smaller range, so should depend on the previous
983// store.
984ASSERT_EQ(history[3]->dependencies().size(), 1);
985ASSERT_TRUE(history[3]->hasDependency(history[2]));
986
987// The fourth: the store to B in the second loop.
988ASSERT_EQ(history[4]->type(), AccessType::Store);
989ASSERT_EQ(history[4]->var(), bVar);
990// It also has the bounds of the second loop.
991ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)}));
992// The previous load is in its RHS, so it depends on it as before.
993ASSERT_EQ(history[4]->dependencies().size(), 1);
994ASSERT_TRUE(history[4]->hasDependency(history[3]));
995
996// The fifth access is the load is from the 3rd loop, and skips previous B
997// accesses.
998ASSERT_EQ(history[5]->type(), AccessType::Load);
999ASSERT_EQ(history[5]->var(), aVar);
1000// It has the bounds of the third loop: >= 3 < 4.
1001ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)}));
1002// It depends on the last thing to write to A, which is the A input.
1003ASSERT_EQ(history[5]->dependencies().size(), 1);
1004ASSERT_TRUE(history[5]->hasDependency(history[0]));
1005
1006// Sixth: the store into the output C.
1007ASSERT_EQ(history[6]->type(), AccessType::Store);
1008ASSERT_EQ(history[6]->var(), cVar);
1009// It also has the bounds of the third loop.
1010ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)}));
1011// The previous load is in its RHS, so it depends on it as always.
1012ASSERT_EQ(history[6]->dependencies().size(), 1);
1013ASSERT_TRUE(history[6]->hasDependency(history[5]));
1014
1015// The seventh access is the load of B in the fourth loop.
1016ASSERT_EQ(history[7]->type(), AccessType::Load);
1017ASSERT_EQ(history[7]->var(), bVar);
1018// It has the bounds of the final loop, >= 0 < 10
1019ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
1020// The bounds of this read are larger than the bounds of the previous write,
1021// so it depends on both previous Stores to B.
1022ASSERT_EQ(history[7]->dependencies().size(), 2);
1023ASSERT_TRUE(history[7]->hasDependency(history[2]));
1024ASSERT_TRUE(history[7]->hasDependency(history[4]));
1025
1026// Eight: the final store into the output C.
1027ASSERT_EQ(history[8]->type(), AccessType::Store);
1028ASSERT_EQ(history[8]->var(), cVar);
1029// It also has the bounds of the final loop.
1030ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
1031// The previous load is in its RHS, so it depends on it as always.
1032ASSERT_EQ(history[8]->dependencies().size(), 1);
1033ASSERT_TRUE(history[8]->hasDependency(history[7]));
1034
1035// The last access represents the output Buf.
1036ASSERT_EQ(history[9]->type(), AccessType::Output);
1037ASSERT_EQ(history[9]->var(), cVar);
1038// It has the bounds of the output Buf.
1039ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)}));
1040// sanity check the input we retrieved earlier matches.
1041ASSERT_EQ(history[9], output);
1042// It depends on the last write to C only.
1043ASSERT_EQ(history[9]->dependencies().size(), 1);
1044ASSERT_TRUE(history[9]->hasDependency(history[8]));
1045}
1046
1047// Verify that we can still infer bounds when the loop var is offset.
1048TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) {
1049BufHandle a("A", {10}, kInt);
1050BufHandle b("B", {10}, kInt);
1051VarHandle x("x", kInt);
1052
1053using namespace analysis;
1054
1055MemDependencyChecker analyzer({a}, {b});
1056
1057// This enables using the execution order of the loops to determine if some
1058// loops are self dependent or not.
1059analyzer.allowLoopExecutionOrderAnalysis();
1060
1061/*
1062* for (int x = 1; x < 10; x++) {
1063* A[x] = A[x - 1];
1064* }
1065* for (int x = 0; x < 9; x++) {
1066* A[x] = A[x + 1];
1067* }
1068* for (int x = 0; x < 9; x++) {
1069* A[9 - x] = A[8 - x];
1070* }
1071* for (int x = 0; x < 10; x++) {
1072* A[x] = A[9 - x];
1073* }
1074* for (int x = 0; x < 10; x++) {
1075* B[x] = A[x];
1076* }
1077*/
1078
1079StmtPtr stmt = Block::make(
1080{For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
1081For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))),
1082For::make(
1083x,
10840,
10859,
1086Store::make(
1087a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))),
1088For::make(
1089x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))),
1090For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))});
1091
1092stmt->accept(&analyzer);
1093
1094// Sanity check output depends on Input.
1095ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
1096
1097auto CB = [](int s, int e) {
1098return Bound(alloc<IntImm>(s), alloc<IntImm>(e));
1099};
1100auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
1101return indexBoundsEquals(x, y);
1102};
1103
1104/* 0. Input: A[(0, 9)] - dependents: 1
1105* 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2
1106* 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3
1107* 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4
1108* 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7
1109* 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6
1110* 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7
1111* 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8
1112* 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9
1113* 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10
1114* 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11
1115* 11. Output: B[(0, 9)] - depends on: 10
1116*/
1117
1118// Now let's look at the bounds of each access.
1119auto history = analyzer.getHistory();
1120ASSERT_EQ(history.size(), 12);
1121VarPtr aVar = a.node()->base_handle();
1122VarPtr bVar = b.node()->base_handle();
1123
1124// The first access is the input A.
1125ASSERT_EQ(history[0]->type(), AccessType::Input);
1126ASSERT_EQ(history[0]->var(), aVar);
1127// It has the bounds of the producing Input.
1128ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)}));
1129
1130// The second access is the load A[x-1].
1131ASSERT_EQ(history[1]->type(), AccessType::Load);
1132ASSERT_EQ(history[1]->var(), aVar);
1133// It has the bounds of the loop modified by the offset of each index, in
1134// this case -1.
1135ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)}));
1136// It depends on the input, but also the store in the same loop, since
1137// different interations of the loop depend on each other.
1138ASSERT_EQ(history[1]->dependencies().size(), 2);
1139ASSERT_TRUE(history[1]->hasDependency(history[0]));
1140ASSERT_TRUE(history[1]->hasDependency(history[2]));
1141
1142// The third access is the Store to A[x] in the first loop.
1143ASSERT_EQ(history[2]->type(), AccessType::Store);
1144ASSERT_EQ(history[2]->var(), aVar);
1145// It has no offset on x, so should have the same bounds as the loop.
1146ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)}));
1147
1148// The fourth access is the load A[x+1] in the second loop.
1149ASSERT_EQ(history[3]->type(), AccessType::Load);
1150ASSERT_EQ(history[3]->var(), aVar);
1151// It has the bounds of the loop (0 <= x < 9) modified by the offset of each
1152// index, in this case 1.
1153ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)}));
1154// This load totally overlaps the previous write to A, so it depends only on
1155// it and not the input.
1156ASSERT_EQ(history[3]->dependencies().size(), 1);
1157ASSERT_TRUE(history[3]->hasDependency(history[2]));
1158
1159// The fifth access is the store to A[x] in the second loop.
1160ASSERT_EQ(history[4]->type(), AccessType::Store);
1161ASSERT_EQ(history[4]->var(), aVar);
1162// It has no offset on x, so should have the same bounds as the loop.
1163ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)}));
1164
1165// The sixth access is the load to A[8 - x] in the third loop.
1166ASSERT_EQ(history[5]->type(), AccessType::Load);
1167ASSERT_EQ(history[5]->var(), aVar);
1168// It has the bounds of the loop (0 <= x < 9) modified by the offset of each
1169// index, in this case 8 - x.
1170// This access has a negative stride, which will be normalized.
1171ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)}));
1172// This load totally overlaps the most recent write to A, so it depends only
1173// on it and not the input or the first write to A.
1174ASSERT_EQ(history[5]->dependencies().size(), 1);
1175ASSERT_TRUE(history[5]->hasDependency(history[4]));
1176
1177// The seventh access is the store to A[9 - x] in the third loop.
1178ASSERT_EQ(history[6]->type(), AccessType::Store);
1179ASSERT_EQ(history[6]->var(), aVar);
1180// This store has a negative stride on it's indices, but is normalized
1181// internally.
1182ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)}));
1183
1184// The eighth access is the load A[9-x] in the second loop.
1185ASSERT_EQ(history[7]->type(), AccessType::Load);
1186ASSERT_EQ(history[7]->var(), aVar);
1187// It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x,
1188// which essentially traverses the loop backwards.
1189ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)}));
1190// This Load has three write dependencies:
1191ASSERT_EQ(history[7]->dependencies().size(), 3);
1192// * The previous store (#6) for elements 1-9
1193ASSERT_TRUE(history[7]->hasDependency(history[6]));
1194// * An earlier store (#4) covering element 0
1195ASSERT_TRUE(history[7]->hasDependency(history[4]));
1196// * A future store inside this loop, since this loop modifies the buffer
1197// in a non distinct way (due to the load and store having different access
1198// strides).
1199ASSERT_TRUE(history[7]->hasDependency(history[8]));
1200
1201// The ninth access is the store to A[x] in the fourth loop.
1202ASSERT_EQ(history[8]->type(), AccessType::Store);
1203ASSERT_EQ(history[8]->var(), aVar);
1204// This store has a negative stride on it's indices, but is normalized
1205// internally.
1206ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)}));
1207
1208// The tenth and 11th accesses are the copy from A[x] to B[x].
1209ASSERT_EQ(history[9]->type(), AccessType::Load);
1210ASSERT_EQ(history[9]->var(), aVar);
1211ASSERT_EQ(history[10]->type(), AccessType::Store);
1212ASSERT_EQ(history[10]->var(), bVar);
1213
1214// The last access represents the output Buf.
1215ASSERT_EQ(history[11]->type(), AccessType::Output);
1216ASSERT_EQ(history[11]->var(), bVar);
1217// It has the bounds of the output Buf.
1218ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)}));
1219// It depends on the last write to B only.
1220ASSERT_EQ(history[11]->dependencies().size(), 1);
1221ASSERT_TRUE(history[11]->hasDependency(history[10]));
1222
1223// ok that's enough of that.
1224}
1225
1226// Check many different cases of loop self dependency - when a load within a
1227// loop is dependent on a Store later in the same loop but in different
1228// iteration. This is affected by whether or not we can trust the execution
1229// order of the loop.
1230TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) {
1231BufHandle a("A", {5}, kInt);
1232BufHandle b("B", {5}, kInt);
1233VarHandle x("x", kInt);
1234VarHandle y("y", kInt);
1235VarHandle z("z", kInt);
1236
1237using namespace analysis;
1238
1239// This check assumes that the Stmt has a single Store with a single Load on
1240// the RHS.
1241auto isSelfDependent =
1242[](const std::vector<std::shared_ptr<AccessInfo>>& history) -> bool {
1243return history.front()->hasDependency(history.back());
1244};
1245
1246{
1247/* for (int y = 0; y < 10; y++) {
1248* A[y] = (A[y]) + 1;
1249* } */
1250
1251// Not self dependent since all loop iterations use a different y.
1252
1253MemDependencyChecker analyzer;
1254StmtPtr stmt = For::make(
1255y,
12560,
125710,
1258Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))}));
1259
1260stmt->accept(&analyzer);
1261
1262ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1263}
1264
1265{
1266/* for (int y = 0; y < 10; y++) {
1267* A[y + 1] = (A[y + 1]) + 1;
1268* }
1269*/
1270
1271// Not self dependent due to different y (with offset).
1272
1273MemDependencyChecker analyzer;
1274StmtPtr stmt = For::make(
1275y,
12760,
127710,
1278Block::make(
1279{Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))}));
1280
1281stmt->accept(&analyzer);
1282
1283ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1284}
1285
1286{
1287/* for (int x = 0; x < 10; x++) {
1288* A[0] = (A[0]) + x;
1289* }
1290*/
1291
1292// Is self dependent since all loops use a common constant element of A.
1293
1294MemDependencyChecker analyzer;
1295StmtPtr stmt = For::make(
1296x,
12970,
129810,
1299Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}));
1300stmt->accept(&analyzer);
1301
1302ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1303}
1304
1305{
1306/* for (int x = 0; x < 10; x++) {
1307* A[0] = (B[0]) + x;
1308* }
1309*/
1310
1311// Is not self dependent because there is no store to the buffer that is
1312// read.
1313
1314MemDependencyChecker analyzer;
1315StmtPtr stmt = For::make(
1316x,
13170,
131810,
1319Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))}));
1320stmt->accept(&analyzer);
1321
1322ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1323}
1324
1325{
1326/* for (int x = 0; x < 10; x++) {
1327* A[y] = (A[y]) + x;
1328* }
1329*/
1330
1331// Is self dependent since all loops use a common symbolic element of A.
1332
1333MemDependencyChecker analyzer;
1334StmtPtr stmt = For::make(
1335x,
13360,
133710,
1338Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))}));
1339stmt->accept(&analyzer);
1340
1341ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1342}
1343
1344{
1345/* for (int x = 0; x < 10; x++) {
1346* A[x] = A[x + 1];
1347* }
1348*/
1349
1350// In this case it depends if we are considering execution order.
1351
1352MemDependencyChecker analyzer;
1353
1354StmtPtr stmt =
1355For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
1356stmt->accept(&analyzer);
1357
1358// With analysis of order disabled, this is self dependent since the read
1359// from X+1 and the write to X+1 could be in reverse order.
1360ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1361}
1362
1363{
1364/* for (int x = 0; x < 10; x++) {
1365* A[x] = A[x + 1];
1366* }
1367*/
1368
1369MemDependencyChecker analyzer;
1370analyzer.allowLoopExecutionOrderAnalysis();
1371
1372StmtPtr stmt =
1373For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1})));
1374stmt->accept(&analyzer);
1375
1376// If order analysis is enabled, this is not dependent since the read for
1377// each element occurs before the write to that element.
1378ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1379}
1380
1381{
1382/* for (int x = 1; x < 10; x++) {
1383* A[x] = A[x - 1];
1384* }
1385*/
1386
1387MemDependencyChecker analyzer;
1388
1389StmtPtr stmt =
1390For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
1391stmt->accept(&analyzer);
1392
1393ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1394}
1395
1396{
1397/* for (int x = 1; x < 10; x++) {
1398* A[x] = A[x - 1];
1399* }
1400*/
1401
1402MemDependencyChecker analyzer;
1403analyzer.allowLoopExecutionOrderAnalysis();
1404
1405StmtPtr stmt =
1406For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})));
1407stmt->accept(&analyzer);
1408
1409// In this case, even with order analysis the Load is dependent on the
1410// Store, since the write to X occurs before the read from X.
1411ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1412}
1413
1414{
1415/* for (int x = 0; x < 9; x++) {
1416* A[9 - x] = A[8 - x];
1417* }
1418*/
1419
1420// Still works if the execution order is reversed, so long as the read
1421// comes before the write.
1422
1423MemDependencyChecker analyzer;
1424analyzer.allowLoopExecutionOrderAnalysis();
1425
1426StmtPtr stmt = For::make(
1427x,
14283,
142910,
1430Store::make(
1431a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
1432stmt->accept(&analyzer);
1433
1434// However here was can determine the A store is earlier in the order than
1435// the load.
1436ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1437}
1438
1439{
1440/* for (int x = 0; x < 9; x++) {
1441* A[8 - x] = A[9 - x];
1442* }
1443*/
1444
1445// But not if it doesn't.
1446
1447MemDependencyChecker analyzer;
1448analyzer.allowLoopExecutionOrderAnalysis();
1449
1450StmtPtr stmt = For::make(
1451x,
14523,
145310,
1454Store::make(
1455a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x})));
1456stmt->accept(&analyzer);
1457
1458ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1459}
1460
1461{
1462/* for (int x = 0; x < 9; x++) {
1463* A[9 - x] = A[8 - x];
1464* }
1465*/
1466
1467// And not if we're not relying on execution order.
1468
1469MemDependencyChecker analyzer;
1470
1471StmtPtr stmt = For::make(
1472x,
14733,
147410,
1475Store::make(
1476a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x})));
1477stmt->accept(&analyzer);
1478
1479ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1480}
1481
1482{
1483/* for (int x = 3; x < 10; x++) {
1484* A[x - 2] = A[x - 1];
1485* }
1486*/
1487
1488// Forward order but negative indices.
1489
1490MemDependencyChecker analyzer;
1491analyzer.allowLoopExecutionOrderAnalysis();
1492
1493StmtPtr stmt =
1494For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1})));
1495stmt->accept(&analyzer);
1496
1497// However here was can determine the A store is earlier in the order than
1498// the load.
1499ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1500}
1501
1502{
1503/* for (int x = 0; x < 10; x++) {
1504* A[x * 2] = A[x * 2];
1505* }
1506*/
1507
1508// With an access stride.
1509
1510MemDependencyChecker analyzer;
1511// Execution order doesn't matter since the read and the write are totally
1512// distinct.
1513
1514StmtPtr stmt =
1515For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2})));
1516stmt->accept(&analyzer);
1517
1518ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1519}
1520
1521{
1522/* for (int x = 0; x < 10; x++) {
1523* A[x * 2] = A[x * 2 + 1];
1524* }
1525*/
1526
1527// Here we can use the common stride of the accesses to determine they are
1528// distinct.
1529// Note, this is the only place (loop self dependency) we use this stride
1530// to avoid unnecessary dependence.
1531
1532MemDependencyChecker analyzer;
1533// Execution order doesn't matter since the read and the write are totally
1534// distinct.
1535
1536StmtPtr stmt = For::make(
1537x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1})));
1538stmt->accept(&analyzer);
1539
1540ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1541}
1542
1543{
1544/* for (int x = 0; x < 10; x++) {
1545* A[x * 2] = A[x * 2 - 1];
1546* }
1547*/
1548
1549// same if the read is behind the write so long as they are distinct.
1550
1551MemDependencyChecker analyzer;
1552StmtPtr stmt = For::make(
1553x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1})));
1554stmt->accept(&analyzer);
1555
1556ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1557}
1558
1559{
1560/* for (int x = 0; x < 10; x++) {
1561* A[x * 2] = A[x * 2 + 2];
1562* }
1563*/
1564
1565// But not if the offset is in the stride.
1566
1567MemDependencyChecker analyzer;
1568StmtPtr stmt = For::make(
1569x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2})));
1570stmt->accept(&analyzer);
1571
1572ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1573}
1574
1575{
1576/* for (int x = 0; x < 10; x++) {
1577* A[x * 2] = A[x * 2 - 2];
1578* }
1579*/
1580
1581// Works with negative offsets too.
1582
1583MemDependencyChecker analyzer;
1584StmtPtr stmt = For::make(
1585x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2})));
1586stmt->accept(&analyzer);
1587
1588ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1589}
1590
1591{
1592/* for (int x = 0; x < 10; x++) {
1593* A[x * 2] = A[x * 2 + 7];
1594* }
1595*/
1596
1597// Detects accesses are distinct when offset is large but not a multiple
1598// of stride.
1599MemDependencyChecker analyzer;
1600StmtPtr stmt = For::make(
1601x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7})));
1602stmt->accept(&analyzer);
1603
1604ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1605}
1606
1607{
1608/* for (int x = 0; x < 10; x++) {
1609* A[x * 2] = A[x * 2 + 4];
1610* }
1611*/
1612
1613// Works with offsets which are multiples of the stride.
1614MemDependencyChecker analyzer;
1615StmtPtr stmt = For::make(
1616x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4})));
1617stmt->accept(&analyzer);
1618
1619ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1620}
1621
1622{
1623/* for (int x = 0; x < 10; x++) {
1624* A[x * 6] = A[x * 6 + 5];
1625* }
1626*/
1627
1628// detects accesses are distinct with large strides when the offset is
1629// within.
1630
1631MemDependencyChecker analyzer;
1632StmtPtr stmt = For::make(
1633x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5})));
1634stmt->accept(&analyzer);
1635
1636ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1637}
1638
1639{
1640/* for (int x = 0; x < 10; x++) {
1641* A[x * 2] = A[x * 6];
1642* }
1643*/
1644
1645// detects accesses are overlapping when stride is different but a
1646// multiple.
1647
1648MemDependencyChecker analyzer;
1649StmtPtr stmt =
1650For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6})));
1651stmt->accept(&analyzer);
1652
1653ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1654}
1655
1656{
1657/* for (int x = 0; x < 10; x++) {
1658* A[x * 4] = A[x * 2];
1659* }
1660*/
1661
1662// still works when the read axis is the smaller stride.
1663
1664MemDependencyChecker analyzer;
1665StmtPtr stmt =
1666For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2})));
1667stmt->accept(&analyzer);
1668
1669ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1670}
1671
1672{
1673/* for (int x = 0; x < 10; x++) {
1674* A[x * 2] = A[x * 6 + 1];
1675* }
1676*/
1677
1678// detects accesses are distinct when stride is different but a multiple
1679// and there is an offset.
1680
1681MemDependencyChecker analyzer;
1682StmtPtr stmt = For::make(
1683x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1})));
1684stmt->accept(&analyzer);
1685
1686ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1687}
1688
1689{
1690/* for (int x = 0; x < 10; x++) {
1691* A[x * 2] = A[x * 6 + 4];
1692* }
1693*/
1694
1695// The smaller stride determines whether there is overlap.
1696
1697MemDependencyChecker analyzer;
1698StmtPtr stmt = For::make(
1699x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4})));
1700stmt->accept(&analyzer);
1701
1702ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1703}
1704
1705{
1706/* for (int x = 0; x < 10; x++) {
1707* A[x * 2 + 3] = A[x * 6];
1708* }
1709*/
1710
1711// The smaller stride determines whether there is overlap, not the larger.
1712
1713MemDependencyChecker analyzer;
1714StmtPtr stmt = For::make(
1715x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6})));
1716stmt->accept(&analyzer);
1717
1718ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1719}
1720
1721{
1722/* for (int x = 0; x < 10; x++) {
1723* A[x * 2] = A[x * 3 + 1];
1724* }
1725*/
1726
1727// If they have strides with no common multiple > 1, they overlap.
1728MemDependencyChecker analyzer;
1729StmtPtr stmt = For::make(
1730x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1})));
1731stmt->accept(&analyzer);
1732
1733ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1734}
1735
1736{
1737/* for (int x = 0; x < 10; x++) {
1738* A[x] = A[x + 10];
1739* }
1740*/
1741
1742// If the offset is greater than the size of the loop, they can't overlap.
1743
1744MemDependencyChecker analyzer;
1745StmtPtr stmt =
1746For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10})));
1747stmt->accept(&analyzer);
1748
1749ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1750}
1751
1752{
1753/* for (int x = 0; x < 10; x++) {
1754* A[x] = A[9 - x];
1755* }
1756*/
1757
1758// If they have different execution orders they may overlap.
1759MemDependencyChecker analyzer;
1760StmtPtr stmt = For::make(
1761x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x})));
1762stmt->accept(&analyzer);
1763
1764ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1765}
1766
1767{
1768/* for (int x = 0; x < 10; x++) {
1769* A[x * 2] = A[19 - x * 2];
1770* }
1771*/
1772
1773// Or they may not, depending on their start offset and strides.
1774MemDependencyChecker analyzer;
1775StmtPtr stmt = For::make(
1776x,
17770,
177810,
1779Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2})));
1780stmt->accept(&analyzer);
1781
1782ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1783}
1784
1785{
1786/* for (int x = 0; x < 10; x++) {
1787* A[x / 2] = A[x / 2];
1788* }
1789*/
1790
1791// If the stride is not monotonic, they overlap.
1792
1793MemDependencyChecker analyzer;
1794StmtPtr stmt =
1795For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2})));
1796stmt->accept(&analyzer);
1797
1798ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1799}
1800
1801{
1802/* for (int x = 0; x < 10; x++) {
1803* A[x / 2] = A[x / 2] + 1;
1804* }
1805*/
1806
1807// If the stride is not monotonic, they overlap - even with an offset.
1808MemDependencyChecker analyzer;
1809StmtPtr stmt = For::make(
1810x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1})));
1811stmt->accept(&analyzer);
1812
1813ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1814}
1815
1816{
1817/* for (int x = 0; x < 10; x++) {
1818* A[x % 2] = A[x % 2];
1819* }
1820*/
1821
1822// Mod too...
1823
1824analysis::MemDependencyChecker analyzer;
1825StmtPtr stmt = For::make(
1826x,
18270,
182810,
1829Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)})));
1830stmt->accept(&analyzer);
1831
1832ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1833}
1834
1835{
1836/* for (int x = y; x < z; x++) {
1837* A[x] = A[x + 1];
1838* }
1839*/
1840
1841// Still works with symbolic loop extents.
1842
1843{
1844MemDependencyChecker analyzer;
1845StmtPtr stmt =
1846For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
1847stmt->accept(&analyzer);
1848
1849ASSERT_TRUE(isSelfDependent(analyzer.getHistory()));
1850}
1851
1852{
1853MemDependencyChecker analyzer;
1854analyzer.allowLoopExecutionOrderAnalysis();
1855StmtPtr stmt =
1856For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1})));
1857stmt->accept(&analyzer);
1858
1859ASSERT_FALSE(isSelfDependent(analyzer.getHistory()));
1860}
1861}
1862}
1863
1864// Verify that a strided access still works.
1865// TODO: actually this only works because of the size of the ranges, revisit
1866// this test after strided overlap is implemented.
1867TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
1868BufHandle a("A", {20}, kInt);
1869BufHandle b("B", {20}, kInt);
1870VarHandle x("x", kInt);
1871VarHandle y("y", kInt);
1872
1873using namespace analysis;
1874MemDependencyChecker analyzer({a.node()}, {b.node()});
1875StmtPtr stmt = Block::make(
1876{For::make(
1877x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
1878For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2})))
1879
1880});
1881stmt->accept(&analyzer);
1882
1883// Sanity check output depends on input.
1884ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
1885
1886// Output has 2 dependencies... the store in each loop.
1887auto outputAccess = analyzer.output(b.node());
1888ASSERT_EQ(outputAccess->dependencies().size(), 2);
1889}
1890
1891/* TODO(nickg) - this test will fail due to the lack of stride math in Bound
1892TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) {
1893BufHandle a("A", {20}, kInt);
1894BufHandle b("B", {20}, kInt);
1895BufHandle c("C", {10}, kInt);
1896VarHandle x("x", kInt);
1897VarHandle y("y", kInt);
1898
1899{
1900analysis::MemDependencyChecker analyzer({a.node()}, {c.node()});
1901StmtPtr stmt = Block::make(
1902{For::make(
1903x,
19040,
190510,
1906Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))),
1907For::make(
1908x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))),
1909For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))
1910
1911});
1912stmt->accept(&analyzer);
1913
1914std::cout << *stmt << "\n";
1915for (auto& wi : analyzer.getHistory()) {
1916wi->print();
1917}
1918}
1919}*/
1920
1921// analysis on Stmts using Cond.
1922TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) {
1923BufHandle a("A", {10}, kInt);
1924BufHandle b("B", {10}, kInt);
1925BufHandle c("C", {10}, kInt);
1926VarHandle x("x", kInt);
1927VarHandle y("y", kInt);
1928
1929using namespace analysis;
1930
1931{
1932/* for (int x = 0; x < 10; x++) {
1933* C[x] = A[x];
1934* }
1935* if (y<5 ? 1 : 0) {
1936* C[0] = (B[0]) + 1;
1937* } else {
1938* C[0] = (B[1]) + 1;
1939* }
1940*/
1941
1942// Future usages may depend on accesses in both branches of a condition.
1943
1944MemDependencyChecker analyzer({a, b}, {c});
1945StmtPtr stmt = Block::make(
1946{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
1947Cond::make(
1948CompareSelect::make(y, 5, CompareSelectOperation::kLT),
1949Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)),
1950Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))});
1951
1952stmt->accept(&analyzer);
1953
1954// Output C should have 3 dependencies, each of the three stores.
1955auto outputAccess = analyzer.output(c.node());
1956ASSERT_NE(outputAccess, nullptr);
1957ASSERT_EQ(outputAccess->dependencies().size(), 3);
1958
1959// C depends indirectly on A and B.
1960ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
1961ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
1962}
1963
1964{
1965/* for (int x = 0; x < 10; x++) {
1966* C[x] = A[x];
1967* }
1968* if (y<5 ? 1 : 0) {
1969* for (int x = 0; x < 10; x++) {
1970* C[x] = B[x];
1971* }
1972* } else {
1973* for (int x = 0; x < 10; x++) {
1974* C[x] = (B[x]) + 1;
1975* }
1976* }
1977*/
1978
1979// Future usages may depend on accesses in both branches of a condition.
1980
1981MemDependencyChecker analyzer({a, b}, {c});
1982StmtPtr stmt = Block::make(
1983{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
1984Cond::make(
1985CompareSelect::make(y, 5, CompareSelectOperation::kLT),
1986For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))),
1987For::make(
1988x,
19890,
199010,
1991Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
1992
1993stmt->accept(&analyzer);
1994
1995// Output C should have 3 dependencies, each of the three stores.
1996auto outputAccess = analyzer.output(c.node());
1997ASSERT_NE(outputAccess, nullptr);
1998ASSERT_EQ(outputAccess->dependencies().size(), 3);
1999
2000// TODO(nickg): actually since the true and false branch cover the total
2001// range of the first store this should have 2 dependencies, but we don't
2002// do that yet.
2003
2004// C depends indirectly on A and B.
2005ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2006ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2007}
2008
2009{
2010/* for (int x = 0; x < 10; x++) {
2011* C[x] = A[x];
2012* }
2013* if (y<5 ? 1 : 0) {
2014* for (int x = 0; x < 10; x++) {
2015* C[x] = (B[x]) + 1;
2016* }
2017* }
2018*/
2019
2020// Only has true branch.
2021
2022MemDependencyChecker analyzer({a, b}, {c});
2023StmtPtr stmt = Block::make(
2024{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2025Cond::make(
2026CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2027For::make(
2028x,
20290,
203010,
2031Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))),
2032nullptr)});
2033
2034stmt->accept(&analyzer);
2035
2036// Output C should have 3 dependencies, each of the three stores.
2037auto outputAccess = analyzer.output(c.node());
2038ASSERT_NE(outputAccess, nullptr);
2039ASSERT_EQ(outputAccess->dependencies().size(), 2);
2040
2041// C depends indirectly on A and B.
2042ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2043ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2044}
2045
2046{
2047/* for (int x = 0; x < 10; x++) {
2048* C[x] = A[x];
2049* }
2050* if (y<5 ? 1 : 0) {
2051* } else {
2052* for (int x = 0; x < 10; x++) {
2053* C[x] = (B[x]) + 1;
2054* }
2055* }
2056*/
2057
2058// Only has false branch.
2059
2060MemDependencyChecker analyzer({a, b}, {c});
2061StmtPtr stmt = Block::make(
2062{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2063Cond::make(
2064CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2065nullptr,
2066For::make(
2067x,
20680,
206910,
2070Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))});
2071
2072stmt->accept(&analyzer);
2073
2074// Output C should have 3 dependencies, each of the three stores.
2075auto outputAccess = analyzer.output(c.node());
2076ASSERT_NE(outputAccess, nullptr);
2077ASSERT_EQ(outputAccess->dependencies().size(), 2);
2078
2079// C depends indirectly on A and B.
2080ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2081ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2082}
2083
2084{
2085/* for (int x = 0; x < 10; x++) {
2086* C[x] = A[x];
2087* }
2088* if (C[0]<5 ? 1 : 0) {
2089* C[0] = 5;
2090* }
2091*/
2092
2093// Cond's Condition depends on a previous access.
2094
2095MemDependencyChecker analyzer({a}, {c});
2096StorePtr initStore = Store::make(c, {x}, Load::make(a, {x}));
2097ExprHandle conditionalLoad = Load::make(c, {0});
2098StmtPtr stmt = Block::make(
2099{For::make(x, 0, 10, initStore),
2100Cond::make(
2101CompareSelect::make(
2102conditionalLoad, 5, CompareSelectOperation::kLT),
2103Store::make(c, {0}, 5),
2104nullptr)});
2105
2106stmt->accept(&analyzer);
2107
2108ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2109
2110ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore));
2111ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node()));
2112ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node()));
2113}
2114}
2115
2116// Stmts using IfThenElse.
2117TEST(MemDependency, MemDependencyCheckerIfThenElse) {
2118BufHandle a("A", {10}, kInt);
2119BufHandle b("B", {10}, kInt);
2120BufHandle c("C", {10}, kInt);
2121VarHandle x("x", kInt);
2122VarHandle y("y", kInt);
2123
2124using namespace analysis;
2125
2126{
2127/* for (int x = 0; x < 10; x++) {
2128* C[x] = A[x];
2129* }
2130* C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1;
2131*/
2132
2133// Future usages may depend on accesses in both branches of a condition.
2134
2135MemDependencyChecker analyzer({a, b}, {c});
2136StorePtr ifStore = Store::make(
2137c,
2138{0},
2139IfThenElse::make(
2140CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2141Add::make(Load::make(b, {0}), 1),
2142Add::make(Load::make(b, {1}), 1)));
2143StmtPtr stmt = Block::make(
2144{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2145ifStore});
2146
2147stmt->accept(&analyzer);
2148
2149// Output C should have 2 dependencies, each of the two stores.
2150auto outputAccess = analyzer.output(c.node());
2151ASSERT_NE(outputAccess, nullptr);
2152ASSERT_EQ(outputAccess->dependencies().size(), 2);
2153
2154// Now we need to check the Store containing the IfThenElse.
2155auto ifStoreAccess = analyzer.accessFor(ifStore);
2156
2157// It should have 2 dependencies.
2158ASSERT_EQ(ifStoreAccess->dependencies().size(), 2);
2159
2160// C depends indirectly on A and B.
2161ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2162ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2163}
2164
2165{
2166/* for (int x = 0; x < 10; x++) {
2167* C[x] = A[x];
2168* }
2169* C[0] = (y < 5 ? (B[0]) + 1 : 42;
2170*/
2171
2172// If the load appears in only one side of an IfThenElse the output may be
2173// dependent on it.
2174
2175MemDependencyChecker analyzer({a, b}, {c});
2176StorePtr ifStore = Store::make(
2177c,
2178{0},
2179IfThenElse::make(
2180CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2181Add::make(Load::make(b, {0}), 1),
218242));
2183StmtPtr stmt = Block::make(
2184{For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))),
2185ifStore});
2186
2187stmt->accept(&analyzer);
2188
2189// C depends indirectly on A and B.
2190ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2191ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2192}
2193
2194{
2195/* for (int x = 0; x < 10; x++) {
2196* C[x] = (x < 5 ? B[x] : A[x];
2197* }
2198*/
2199
2200// In this case C is dependent on both A and B.
2201
2202// TODO: in cases like this it would be possible to split the range of B
2203// into two bounds, one dependent on A and one dependent on B. We'd need to
2204// examine conditions relative to previously encountered loop variables. I'm
2205// uncertain if this would be helpful.
2206
2207MemDependencyChecker analyzer({a, b}, {c});
2208StorePtr ifStore = Store::make(
2209c,
2210{0},
2211IfThenElse::make(
2212CompareSelect::make(y, 5, CompareSelectOperation::kLT),
2213Load::make(b, {x}),
2214Load::make(a, {x})));
2215StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)});
2216
2217stmt->accept(&analyzer);
2218
2219// C depends indirectly on A and B.
2220ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2221ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2222}
2223}
2224
2225// Cutting a loop with single elem writes
2226TEST(MemDependency, MemDependencyCheckerCutLoop) {
2227BufHandle a("A", {10}, kInt);
2228BufHandle b("B", {10}, kInt);
2229VarHandle x("x", kInt);
2230
2231using namespace analysis;
2232
2233{
2234/* for (int x = 0; x < 10; x++) {
2235* B[x] = A[x];
2236* }
2237* B[5] = 100;
2238*/
2239
2240// Cutting a loop with single element writes.
2241
2242MemDependencyChecker analyzer({a}, {b});
2243StmtPtr stmt = Block::make(
2244{For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))),
2245Store::make(b, {5}, 100)});
2246
2247stmt->accept(&analyzer);
2248
2249// Output depends on input.
2250ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2251
2252// Output has 2 dependencies.
2253auto outputAccess = analyzer.output(b.node());
2254ASSERT_NE(outputAccess, nullptr);
2255ASSERT_EQ(outputAccess->dependencies().size(), 2);
2256}
2257
2258{
2259/* for (int x = 0; x < 10; x++) {
2260* B[x] = A[x];
2261* }
2262* for (int x = 4; x < 7; x++) {
2263* B[x] = B[x] + 3;
2264* }
2265* B[5] = 100;
2266* B[6] = 101;
2267* B[7] = 102;
2268*/
2269
2270// Cutting a loop with a smaller loop but then totally overlap that second
2271// loop with one element writes.
2272
2273MemDependencyChecker analyzer({a}, {b});
2274ForPtr firstLoop =
2275For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})));
2276StorePtr secondStore =
2277Store::make(b, {x}, Add::make(Load::make(b, {x}), 1));
2278ForPtr secondLoop = For::make(x, 4, 7, secondStore);
2279
2280StmtPtr stmt = Block::make(
2281{firstLoop,
2282secondLoop,
2283Store::make(b, {4}, 100),
2284Store::make(b, {5}, 101),
2285Store::make(b, {6}, 102)});
2286
2287stmt->accept(&analyzer);
2288
2289// Output depends on input.
2290ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2291
2292// Output has 4 dependencies.
2293auto outputAccess = analyzer.output(b.node());
2294ASSERT_NE(outputAccess, nullptr);
2295ASSERT_EQ(outputAccess->dependencies().size(), 4);
2296
2297// Second loop depends on first loop.
2298ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop));
2299
2300// Output does not depend on second loop or store.
2301ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop));
2302ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore));
2303}
2304}
2305
2306// Dynamic shapes (load in indices).
2307TEST(MemDependency, MemDependencyCheckerDynamicShapes) {
2308BufHandle a("A", {100}, kInt);
2309BufHandle b("B", {100}, kInt);
2310BufHandle c("C", {100}, kInt);
2311VarHandle x("x", kInt);
2312
2313using namespace analysis;
2314
2315auto CB = [](ExprHandle s, ExprHandle e) {
2316return Bound(s.node(), e.node());
2317};
2318
2319auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
2320return indexBoundsEquals(x, y);
2321};
2322
2323{
2324/* for (int x = 0; x < B[0]; x++) {
2325* C[x] = A[x];
2326* }
2327*/
2328MemDependencyChecker analyzer({a, b}, {c});
2329StmtPtr stmt = Block::make({For::make(
2330x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))});
2331
2332stmt->accept(&analyzer);
2333
2334/* 0. Input: B[(0, 99)] - dependents: 2
2335* 1. Input: A[(0, 99)] - dependents: 3
2336* 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4
2337* 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4
2338* 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5
2339* 5. Output: C[(0, 99)] - depends on: 4
2340*/
2341
2342// Output dependent on A input.
2343ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2344// Also dependent on B input to determine the size of the region written.
2345ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2346
2347auto history = analyzer.getHistory();
2348ASSERT_EQ(history.size(), 6);
2349
2350// The accesses in the loop depend on the load in the stop condition.
2351ASSERT_TRUE(history[4]->hasDependency(history[2]));
2352ASSERT_TRUE(history[3]->hasDependency(history[2]));
2353
2354// Make a load from B to compare against.
2355ExprHandle loadFromB = Load::make(b, {0});
2356
2357ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)}));
2358ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)}));
2359}
2360
2361{
2362/* for (int x = B[0]; x < B[1]; x++) {
2363* C[x] = A[x];
2364* }
2365*/
2366MemDependencyChecker analyzer({a, b}, {c});
2367StmtPtr stmt = Block::make({For::make(
2368x,
2369Load::make(b, {0}),
2370Load::make(b, {1}),
2371Store::make(c, {x}, Load::make(a, {x})))});
2372
2373stmt->accept(&analyzer);
2374
2375/* 0. Input: B[(0, 99)] - dependents: 2 3
2376* 1. Input: A[(0, 99)] - dependents: 4
2377* 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5
2378* 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5
2379* 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5
2380* 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6
2381* 6. Output: C[(0, 99)] - depends on: 5
2382*/
2383
2384// Sanity check output depends on input.
2385ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2386ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2387
2388auto history = analyzer.getHistory();
2389ASSERT_EQ(history.size(), 7);
2390
2391// The accesses in the loop depend on the load in the start condition.
2392ASSERT_TRUE(history[5]->hasDependency(history[2]));
2393ASSERT_TRUE(history[4]->hasDependency(history[2]));
2394
2395// also the stop condition.
2396ASSERT_TRUE(history[5]->hasDependency(history[3]));
2397ASSERT_TRUE(history[4]->hasDependency(history[3]));
2398
2399// Make loads from B to compare against.
2400ExprHandle loadFromB0 = Load::make(b, {0});
2401ExprHandle loadFromB1 = Load::make(b, {1});
2402ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
2403ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)}));
2404}
2405
2406{
2407/* for (int x = 0; x < 10; x++) {
2408* C[x] = A[B[x]];
2409* }
2410*/
2411MemDependencyChecker analyzer({a, b}, {c});
2412StmtPtr stmt = Block::make({For::make(
2413x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))});
2414
2415stmt->accept(&analyzer);
2416
2417/* 0. Input: B[(0, 99)] - dependents: 2
2418* 1. Input: A[(0, 99)] - dependents: 3
2419* 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4
2420* 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4
2421* 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5
2422* 5. Output: C[(0, 99)] - depends on: 4
2423*/
2424
2425// Sanity check output depends on input.
2426ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2427ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2428
2429auto history = analyzer.getHistory();
2430ASSERT_EQ(history.size(), 6);
2431
2432// The store depends on both loads, the load of A depends on the load of B.
2433ASSERT_TRUE(history[4]->hasDependency(history[2]));
2434ASSERT_TRUE(history[4]->hasDependency(history[3]));
2435
2436ASSERT_TRUE(history[3]->hasDependency(history[2]));
2437
2438// The loads in the indices depend on the relevant input buffer.
2439ASSERT_TRUE(history[3]->hasDependency(history[1]));
2440ASSERT_TRUE(history[2]->hasDependency(history[0]));
2441
2442// The load from B has the loop bounds.
2443ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2444
2445// The load from A has bounds B[0] to B[9].
2446ExprHandle loadFromB0 = Load::make(b, {0});
2447ExprHandle loadFromB9 = Load::make(b, {9});
2448ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)}));
2449}
2450
2451{
2452/* for (int x = 0; x < 10; x++) {
2453* C[B[x]] = A[x];
2454* }
2455*/
2456MemDependencyChecker analyzer({a, b}, {c});
2457StmtPtr stmt = Block::make({For::make(
2458x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))});
2459
2460stmt->accept(&analyzer);
2461
2462/* 0. Input: B[(0, 99)] - dependents: 3
2463* 1. Input: A[(0, 99)] - dependents: 2
2464* 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4
2465* 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4
2466* 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5
2467* 5. Output: C[(0, 99)] - depends on: 4
2468*/
2469// Sanity check output depends on input.
2470ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2471ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2472
2473auto history = analyzer.getHistory();
2474ASSERT_EQ(history.size(), 6);
2475
2476// The store depends on both loads, neither load is dependent.
2477ASSERT_TRUE(history[4]->hasDependency(history[2]));
2478ASSERT_TRUE(history[4]->hasDependency(history[3]));
2479
2480ASSERT_FALSE(history[3]->hasDependency(history[2]));
2481ASSERT_FALSE(history[2]->hasDependency(history[3]));
2482
2483// The loads each depend on their relevant input. (but accesses are in a
2484// different order than the last case).
2485ASSERT_TRUE(history[3]->hasDependency(history[0]));
2486ASSERT_TRUE(history[2]->hasDependency(history[1]));
2487
2488// The load from B has the loop bounds.
2489ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)}));
2490
2491// And so does the load from A.
2492ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2493}
2494
2495{
2496/* for (int x = 0; x < 10; x++) {
2497* C[B[A[x]]] = x;
2498* }
2499*/
2500MemDependencyChecker analyzer({a, b}, {c});
2501StmtPtr stmt = Block::make({For::make(
2502x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))});
2503
2504stmt->accept(&analyzer);
2505
2506/* 0. Input: B[(0, 99)] - dependents: 3
2507* 1. Input: A[(0, 99)] - dependents: 2
2508* 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4
2509* 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4
2510* 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5
2511* 5. Output: C[(0, 99)] - depends on: 4
2512*/
2513
2514// Sanity check output depends on input.
2515ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node()));
2516ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node()));
2517
2518auto history = analyzer.getHistory();
2519ASSERT_EQ(history.size(), 6);
2520
2521// The store depends on both loads.
2522ASSERT_TRUE(history[4]->hasDependency(history[2]));
2523ASSERT_TRUE(history[4]->hasDependency(history[3]));
2524
2525// The outer load depends on the inner.
2526ASSERT_TRUE(history[3]->hasDependency(history[2]));
2527
2528// The loads each depend on their relevant input. (but accesses are in a
2529// different order than the last case).
2530ASSERT_TRUE(history[3]->hasDependency(history[0]));
2531ASSERT_TRUE(history[2]->hasDependency(history[1]));
2532
2533// The load from A has the loop bounds.
2534ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)}));
2535// The load from B as bounds A[0] to A[9].
2536ExprHandle loadFromA0 = Load::make(a, {0});
2537ExprHandle loadFromA9 = Load::make(a, {9});
2538ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)}));
2539
2540// The store has bounds of B[A[0]] to B[A[9]].
2541ExprHandle loadFromBA0 = Load::make(b, {loadFromA0});
2542ExprHandle loadFromBA9 = Load::make(b, {loadFromA9});
2543ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)}));
2544}
2545}
2546
2547// Verify multi dimensional bounds work.
2548TEST(MemDependency, MemDependencyCheckerMultiDim) {
2549int M = 10, N = 9, K = 12;
2550BufHandle a("A", {M, N, K}, kInt);
2551BufHandle b("B", {M, N, K}, kInt);
2552BufHandle c("C", {M, K}, kInt);
2553VarHandle x("x", kInt);
2554VarHandle y("y", kInt);
2555VarHandle z("z", kInt);
2556
2557using namespace analysis;
2558
2559auto CB = [](ExprHandle s, ExprHandle e) {
2560return Bound(s.node(), e.node());
2561};
2562
2563auto EQ = [](const IndexBounds& x, const IndexBounds& y) {
2564return indexBoundsEquals(x, y);
2565};
2566
2567{
2568/* for (int x = 0; x < 10; x++) {
2569* for (int y = 0; y < 9; y++) {
2570* for (int z = 0; z < 12; z++) {
2571* B[x, y, z] = A[x, y, z];
2572* }
2573* }
2574* }
2575*/
2576// Full range.
2577
2578MemDependencyChecker analyzer({a}, {b});
2579StmtPtr stmt = Block::make({For::make(
2580x,
25810,
2582M,
2583For::make(
2584y,
25850,
2586N,
2587For::make(
2588z,
25890,
2590K,
2591Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
2592
2593stmt->accept(&analyzer);
2594
2595// Sanity test: Output depends on input.
2596ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2597
2598// 4 accesses: input, load, store, output.
2599auto history = analyzer.getHistory();
2600ASSERT_EQ(history.size(), 4);
2601
2602// Simple chain from input to output.
2603ASSERT_TRUE(history[3]->hasDependency(history[2]));
2604ASSERT_TRUE(history[2]->hasDependency(history[1]));
2605ASSERT_TRUE(history[1]->hasDependency(history[0]));
2606
2607ASSERT_TRUE(
2608EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2609ASSERT_TRUE(
2610EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2611}
2612
2613{
2614/* for (int x = 0; x < 5; x++) {
2615* for (int y = 0; y < 5; y++) {
2616* for (int z = 0; z < 5; z++) {
2617* B[x, y, z] = A[x, y, z];
2618* }
2619* }
2620* }
2621*/
2622// Partial range.
2623
2624MemDependencyChecker analyzer({a}, {b});
2625StmtPtr stmt = Block::make({For::make(
2626x,
26270,
26285,
2629For::make(
2630y,
26310,
26325,
2633For::make(
2634z,
26350,
26365,
2637Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))});
2638
2639stmt->accept(&analyzer);
2640
2641// Sanity test: Output depends on input.
2642ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2643
2644// 4 accesses: input, load, store, output.
2645auto history = analyzer.getHistory();
2646ASSERT_EQ(history.size(), 4);
2647
2648// Simple chain from input to output.
2649ASSERT_TRUE(history[3]->hasDependency(history[2]));
2650ASSERT_TRUE(history[2]->hasDependency(history[1]));
2651ASSERT_TRUE(history[1]->hasDependency(history[0]));
2652
2653ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
2654ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)}));
2655}
2656
2657{
2658/* for (int x = 0; x < 10; x++) {
2659* for (int y = 0; y < 12; y++) {
2660* B[x, 0, y] = A[x, 0, y];
2661* }
2662* }
2663*/
2664
2665// Partial loops.
2666
2667MemDependencyChecker analyzer({a}, {b});
2668StmtPtr stmt = Block::make({For::make(
2669x,
26700,
2671N,
2672For::make(
2673y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))});
2674
2675stmt->accept(&analyzer);
2676
2677// Sanity test: Output depends on input.
2678ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2679
2680// 4 accesses: input, load, store, output.
2681auto history = analyzer.getHistory();
2682ASSERT_EQ(history.size(), 4);
2683
2684// Simple chain from input to output.
2685ASSERT_TRUE(history[3]->hasDependency(history[2]));
2686ASSERT_TRUE(history[2]->hasDependency(history[1]));
2687ASSERT_TRUE(history[1]->hasDependency(history[0]));
2688
2689ASSERT_TRUE(
2690EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
2691ASSERT_TRUE(
2692EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)}));
2693}
2694
2695{
2696/* for (int x = 0; x < 10; x++) {
2697* for (int y = 0; y < 100; y++) {
2698* for (int z = 0; z < 12; z++) {
2699* B[x, 0, z] = (A[x, 0, z]) + (C[x, z]);
2700* }
2701* }
2702* }
2703*/
2704
2705// Loops that don't correspond to an index, bufs with different
2706// dimensionality.
2707
2708MemDependencyChecker analyzer({a, c}, {b});
2709StmtPtr stmt = Block::make({For::make(
2710x,
27110,
2712M,
2713For::make(
2714y,
27150,
2716100,
2717For::make(
2718z,
27190,
2720K,
2721Store::make(
2722b,
2723{x, 0, z},
2724Add::make(
2725Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))});
2726
2727stmt->accept(&analyzer);
2728
2729// Sanity test: Output depends on both inputs.
2730ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2731ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node()));
2732
2733// 6 accesses: 2 inputs, 2 loads, store, output.
2734auto history = analyzer.getHistory();
2735ASSERT_EQ(history.size(), 6);
2736
2737// Simple chain from input to output over the A buf.
2738// history[0] is the C input, history[3] is the load from C.
2739ASSERT_TRUE(history[5]->hasDependency(history[4]));
2740ASSERT_TRUE(history[4]->hasDependency(history[2]));
2741ASSERT_TRUE(history[2]->hasDependency(history[1]));
2742// The store also depends on the load from the C input.
2743ASSERT_TRUE(history[4]->hasDependency(history[3]));
2744ASSERT_TRUE(history[3]->hasDependency(history[0]));
2745
2746// A Buf accesses.
2747ASSERT_TRUE(
2748EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
2749ASSERT_TRUE(
2750EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)}));
2751
2752// C buf access.
2753ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)}));
2754}
2755
2756{
2757/* for (int x = 0; x < 9; x++) {
2758* for (int y = 0; y < 10; y++) {
2759* for (int z = 0; z < 12; z++) {
2760* B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]);
2761* }
2762* }
2763* }
2764*/
2765// Multi-dim reductions.
2766
2767MemDependencyChecker analyzer({a}, {b});
2768StmtPtr stmt = Block::make({For::make(
2769x,
27700,
2771M,
2772For::make(
2773y,
27740,
2775N,
2776For::make(
2777z,
27780,
2779K,
2780Store::make(
2781b,
2782{x, 0, 0},
2783Add::make(
2784Load::make(b, {x, y, z}),
2785Load::make(a, {x, y, z}))))))});
2786
2787stmt->accept(&analyzer);
2788
2789// Sanity test: Output depends on input.
2790ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node()));
2791
2792// 4 accesses: input, 2 loads, store, output.
2793auto history = analyzer.getHistory();
2794ASSERT_EQ(history.size(), 5);
2795
2796// Simple chain from input to output.
2797ASSERT_TRUE(history[4]->hasDependency(history[3]));
2798ASSERT_TRUE(history[3]->hasDependency(history[2]));
2799ASSERT_TRUE(history[3]->hasDependency(history[1]));
2800ASSERT_TRUE(history[2]->hasDependency(history[0]));
2801
2802// The load from B depends on the store to B.
2803ASSERT_TRUE(history[1]->hasDependency(history[3]));
2804
2805ASSERT_TRUE(
2806EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2807ASSERT_TRUE(
2808EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)}));
2809ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)}));
2810}
2811}
2812
2813// Various tests using the external Compute/Reduce API.
2814TEST(MemDependency, MemDependencyCheckerComputeAPI) {
2815using namespace analysis;
2816
2817/* for (int m = 0; m < 4; m++) {
2818* for (int n = 0; n < 5; n++) {
2819* for (int k = 0; k < 6; k++) {
2820* broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]);
2821* }
2822* }
2823* }
2824* for (int m_1 = 0; m_1 < 4; m_1++) {
2825* for (int n_1 = 0; n_1 < 5; n_1++) {
2826* for (int k_1 = 0; k_1 < 6; k_1++) {
2827* d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1);
2828* }
2829* }
2830* }
2831*/
2832
2833// Can determine if 2 loops created by Compute are dependent.
2834BufHandle a_buf("a", {4, 5}, kFloat);
2835BufHandle b_buf("b", {5, 6}, kFloat);
2836Tensor c = Compute(
2837"broadcast_add",
2838{4, 5, 6},
2839[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2840return a_buf.load(m, n) + b_buf.load(n, k);
2841});
2842Tensor d = Compute(
2843"d",
2844{4, 5, 6},
2845[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2846return c.load(m, n, k) + 1;
2847});
2848
2849LoopNest l({d}, {c, d});
2850
2851MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
2852
2853l.root_stmt()->accept(&analyzer);
2854
2855// Sanity test: Output depends on input.
2856ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
2857ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
2858
2859// Second loop depends on first loop.
2860auto c_loop = l.getLoopStmtsFor(c)[0];
2861auto d_loop = l.getLoopStmtsFor(d)[0];
2862ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
2863}
2864
2865TEST(MemDependency, MemDependencyCheckerComputeInline) {
2866using namespace analysis;
2867
2868/* for (int m = 0; m < 4; m++) {
2869* for (int n = 0; n < 5; n++) {
2870* for (int k = 0; k < 6; k++) {
2871* d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1);
2872* }
2873* }
2874* }
2875*/
2876
2877// Check inlining affects the number of accesses returned.
2878
2879BufHandle a_buf("a", {4, 5}, kFloat);
2880BufHandle b_buf("b", {5, 6}, kFloat);
2881Tensor c = Compute(
2882"broadcast_add",
2883{4, 5, 6},
2884[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2885return a_buf.load(m, n) + b_buf.load(n, k);
2886});
2887Tensor d = Compute(
2888"d",
2889{4, 5, 6},
2890[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2891return c.load(m, n, k) + 1;
2892});
2893
2894LoopNest l({d}, {c, d});
2895l.computeInline(c.buf());
2896
2897MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()});
2898l.root_stmt()->accept(&analyzer);
2899
2900// Sanity test: Output depends on input.
2901ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node()));
2902ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node()));
2903
2904// broadcast_add tensor should not appear in trace at all.
2905for (auto& wi : analyzer.getHistory()) {
2906ASSERT_NE(wi->var(), c.buf()->base_handle());
2907}
2908}
2909
2910TEST(MemDependency, MemDependencyCheckerComputeSplit) {
2911using namespace analysis;
2912// Split an axis, so the number of loops != the number of dimensions.
2913
2914BufHandle a_buf("a", {4, 5}, kFloat);
2915BufHandle b_buf("b", {5, 6}, kFloat);
2916Tensor c = Compute(
2917"broadcast_add",
2918{4, 5, 6},
2919[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2920return a_buf.load(m, n) + b_buf.load(n, k);
2921});
2922
2923LoopNest l({c});
2924
2925MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
2926l.root_stmt()->accept(&analyzer_before);
2927
2928l.splitWithTail(l.getLoopStmtsFor(c)[0], 2);
2929
2930MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
2931StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2932stmt->accept(&analyzer_after);
2933
2934// Splitting should not change accesses at all.
2935auto history_before = analyzer_before.getHistory();
2936auto history_after = analyzer_after.getHistory();
2937
2938ASSERT_EQ(history_before.size(), history_after.size());
2939
2940for (size_t i = 0; i < history_before.size(); ++i) {
2941ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
2942ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
2943ASSERT_EQ(
2944history_before[i]->bounds().size(), history_after[i]->bounds().size());
2945ASSERT_TRUE(indexBoundsEquals(
2946history_before[i]->bounds(), history_after[i]->bounds()));
2947ASSERT_EQ(
2948history_before[i]->dependencies().size(),
2949history_after[i]->dependencies().size());
2950ASSERT_EQ(
2951history_before[i]->dependents().size(),
2952history_after[i]->dependents().size());
2953}
2954}
2955
2956TEST(MemDependency, MemDependencyCheckerComputeReorder) {
2957using namespace analysis;
2958// Reorder an axis, so the loop order doesn't match the indexing order.
2959
2960BufHandle a_buf("a", {4, 5}, kFloat);
2961BufHandle b_buf("b", {5, 6}, kFloat);
2962Tensor c = Compute(
2963"broadcast_add",
2964{4, 5, 6},
2965[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
2966return a_buf.load(m, n) + b_buf.load(n, k);
2967});
2968
2969LoopNest l({c});
2970
2971MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()});
2972l.root_stmt()->accept(&analyzer_before);
2973
2974auto loops = l.getLoopStmtsFor(c);
2975l.reorderAxis(loops[0], loops[1]);
2976
2977MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()});
2978StmtPtr stmt = IRSimplifier::simplify(l.root_stmt());
2979stmt->accept(&analyzer_after);
2980
2981// Reordering should not change accesses at all.
2982auto history_before = analyzer_before.getHistory();
2983auto history_after = analyzer_after.getHistory();
2984
2985ASSERT_EQ(history_before.size(), history_after.size());
2986
2987for (size_t i = 0; i < history_before.size(); ++i) {
2988ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
2989ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
2990ASSERT_EQ(
2991history_before[i]->bounds().size(), history_after[i]->bounds().size());
2992ASSERT_TRUE(indexBoundsEquals(
2993history_before[i]->bounds(), history_after[i]->bounds()));
2994ASSERT_EQ(
2995history_before[i]->dependencies().size(),
2996history_after[i]->dependencies().size());
2997ASSERT_EQ(
2998history_before[i]->dependents().size(),
2999history_after[i]->dependents().size());
3000}
3001}
3002
3003TEST(MemDependency, MemDependencyCheckerComputeReduce) {
3004using namespace analysis;
3005/* for (int l2 = 0; l2 < 2; l2++) {
3006* for (int n1 = 0; n1 < 3; n1++) {
3007* for (int m1 = 0; m1 < 6; m1++) {
3008* scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]);
3009* }
3010* }
3011* }
3012* for (int l1 = 0; l1 < 2; l1++) {
3013* sum[l1] = float(0);
3014* for (int n1_1 = 0; n1_1 < 3; n1_1++) {
3015* for (int m1_1 = 0; m1_1 < 6; m1_1++) {
3016* sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)),
3017* out_args={l1}, reduce_args={n1, m1});
3018* }
3019* }
3020* }
3021*/
3022
3023// Can determine dependencies of a Reduction.
3024
3025BufHandle a("a", {2, 3, 6}, kFloat);
3026BufHandle b("b", {2, 3, 6}, kFloat);
3027
3028Tensor c = Compute(
3029"scale",
3030{2, 3, 6},
3031[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
3032return b.load(l, n, m) * a.load(l, n, m);
3033});
3034Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6});
3035LoopNest l({d}, {c, d});
3036
3037MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()});
3038
3039l.root_stmt()->accept(&analyzer);
3040
3041// Sanity test: Output depends on input.
3042ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node()));
3043ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node()));
3044
3045// Second loop depends on first loop.
3046auto c_loop = l.getLoopStmtsFor(c)[0];
3047auto d_loop = l.getLoopStmtsFor(d)[0];
3048ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop));
3049
3050// Reduction depends on both inputs.
3051auto reduces = NodeFinder<ReduceOp>::find(l.root_stmt());
3052ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node()));
3053ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node()));
3054}
3055
3056TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
3057int M = 1024;
3058int N = 1024;
3059int K = 2048;
3060using namespace analysis;
3061
3062BufHandle AP("A", {M, K}, kFloat);
3063BufHandle BP("B", {K, N}, kFloat);
3064Tensor CT = Reduce(
3065"gemm",
3066{M, N},
3067Sum(),
3068[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
3069return AP.load(m, k) * BP.load(k, n);
3070},
3071{K});
3072LoopNest loop({CT});
3073
3074{
3075auto const& loops = loop.getLoopStmtsFor(CT);
3076ForPtr m = loops[0];
3077loop.splitWithMask(m, 4);
3078}
3079{
3080auto const& loops = loop.getLoopStmtsFor(CT);
3081ForPtr n = loops[2];
3082loop.splitWithMask(n, 16);
3083}
3084// mo, mi, no, ni, k ->
3085// mo, no, mi, ni, k
3086{
3087auto const& loops = loop.getLoopStmtsFor(CT);
3088ForPtr mi = loops[1];
3089ForPtr no = loops[2];
3090loop.reorderAxis(mi, no);
3091}
3092// mo, no, mi, ni, k ->
3093// mo, no, mi, k, ni
3094{
3095auto const& loops = loop.getLoopStmtsFor(CT);
3096ForPtr ni = loops[3];
3097ForPtr k = loops[4];
3098loop.reorderAxis(ni, k);
3099}
3100// mo, no, mi, k, ni ->
3101// mo, no, k, mi, ni
3102{
3103auto const& loops = loop.getLoopStmtsFor(CT);
3104ForPtr mi = loops[2];
3105ForPtr k = loops[3];
3106loop.reorderAxis(mi, k);
3107}
3108{
3109auto const& loops = loop.getLoopStmtsFor(CT);
3110loop.cacheAccesses(CT.buf(), "C_regs", loops[2]);
3111}
3112
3113MemDependencyChecker analyzer_unlowered(
3114loop.getInputBufs(), loop.getOutputBufs());
3115
3116MemDependencyChecker analyzer_lowered(
3117loop.getInputBufs(), loop.getOutputBufs());
3118
3119// Test both unlowered and lowered form.
3120{
3121StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt());
3122stmt->accept(&analyzer_unlowered);
3123
3124// Outputs depend on inputs.
3125ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node()));
3126ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node()));
3127
3128// The last write to gemm should cover the total bound of the output.
3129std::shared_ptr<AccessInfo> outputAccess =
3130analyzer_unlowered.output(CT.buf());
3131// A single dependency.
3132ASSERT_EQ(outputAccess->dependencies().size(), 1);
3133
3134// dependencies is a set with 1 element, so can just deref begin().
3135std::shared_ptr<AccessInfo> gemmStore =
3136outputAccess->dependencies().begin()->second;
3137// Check its a store.
3138ASSERT_EQ(gemmStore->type(), AccessType::Store);
3139
3140ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds()));
3141
3142// Likewise the first read from each input cover the entire range of the
3143// input.
3144auto aInput = analyzer_unlowered.input(AP.node());
3145auto bInput = analyzer_unlowered.input(BP.node());
3146
3147// A single dependent each.
3148ASSERT_EQ(aInput->dependents().size(), 1);
3149ASSERT_EQ(bInput->dependents().size(), 1);
3150
3151// They're both loads.
3152std::shared_ptr<AccessInfo> aLoad = aInput->dependents().begin()->second;
3153std::shared_ptr<AccessInfo> bLoad = bInput->dependents().begin()->second;
3154ASSERT_EQ(aLoad->type(), AccessType::Load);
3155ASSERT_EQ(bLoad->type(), AccessType::Load);
3156
3157ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds()));
3158ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds()));
3159}
3160
3161loop.prepareForCodegen();
3162SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT});
3163
3164// now check lowered dependency graph.
3165{
3166StmtPtr stmt = IRSimplifier::simplify(cg.stmt());
3167stmt->accept(&analyzer_lowered);
3168
3169// Lowering will change the dimensionality of all bounds due to index
3170// flattening and will insert Allocates and Frees.
3171
3172auto history_before = analyzer_unlowered.getHistory();
3173auto history_after = analyzer_lowered.getHistory();
3174
3175ASSERT_EQ(history_before.size() + 2, history_after.size());
3176
3177// Filter out the alloc/free;
3178auto isAllocFree = [](const auto& info) {
3179return info->type() == AccessType::Alloc ||
3180info->type() == AccessType::Free;
3181};
3182history_after.erase(
3183std::remove_if(history_after.begin(), history_after.end(), isAllocFree),
3184history_after.end());
3185
3186ASSERT_EQ(history_before.size(), history_after.size());
3187
3188for (size_t i = 0; i < history_before.size(); ++i) {
3189ASSERT_EQ(history_before[i]->type(), history_after[i]->type());
3190ASSERT_EQ(history_before[i]->var(), history_after[i]->var());
3191
3192if (history_before[i]->dependencies().size() !=
3193history_after[i]->dependencies().size()) {
3194// Must depend on an Alloc.
3195ASSERT_TRUE(std::any_of(
3196history_after[i]->dependencies().begin(),
3197history_after[i]->dependencies().end(),
3198[](const auto& pair) {
3199return pair.second->type() == AccessType::Alloc;
3200}));
3201
3202ASSERT_EQ(
3203history_before[i]->dependencies().size() + 1,
3204history_after[i]->dependencies().size());
3205}
3206
3207if (history_before[i]->dependents().size() !=
3208history_after[i]->dependents().size()) {
3209// Must depend on an Free.
3210ASSERT_TRUE(std::any_of(
3211history_after[i]->dependents().begin(),
3212history_after[i]->dependents().end(),
3213[](const auto& pair) {
3214return pair.second->type() == AccessType::Free;
3215}));
3216
3217ASSERT_EQ(
3218history_before[i]->dependents().size() + 1,
3219history_after[i]->dependents().size());
3220}
3221
3222// Inputs and outputs are not flattened, only accesses.
3223if (history_before[i]->type() == AccessType::Input ||
3224history_before[i]->type() == AccessType::Output) {
3225ASSERT_EQ(
3226history_before[i]->bounds().size(),
3227history_after[i]->bounds().size());
3228ASSERT_TRUE(indexBoundsEquals(
3229history_before[i]->bounds(), history_after[i]->bounds()));
3230} else {
3231ASSERT_EQ(history_after[i]->bounds().size(), 1);
3232ExprPtr flat_bounds = alloc<IntImm>(1);
3233
3234for (auto& b : history_before[i]->bounds()) {
3235flat_bounds =
3236alloc<Mul>(flat_bounds, alloc<Add>(b.end, alloc<IntImm>(1)));
3237
3238// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3239ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start));
3240}
3241
3242flat_bounds = IRSimplifier::simplify(flat_bounds);
3243ExprPtr after_bounds = IRSimplifier::simplify(
3244alloc<Add>(history_after[i]->bounds()[0].end, alloc<IntImm>(1)));
3245ASSERT_TRUE(exprEquals(flat_bounds, after_bounds));
3246}
3247}
3248}
3249}
3250
3251} // namespace jit
3252} // namespace torch
3253