llvm-project
229 строк · 8.7 Кб
1//===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/AffineMap.h"10#include "mlir/IR/BuiltinAttributes.h"11#include "mlir/IR/BuiltinTypes.h"12#include "mlir/IR/Dialect.h"13#include "mlir/IR/DialectInterface.h"14#include "llvm/ADT/SmallVector.h"15#include "gtest/gtest.h"16#include <cstdint>17
18using namespace mlir;19using namespace mlir::detail;20
21namespace {22TEST(ShapedTypeTest, CloneMemref) {23MLIRContext context;24
25Type i32 = IntegerType::get(&context, 32);26Type f32 = FloatType::getF32(&context);27Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7);28Type memrefOriginalType = i32;29llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});30AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);31
32ShapedType memrefType =33(ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)34.setMemorySpace(memSpace)35.setLayout(AffineMapAttr::get(map));36// Update shape.37llvm::SmallVector<int64_t> memrefNewShape({30, 40});38ASSERT_NE(memrefOriginalShape, memrefNewShape);39ASSERT_EQ(memrefType.clone(memrefNewShape),40(ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)41.setMemorySpace(memSpace)42.setLayout(AffineMapAttr::get(map)));43// Update type.44Type memrefNewType = f32;45ASSERT_NE(memrefOriginalType, memrefNewType);46ASSERT_EQ(memrefType.clone(memrefNewType),47(MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)48.setMemorySpace(memSpace)49.setLayout(AffineMapAttr::get(map)));50// Update both.51ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),52(MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)53.setMemorySpace(memSpace)54.setLayout(AffineMapAttr::get(map)));55
56// Test unranked memref cloning.57ShapedType unrankedTensorType =58UnrankedMemRefType::get(memrefOriginalType, memSpace);59ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),60(MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)61.setMemorySpace(memSpace));62ASSERT_EQ(unrankedTensorType.clone(memrefNewType),63UnrankedMemRefType::get(memrefNewType, memSpace));64ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),65(MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)66.setMemorySpace(memSpace));67}
68
69TEST(ShapedTypeTest, CloneTensor) {70MLIRContext context;71
72Type i32 = IntegerType::get(&context, 32);73Type f32 = FloatType::getF32(&context);74
75Type tensorOriginalType = i32;76llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});77
78// Test ranked tensor cloning.79ShapedType tensorType =80RankedTensorType::get(tensorOriginalShape, tensorOriginalType);81// Update shape.82llvm::SmallVector<int64_t> tensorNewShape({30, 40});83ASSERT_NE(tensorOriginalShape, tensorNewShape);84ASSERT_EQ(85tensorType.clone(tensorNewShape),86(ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));87// Update type.88Type tensorNewType = f32;89ASSERT_NE(tensorOriginalType, tensorNewType);90ASSERT_EQ(91tensorType.clone(tensorNewType),92(ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));93// Update both.94ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),95(ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));96
97// Test unranked tensor cloning.98ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);99ASSERT_EQ(100unrankedTensorType.clone(tensorNewShape),101(ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));102ASSERT_EQ(unrankedTensorType.clone(tensorNewType),103(ShapedType)UnrankedTensorType::get(tensorNewType));104ASSERT_EQ(105unrankedTensorType.clone(tensorNewShape),106(ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));107}
108
109TEST(ShapedTypeTest, CloneVector) {110MLIRContext context;111
112Type i32 = IntegerType::get(&context, 32);113Type f32 = FloatType::getF32(&context);114
115Type vectorOriginalType = i32;116llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});117ShapedType vectorType =118VectorType::get(vectorOriginalShape, vectorOriginalType);119// Update shape.120llvm::SmallVector<int64_t> vectorNewShape({30, 40});121ASSERT_NE(vectorOriginalShape, vectorNewShape);122ASSERT_EQ(vectorType.clone(vectorNewShape),123VectorType::get(vectorNewShape, vectorOriginalType));124// Update type.125Type vectorNewType = f32;126ASSERT_NE(vectorOriginalType, vectorNewType);127ASSERT_EQ(vectorType.clone(vectorNewType),128VectorType::get(vectorOriginalShape, vectorNewType));129// Update both.130ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),131VectorType::get(vectorNewShape, vectorNewType));132}
133
134TEST(ShapedTypeTest, VectorTypeBuilder) {135MLIRContext context;136Type f32 = FloatType::getF32(&context);137
138SmallVector<int64_t> shape{2, 4, 8, 9, 1};139SmallVector<bool> scalableDims{true, false, true, false, false};140VectorType vectorType = VectorType::get(shape, f32, scalableDims);141
142{143// Drop some dims.144VectorType dropFrontTwoDims =145VectorType::Builder(vectorType).dropDim(0).dropDim(0);146ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());147ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());148ASSERT_EQ(vectorType.getScalableDims().drop_front(2),149dropFrontTwoDims.getScalableDims());150}151
152{153// Set some dims.154VectorType setTwoDims =155VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);156ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));157ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());158ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());159}160
161{162// Test for bug from:163// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a164// Constructs a temporary builder, modifies it, copies it to `builder`.165// This used to lead to a use-after-free. Running under sanitizers will166// catch any issues.167VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);168VectorType newVectorType = VectorType(builder);169ASSERT_EQ(newVectorType.getDimSize(0), 16);170}171
172{173// Make builder from scratch (without scalable dims) -- this use to lead to174// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.175// Running under sanitizers will catch any issues.176SmallVector<int64_t> shape{1, 2, 3, 4};177VectorType::Builder builder(shape, f32);178ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));179}180
181{182// Set vector shape (without scalable dims) -- this use to lead to183// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.184// Running under sanitizers will catch any issues.185VectorType::Builder builder(vectorType);186SmallVector<int64_t> newShape{2, 2};187builder.setShape(newShape);188ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));189}190}
191
192TEST(ShapedTypeTest, RankedTensorTypeBuilder) {193MLIRContext context;194Type f32 = FloatType::getF32(&context);195
196SmallVector<int64_t> shape{2, 4, 8, 16, 32};197RankedTensorType tensorType = RankedTensorType::get(shape, f32);198
199{200// Drop some dims.201RankedTensorType dropFrontTwoDims =202RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);203ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());204ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));205}206
207{208// Insert some dims.209RankedTensorType insertTwoDims =210RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);211ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());212ASSERT_EQ(insertTwoDims.getShape(),213ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));214}215
216{217// Test for bug from:218// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a219// Constructs a temporary builder, modifies it, copies it to `builder`.220// This used to lead to a use-after-free. Running under sanitizers will221// catch any issues.222RankedTensorType::Builder builder =223RankedTensorType::Builder(tensorType).dropDim(0);224RankedTensorType newTensorType = RankedTensorType(builder);225ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());226}227}
228
229} // namespace230