llvm-project
86 строк · 2.6 Кб
1//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/PatternMatch.h"10#include "mlir/Parser/Parser.h"11#include "mlir/Pass/PassManager.h"12#include "mlir/Transforms/GreedyPatternRewriteDriver.h"13#include "mlir/Transforms/Passes.h"14#include "gtest/gtest.h"15
16using namespace mlir;17
18namespace {19
20struct DisabledPattern : public RewritePattern {21DisabledPattern(MLIRContext *context)22: RewritePattern("test.foo", /*benefit=*/0, context,23/*generatedNamed=*/{}) {24setDebugName("DisabledPattern");25}26
27LogicalResult matchAndRewrite(Operation *op,28PatternRewriter &rewriter) const override {29if (op->getNumResults() != 1)30return failure();31rewriter.eraseOp(op);32return success();33}34};35
36struct EnabledPattern : public RewritePattern {37EnabledPattern(MLIRContext *context)38: RewritePattern("test.foo", /*benefit=*/0, context,39/*generatedNamed=*/{}) {40setDebugName("EnabledPattern");41}42
43LogicalResult matchAndRewrite(Operation *op,44PatternRewriter &rewriter) const override {45if (op->getNumResults() == 1)46return failure();47rewriter.eraseOp(op);48return success();49}50};51
52struct TestDialect : public Dialect {53MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)54
55static StringRef getDialectNamespace() { return "test"; }56
57TestDialect(MLIRContext *context)58: Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {59allowUnknownOperations();60}61
62void getCanonicalizationPatterns(RewritePatternSet &results) const override {63results.add<DisabledPattern, EnabledPattern>(results.getContext());64}65};66
67TEST(CanonicalizerTest, TestDisablePatterns) {68MLIRContext context;69context.getOrLoadDialect<TestDialect>();70PassManager mgr(&context);71mgr.addPass(72createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));73
74const char *const code = R"mlir(75%0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
76%1 = "test.foo"() {sym_name = "B"} : () -> (f32)
77)mlir";78
79OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);80ASSERT_TRUE(succeeded(mgr.run(*module)));81
82EXPECT_TRUE(module->lookupSymbol("B"));83EXPECT_FALSE(module->lookupSymbol("A"));84}
85
86} // end anonymous namespace87