llvm-project
85 строк · 3.3 Кб
1//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass converts operations into their canonical forms by
10// folding constants, applying operation identity transformations etc.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Transforms/Passes.h"
15
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_CANONICALIZER
21#include "mlir/Transforms/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25
26namespace {
27/// Canonicalize operations in nested regions.
28struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
29Canonicalizer() = default;
30Canonicalizer(const GreedyRewriteConfig &config,
31ArrayRef<std::string> disabledPatterns,
32ArrayRef<std::string> enabledPatterns)
33: config(config) {
34this->topDownProcessingEnabled = config.useTopDownTraversal;
35this->enableRegionSimplification = config.enableRegionSimplification;
36this->maxIterations = config.maxIterations;
37this->maxNumRewrites = config.maxNumRewrites;
38this->disabledPatterns = disabledPatterns;
39this->enabledPatterns = enabledPatterns;
40}
41
42/// Initialize the canonicalizer by building the set of patterns used during
43/// execution.
44LogicalResult initialize(MLIRContext *context) override {
45// Set the config from possible pass options set in the meantime.
46config.useTopDownTraversal = topDownProcessingEnabled;
47config.enableRegionSimplification = enableRegionSimplification;
48config.maxIterations = maxIterations;
49config.maxNumRewrites = maxNumRewrites;
50
51RewritePatternSet owningPatterns(context);
52for (auto *dialect : context->getLoadedDialects())
53dialect->getCanonicalizationPatterns(owningPatterns);
54for (RegisteredOperationName op : context->getRegisteredOperations())
55op.getCanonicalizationPatterns(owningPatterns, context);
56
57patterns = std::make_shared<FrozenRewritePatternSet>(
58std::move(owningPatterns), disabledPatterns, enabledPatterns);
59return success();
60}
61void runOnOperation() override {
62LogicalResult converged =
63applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
64// Canonicalization is best-effort. Non-convergence is not a pass failure.
65if (testConvergence && failed(converged))
66signalPassFailure();
67}
68GreedyRewriteConfig config;
69std::shared_ptr<const FrozenRewritePatternSet> patterns;
70};
71} // namespace
72
73/// Create a Canonicalizer pass.
74std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
75return std::make_unique<Canonicalizer>();
76}
77
78/// Creates an instance of the Canonicalizer pass with the specified config.
79std::unique_ptr<Pass>
80mlir::createCanonicalizerPass(const GreedyRewriteConfig &config,
81ArrayRef<std::string> disabledPatterns,
82ArrayRef<std::string> enabledPatterns) {
83return std::make_unique<Canonicalizer>(config, disabledPatterns,
84enabledPatterns);
85}
86