llvm-project

Форк
0
/
ShapedTypeTest.cpp 
229 строк · 8.7 Кб
1
//===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8

9
#include "mlir/IR/AffineMap.h"
10
#include "mlir/IR/BuiltinAttributes.h"
11
#include "mlir/IR/BuiltinTypes.h"
12
#include "mlir/IR/Dialect.h"
13
#include "mlir/IR/DialectInterface.h"
14
#include "llvm/ADT/SmallVector.h"
15
#include "gtest/gtest.h"
16
#include <cstdint>
17

18
using namespace mlir;
19
using namespace mlir::detail;
20

21
namespace {
22
TEST(ShapedTypeTest, CloneMemref) {
23
  MLIRContext context;
24

25
  Type i32 = IntegerType::get(&context, 32);
26
  Type f32 = FloatType::getF32(&context);
27
  Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);
28
  Type memrefOriginalType = i32;
29
  llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
30
  AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
31

32
  ShapedType memrefType =
33
      (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
34
          .setMemorySpace(memSpace)
35
          .setLayout(AffineMapAttr::get(map));
36
  // Update shape.
37
  llvm::SmallVector<int64_t> memrefNewShape({30, 40});
38
  ASSERT_NE(memrefOriginalShape, memrefNewShape);
39
  ASSERT_EQ(memrefType.clone(memrefNewShape),
40
            (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
41
                .setMemorySpace(memSpace)
42
                .setLayout(AffineMapAttr::get(map)));
43
  // Update type.
44
  Type memrefNewType = f32;
45
  ASSERT_NE(memrefOriginalType, memrefNewType);
46
  ASSERT_EQ(memrefType.clone(memrefNewType),
47
            (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
48
                .setMemorySpace(memSpace)
49
                .setLayout(AffineMapAttr::get(map)));
50
  // Update both.
51
  ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
52
            (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
53
                .setMemorySpace(memSpace)
54
                .setLayout(AffineMapAttr::get(map)));
55

56
  // Test unranked memref cloning.
57
  ShapedType unrankedTensorType =
58
      UnrankedMemRefType::get(memrefOriginalType, memSpace);
59
  ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
60
            (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
61
                .setMemorySpace(memSpace));
62
  ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
63
            UnrankedMemRefType::get(memrefNewType, memSpace));
64
  ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
65
            (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
66
                .setMemorySpace(memSpace));
67
}
68

69
TEST(ShapedTypeTest, CloneTensor) {
70
  MLIRContext context;
71

72
  Type i32 = IntegerType::get(&context, 32);
73
  Type f32 = FloatType::getF32(&context);
74

75
  Type tensorOriginalType = i32;
76
  llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
77

78
  // Test ranked tensor cloning.
79
  ShapedType tensorType =
80
      RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
81
  // Update shape.
82
  llvm::SmallVector<int64_t> tensorNewShape({30, 40});
83
  ASSERT_NE(tensorOriginalShape, tensorNewShape);
84
  ASSERT_EQ(
85
      tensorType.clone(tensorNewShape),
86
      (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
87
  // Update type.
88
  Type tensorNewType = f32;
89
  ASSERT_NE(tensorOriginalType, tensorNewType);
90
  ASSERT_EQ(
91
      tensorType.clone(tensorNewType),
92
      (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));
93
  // Update both.
94
  ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
95
            (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));
96

97
  // Test unranked tensor cloning.
98
  ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
99
  ASSERT_EQ(
100
      unrankedTensorType.clone(tensorNewShape),
101
      (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
102
  ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
103
            (ShapedType)UnrankedTensorType::get(tensorNewType));
104
  ASSERT_EQ(
105
      unrankedTensorType.clone(tensorNewShape),
106
      (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
107
}
108

109
TEST(ShapedTypeTest, CloneVector) {
110
  MLIRContext context;
111

112
  Type i32 = IntegerType::get(&context, 32);
113
  Type f32 = FloatType::getF32(&context);
114

115
  Type vectorOriginalType = i32;
116
  llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
117
  ShapedType vectorType =
118
      VectorType::get(vectorOriginalShape, vectorOriginalType);
119
  // Update shape.
120
  llvm::SmallVector<int64_t> vectorNewShape({30, 40});
121
  ASSERT_NE(vectorOriginalShape, vectorNewShape);
122
  ASSERT_EQ(vectorType.clone(vectorNewShape),
123
            VectorType::get(vectorNewShape, vectorOriginalType));
124
  // Update type.
125
  Type vectorNewType = f32;
126
  ASSERT_NE(vectorOriginalType, vectorNewType);
127
  ASSERT_EQ(vectorType.clone(vectorNewType),
128
            VectorType::get(vectorOriginalShape, vectorNewType));
129
  // Update both.
130
  ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
131
            VectorType::get(vectorNewShape, vectorNewType));
132
}
133

134
TEST(ShapedTypeTest, VectorTypeBuilder) {
135
  MLIRContext context;
136
  Type f32 = FloatType::getF32(&context);
137

138
  SmallVector<int64_t> shape{2, 4, 8, 9, 1};
139
  SmallVector<bool> scalableDims{true, false, true, false, false};
140
  VectorType vectorType = VectorType::get(shape, f32, scalableDims);
141

142
  {
143
    // Drop some dims.
144
    VectorType dropFrontTwoDims =
145
        VectorType::Builder(vectorType).dropDim(0).dropDim(0);
146
    ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
147
    ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
148
    ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
149
              dropFrontTwoDims.getScalableDims());
150
  }
151

152
  {
153
    // Set some dims.
154
    VectorType setTwoDims =
155
        VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
156
    ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
157
    ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
158
    ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
159
  }
160

161
  {
162
    // Test for bug from:
163
    // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
164
    // Constructs a temporary builder, modifies it, copies it to `builder`.
165
    // This used to lead to a use-after-free. Running under sanitizers will
166
    // catch any issues.
167
    VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
168
    VectorType newVectorType = VectorType(builder);
169
    ASSERT_EQ(newVectorType.getDimSize(0), 16);
170
  }
171

172
  {
173
    // Make builder from scratch (without scalable dims) -- this use to lead to
174
    // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
175
    // Running under sanitizers will catch any issues.
176
    SmallVector<int64_t> shape{1, 2, 3, 4};
177
    VectorType::Builder builder(shape, f32);
178
    ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
179
  }
180

181
  {
182
    // Set vector shape (without scalable dims) -- this use to lead to
183
    // a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
184
    // Running under sanitizers will catch any issues.
185
    VectorType::Builder builder(vectorType);
186
    SmallVector<int64_t> newShape{2, 2};
187
    builder.setShape(newShape);
188
    ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
189
  }
190
}
191

192
TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
193
  MLIRContext context;
194
  Type f32 = FloatType::getF32(&context);
195

196
  SmallVector<int64_t> shape{2, 4, 8, 16, 32};
197
  RankedTensorType tensorType = RankedTensorType::get(shape, f32);
198

199
  {
200
    // Drop some dims.
201
    RankedTensorType dropFrontTwoDims =
202
        RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);
203
    ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
204
    ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));
205
  }
206

207
  {
208
    // Insert some dims.
209
    RankedTensorType insertTwoDims =
210
        RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
211
    ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
212
    ASSERT_EQ(insertTwoDims.getShape(),
213
              ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
214
  }
215

216
  {
217
    // Test for bug from:
218
    // https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
219
    // Constructs a temporary builder, modifies it, copies it to `builder`.
220
    // This used to lead to a use-after-free. Running under sanitizers will
221
    // catch any issues.
222
    RankedTensorType::Builder builder =
223
        RankedTensorType::Builder(tensorType).dropDim(0);
224
    RankedTensorType newTensorType = RankedTensorType(builder);
225
    ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
226
  }
227
}
228

229
} // namespace
230

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.