llvm-project

Форк
0
86 строк · 2.6 Кб
1
//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8

9
#include "mlir/IR/PatternMatch.h"
10
#include "mlir/Parser/Parser.h"
11
#include "mlir/Pass/PassManager.h"
12
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13
#include "mlir/Transforms/Passes.h"
14
#include "gtest/gtest.h"
15

16
using namespace mlir;
17

18
namespace {
19

20
struct DisabledPattern : public RewritePattern {
21
  DisabledPattern(MLIRContext *context)
22
      : RewritePattern("test.foo", /*benefit=*/0, context,
23
                       /*generatedNamed=*/{}) {
24
    setDebugName("DisabledPattern");
25
  }
26

27
  LogicalResult matchAndRewrite(Operation *op,
28
                                PatternRewriter &rewriter) const override {
29
    if (op->getNumResults() != 1)
30
      return failure();
31
    rewriter.eraseOp(op);
32
    return success();
33
  }
34
};
35

36
struct EnabledPattern : public RewritePattern {
37
  EnabledPattern(MLIRContext *context)
38
      : RewritePattern("test.foo", /*benefit=*/0, context,
39
                       /*generatedNamed=*/{}) {
40
    setDebugName("EnabledPattern");
41
  }
42

43
  LogicalResult matchAndRewrite(Operation *op,
44
                                PatternRewriter &rewriter) const override {
45
    if (op->getNumResults() == 1)
46
      return failure();
47
    rewriter.eraseOp(op);
48
    return success();
49
  }
50
};
51

52
struct TestDialect : public Dialect {
53
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
54

55
  static StringRef getDialectNamespace() { return "test"; }
56

57
  TestDialect(MLIRContext *context)
58
      : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {
59
    allowUnknownOperations();
60
  }
61

62
  void getCanonicalizationPatterns(RewritePatternSet &results) const override {
63
    results.add<DisabledPattern, EnabledPattern>(results.getContext());
64
  }
65
};
66

67
TEST(CanonicalizerTest, TestDisablePatterns) {
68
  MLIRContext context;
69
  context.getOrLoadDialect<TestDialect>();
70
  PassManager mgr(&context);
71
  mgr.addPass(
72
      createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));
73

74
  const char *const code = R"mlir(
75
    %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
76
    %1 = "test.foo"() {sym_name = "B"} : () -> (f32)
77
  )mlir";
78

79
  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
80
  ASSERT_TRUE(succeeded(mgr.run(*module)));
81

82
  EXPECT_TRUE(module->lookupSymbol("B"));
83
  EXPECT_FALSE(module->lookupSymbol("A"));
84
}
85

86
} // end anonymous namespace
87

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.