llvm-project
1001 строка · 36.9 Кб
1//===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Pass/PassRegistry.h"10
11#include "mlir/Pass/Pass.h"12#include "mlir/Pass/PassManager.h"13#include "llvm/ADT/DenseMap.h"14#include "llvm/ADT/ScopeExit.h"15#include "llvm/Support/Format.h"16#include "llvm/Support/ManagedStatic.h"17#include "llvm/Support/MemoryBuffer.h"18#include "llvm/Support/SourceMgr.h"19
20#include <optional>21#include <utility>22
23using namespace mlir;24using namespace detail;25
26/// Static mapping of all of the registered passes.
27static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;28
29/// A mapping of the above pass registry entries to the corresponding TypeID
30/// of the pass that they generate.
31static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;32
33/// Static mapping of all of the registered pass pipelines.
34static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>35passPipelineRegistry;36
37/// Utility to create a default registry function from a pass instance.
38static PassRegistryFunction39buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {40return [=](OpPassManager &pm, StringRef options,41function_ref<LogicalResult(const Twine &)> errorHandler) {42std::unique_ptr<Pass> pass = allocator();43LogicalResult result = pass->initializeOptions(options, errorHandler);44
45std::optional<StringRef> pmOpName = pm.getOpName();46std::optional<StringRef> passOpName = pass->getOpName();47if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&48passOpName && *pmOpName != *passOpName) {49return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +50"' restricted to '" + *pass->getOpName() +51"' on a PassManager intended to run on '" +52pm.getOpAnchorName() + "', did you intend to nest?");53}54pm.addPass(std::move(pass));55return result;56};57}
58
59/// Utility to print the help string for a specific option.
60static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,61size_t descIndent, bool isTopLevel) {62size_t numSpaces = descIndent - indent - 4;63llvm::outs().indent(indent)64<< "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';65}
66
67//===----------------------------------------------------------------------===//
68// PassRegistry
69//===----------------------------------------------------------------------===//
70
71/// Print the help information for this pass. This includes the argument,
72/// description, and any pass options. `descIndent` is the indent that the
73/// descriptions should be aligned.
74void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {75printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,76/*isTopLevel=*/true);77// If this entry has options, print the help for those as well.78optHandler([=](const PassOptions &options) {79options.printHelp(indent, descIndent);80});81}
82
83/// Return the maximum width required when printing the options of this
84/// entry.
85size_t PassRegistryEntry::getOptionWidth() const {86size_t maxLen = 0;87optHandler([&](const PassOptions &options) mutable {88maxLen = options.getOptionWidth() + 2;89});90return maxLen;91}
92
93//===----------------------------------------------------------------------===//
94// PassPipelineInfo
95//===----------------------------------------------------------------------===//
96
97void mlir::registerPassPipeline(98StringRef arg, StringRef description, const PassRegistryFunction &function,99std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {100PassPipelineInfo pipelineInfo(arg, description, function,101std::move(optHandler));102bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;103#ifndef NDEBUG104if (!inserted)105report_fatal_error("Pass pipeline " + arg + " registered multiple times");106#endif107(void)inserted;108}
109
110//===----------------------------------------------------------------------===//
111// PassInfo
112//===----------------------------------------------------------------------===//
113
114PassInfo::PassInfo(StringRef arg, StringRef description,115const PassAllocatorFunction &allocator)116: PassRegistryEntry(117arg, description, buildDefaultRegistryFn(allocator),118// Use a temporary pass to provide an options instance.119[=](function_ref<void(const PassOptions &)> optHandler) {120optHandler(allocator()->passOptions);121}) {}122
123void mlir::registerPass(const PassAllocatorFunction &function) {124std::unique_ptr<Pass> pass = function();125StringRef arg = pass->getArgument();126if (arg.empty())127llvm::report_fatal_error(llvm::Twine("Trying to register '") +128pass->getName() +129"' pass that does not override `getArgument()`");130StringRef description = pass->getDescription();131PassInfo passInfo(arg, description, function);132passRegistry->try_emplace(arg, passInfo);133
134// Verify that the registered pass has the same ID as any registered to this135// arg before it.136TypeID entryTypeID = pass->getTypeID();137auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;138if (it->second != entryTypeID)139llvm::report_fatal_error(140"pass allocator creates a different pass than previously "141"registered for pass " +142arg);143}
144
145/// Returns the pass info for the specified pass argument or null if unknown.
146const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {147auto it = passRegistry->find(passArg);148return it == passRegistry->end() ? nullptr : &it->second;149}
150
151/// Returns the pass pipeline info for the specified pass pipeline argument or
152/// null if unknown.
153const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {154auto it = passPipelineRegistry->find(pipelineArg);155return it == passPipelineRegistry->end() ? nullptr : &it->second;156}
157
158//===----------------------------------------------------------------------===//
159// PassOptions
160//===----------------------------------------------------------------------===//
161
162LogicalResult detail::pass_options::parseCommaSeparatedList(163llvm::cl::Option &opt, StringRef argName, StringRef optionStr,164function_ref<LogicalResult(StringRef)> elementParseFn) {165// Functor used for finding a character in a string, and skipping over166// various "range" characters.167llvm::unique_function<size_t(StringRef, size_t, char)> findChar =168[&](StringRef str, size_t index, char c) -> size_t {169for (size_t i = index, e = str.size(); i < e; ++i) {170if (str[i] == c)171return i;172// Check for various range characters.173if (str[i] == '{')174i = findChar(str, i + 1, '}');175else if (str[i] == '(')176i = findChar(str, i + 1, ')');177else if (str[i] == '[')178i = findChar(str, i + 1, ']');179else if (str[i] == '\"')180i = str.find_first_of('\"', i + 1);181else if (str[i] == '\'')182i = str.find_first_of('\'', i + 1);183}184return StringRef::npos;185};186
187size_t nextElePos = findChar(optionStr, 0, ',');188while (nextElePos != StringRef::npos) {189// Process the portion before the comma.190if (failed(elementParseFn(optionStr.substr(0, nextElePos))))191return failure();192
193optionStr = optionStr.substr(nextElePos + 1);194nextElePos = findChar(optionStr, 0, ',');195}196return elementParseFn(optionStr.substr(0, nextElePos));197}
198
199/// Out of line virtual function to provide home for the class.
200void detail::PassOptions::OptionBase::anchor() {}201
202/// Copy the option values from 'other'.
203void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {204assert(options.size() == other.options.size());205if (options.empty())206return;207for (auto optionsIt : llvm::zip(options, other.options))208std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));209}
210
211/// Parse in the next argument from the given options string. Returns a tuple
212/// containing [the key of the option, the value of the option, updated
213/// `options` string pointing after the parsed option].
214static std::tuple<StringRef, StringRef, StringRef>215parseNextArg(StringRef options) {216// Functor used to extract an argument from 'options' and update it to point217// after the arg.218auto extractArgAndUpdateOptions = [&](size_t argSize) {219StringRef str = options.take_front(argSize).trim();220options = options.drop_front(argSize).ltrim();221// Handle escape sequences222if (str.size() > 2) {223const auto escapePairs = {std::make_pair('\'', '\''),224std::make_pair('"', '"'),225std::make_pair('{', '}')};226for (const auto &escape : escapePairs) {227if (str.front() == escape.first && str.back() == escape.second) {228// Drop the escape characters and trim.229str = str.drop_front().drop_back().trim();230// Don't process additional escape sequences.231break;232}233}234}235return str;236};237// Try to process the given punctuation, properly escaping any contained238// characters.239auto tryProcessPunct = [&](size_t ¤tPos, char punct) {240if (options[currentPos] != punct)241return false;242size_t nextIt = options.find_first_of(punct, currentPos + 1);243if (nextIt != StringRef::npos)244currentPos = nextIt;245return true;246};247
248// Parse the argument name of the option.249StringRef argName;250for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {251// Check for the end of the full option.252if (argEndIt == optionsE || options[argEndIt] == ' ') {253argName = extractArgAndUpdateOptions(argEndIt);254return std::make_tuple(argName, StringRef(), options);255}256
257// Check for the end of the name and the start of the value.258if (options[argEndIt] == '=') {259argName = extractArgAndUpdateOptions(argEndIt);260options = options.drop_front();261break;262}263}264
265// Parse the value of the option.266for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {267// Handle the end of the options string.268if (argEndIt == optionsE || options[argEndIt] == ' ') {269StringRef value = extractArgAndUpdateOptions(argEndIt);270return std::make_tuple(argName, value, options);271}272
273// Skip over escaped sequences.274char c = options[argEndIt];275if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))276continue;277// '{...}' is used to specify options to passes, properly escape it so278// that we don't accidentally split any nested options.279if (c == '{') {280size_t braceCount = 1;281for (++argEndIt; argEndIt != optionsE; ++argEndIt) {282// Allow nested punctuation.283if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))284continue;285if (options[argEndIt] == '{')286++braceCount;287else if (options[argEndIt] == '}' && --braceCount == 0)288break;289}290// Account for the increment at the top of the loop.291--argEndIt;292}293}294llvm_unreachable("unexpected control flow in pass option parsing");295}
296
297LogicalResult detail::PassOptions::parseFromString(StringRef options,298raw_ostream &errorStream) {299// NOTE: `options` is modified in place to always refer to the unprocessed300// part of the string.301while (!options.empty()) {302StringRef key, value;303std::tie(key, value, options) = parseNextArg(options);304if (key.empty())305continue;306
307auto it = OptionsMap.find(key);308if (it == OptionsMap.end()) {309errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";310return failure();311}312if (llvm::cl::ProvidePositionalOption(it->second, value, 0))313return failure();314}315
316return success();317}
318
319/// Print the options held by this struct in a form that can be parsed via
320/// 'parseFromString'.
321void detail::PassOptions::print(raw_ostream &os) {322// If there are no options, there is nothing left to do.323if (OptionsMap.empty())324return;325
326// Sort the options to make the ordering deterministic.327SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());328auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {329return (*lhs)->getArgStr().compare((*rhs)->getArgStr());330};331llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);332
333// Interleave the options with ' '.334os << '{';335llvm::interleave(336orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");337os << '}';338}
339
340/// Print the help string for the options held by this struct. `descIndent` is
341/// the indent within the stream that the descriptions should be aligned.
342void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {343// Sort the options to make the ordering deterministic.344SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());345auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {346return (*lhs)->getArgStr().compare((*rhs)->getArgStr());347};348llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);349for (OptionBase *option : orderedOps) {350// TODO: printOptionInfo assumes a specific indent and will351// print options with values with incorrect indentation. We should add352// support to llvm::cl::Option for passing in a base indent to use when353// printing.354llvm::outs().indent(indent);355option->getOption()->printOptionInfo(descIndent - indent);356}357}
358
359/// Return the maximum width required when printing the help string.
360size_t detail::PassOptions::getOptionWidth() const {361size_t max = 0;362for (auto *option : options)363max = std::max(max, option->getOption()->getOptionWidth());364return max;365}
366
367//===----------------------------------------------------------------------===//
368// MLIR Options
369//===----------------------------------------------------------------------===//
370
371//===----------------------------------------------------------------------===//
372// OpPassManager: OptionValue
373
374llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;375llvm::cl::OptionValue<OpPassManager>::OptionValue(376const mlir::OpPassManager &value) {377setValue(value);378}
379llvm::cl::OptionValue<OpPassManager>::OptionValue(380const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) {381if (rhs.hasValue())382setValue(rhs.getValue());383}
384llvm::cl::OptionValue<OpPassManager> &385llvm::cl::OptionValue<OpPassManager>::operator=(386const mlir::OpPassManager &rhs) {387setValue(rhs);388return *this;389}
390
391llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;392
393void llvm::cl::OptionValue<OpPassManager>::setValue(394const OpPassManager &newValue) {395if (hasValue())396*value = newValue;397else398value = std::make_unique<mlir::OpPassManager>(newValue);399}
400void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {401FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);402assert(succeeded(pipeline) && "invalid pass pipeline");403setValue(*pipeline);404}
405
406bool llvm::cl::OptionValue<OpPassManager>::compare(407const mlir::OpPassManager &rhs) const {408std::string lhsStr, rhsStr;409{410raw_string_ostream lhsStream(lhsStr);411value->printAsTextualPipeline(lhsStream);412
413raw_string_ostream rhsStream(rhsStr);414rhs.printAsTextualPipeline(rhsStream);415}416
417// Use the textual format for pipeline comparisons.418return lhsStr == rhsStr;419}
420
421void llvm::cl::OptionValue<OpPassManager>::anchor() {}422
423//===----------------------------------------------------------------------===//
424// OpPassManager: Parser
425
426namespace llvm {427namespace cl {428template class basic_parser<OpPassManager>;429} // namespace cl430} // namespace llvm431
432bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,433ParsedPassManager &value) {434FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);435if (failed(pipeline))436return true;437value.value = std::make_unique<OpPassManager>(std::move(*pipeline));438return false;439}
440
441void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,442const OpPassManager &value) {443value.printAsTextualPipeline(os);444}
445
446void llvm::cl::parser<OpPassManager>::printOptionDiff(447const Option &opt, OpPassManager &pm, const OptVal &defaultValue,448size_t globalWidth) const {449printOptionName(opt, globalWidth);450outs() << "= ";451pm.printAsTextualPipeline(outs());452
453if (defaultValue.hasValue()) {454outs().indent(2) << " (default: ";455defaultValue.getValue().printAsTextualPipeline(outs());456outs() << ")";457}458outs() << "\n";459}
460
461void llvm::cl::parser<OpPassManager>::anchor() {}462
463llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =464default;465llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(466ParsedPassManager &&) = default;467llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =468default;469
470//===----------------------------------------------------------------------===//
471// TextualPassPipeline Parser
472//===----------------------------------------------------------------------===//
473
474namespace {475/// This class represents a textual description of a pass pipeline.
476class TextualPipeline {477public:478/// Try to initialize this pipeline with the given pipeline text.479/// `errorStream` is the output stream to emit errors to.480LogicalResult initialize(StringRef text, raw_ostream &errorStream);481
482/// Add the internal pipeline elements to the provided pass manager.483LogicalResult
484addToPipeline(OpPassManager &pm,485function_ref<LogicalResult(const Twine &)> errorHandler) const;486
487private:488/// A functor used to emit errors found during pipeline handling. The first489/// parameter corresponds to the raw location within the pipeline string. This490/// should always return failure.491using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;492
493/// A struct to capture parsed pass pipeline names.494///495/// A pipeline is defined as a series of names, each of which may in itself496/// recursively contain a nested pipeline. A name is either the name of a pass497/// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If498/// the name is the name of a pass, the InnerPipeline is empty, since passes499/// cannot contain inner pipelines.500struct PipelineElement {501PipelineElement(StringRef name) : name(name) {}502
503StringRef name;504StringRef options;505const PassRegistryEntry *registryEntry = nullptr;506std::vector<PipelineElement> innerPipeline;507};508
509/// Parse the given pipeline text into the internal pipeline vector. This510/// function only parses the structure of the pipeline, and does not resolve511/// its elements.512LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);513
514/// Resolve the elements of the pipeline, i.e. connect passes and pipelines to515/// the corresponding registry entry.516LogicalResult
517resolvePipelineElements(MutableArrayRef<PipelineElement> elements,518ErrorHandlerT errorHandler);519
520/// Resolve a single element of the pipeline.521LogicalResult resolvePipelineElement(PipelineElement &element,522ErrorHandlerT errorHandler);523
524/// Add the given pipeline elements to the provided pass manager.525LogicalResult
526addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,527function_ref<LogicalResult(const Twine &)> errorHandler) const;528
529std::vector<PipelineElement> pipeline;530};531
532} // namespace533
534/// Try to initialize this pipeline with the given pipeline text. An option is
535/// given to enable accurate error reporting.
536LogicalResult TextualPipeline::initialize(StringRef text,537raw_ostream &errorStream) {538if (text.empty())539return success();540
541// Build a source manager to use for error reporting.542llvm::SourceMgr pipelineMgr;543pipelineMgr.AddNewSourceBuffer(544llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",545/*RequiresNullTerminator=*/false),546SMLoc());547auto errorHandler = [&](const char *rawLoc, Twine msg) {548pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),549llvm::SourceMgr::DK_Error, msg);550return failure();551};552
553// Parse the provided pipeline string.554if (failed(parsePipelineText(text, errorHandler)))555return failure();556return resolvePipelineElements(pipeline, errorHandler);557}
558
559/// Add the internal pipeline elements to the provided pass manager.
560LogicalResult TextualPipeline::addToPipeline(561OpPassManager &pm,562function_ref<LogicalResult(const Twine &)> errorHandler) const {563// Temporarily disable implicit nesting while we append to the pipeline. We564// want the created pipeline to exactly match the parsed text pipeline, so565// it's preferrable to just error out if implicit nesting would be required.566OpPassManager::Nesting nesting = pm.getNesting();567pm.setNesting(OpPassManager::Nesting::Explicit);568auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });569
570return addToPipeline(pipeline, pm, errorHandler);571}
572
573/// Parse the given pipeline text into the internal pipeline vector. This
574/// function only parses the structure of the pipeline, and does not resolve
575/// its elements.
576LogicalResult TextualPipeline::parsePipelineText(StringRef text,577ErrorHandlerT errorHandler) {578SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};579for (;;) {580std::vector<PipelineElement> &pipeline = *pipelineStack.back();581size_t pos = text.find_first_of(",(){");582pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());583
584// If we have a single terminating name, we're done.585if (pos == StringRef::npos)586break;587
588text = text.substr(pos);589char sep = text[0];590
591// Handle pulling ... from 'pass{...}' out as PipelineElement.options.592if (sep == '{') {593text = text.substr(1);594
595// Skip over everything until the closing '}' and store as options.596size_t close = StringRef::npos;597for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {598if (text[i] == '{') {599++braceCount;600continue;601}602if (text[i] == '}' && --braceCount == 0) {603close = i;604break;605}606}607
608// Check to see if a closing options brace was found.609if (close == StringRef::npos) {610return errorHandler(611/*rawLoc=*/text.data() - 1,612"missing closing '}' while processing pass options");613}614pipeline.back().options = text.substr(0, close);615text = text.substr(close + 1);616
617// Consume space characters that an user might add for readability.618text = text.ltrim();619
620// Skip checking for '(' because nested pipelines cannot have options.621} else if (sep == '(') {622text = text.substr(1);623
624// Push the inner pipeline onto the stack to continue processing.625pipelineStack.push_back(&pipeline.back().innerPipeline);626continue;627}628
629// When handling the close parenthesis, we greedily consume them to avoid630// empty strings in the pipeline.631while (text.consume_front(")")) {632// If we try to pop the outer pipeline we have unbalanced parentheses.633if (pipelineStack.size() == 1)634return errorHandler(/*rawLoc=*/text.data() - 1,635"encountered extra closing ')' creating unbalanced "636"parentheses while parsing pipeline");637
638pipelineStack.pop_back();639// Consume space characters that an user might add for readability.640text = text.ltrim();641}642
643// Check if we've finished parsing.644if (text.empty())645break;646
647// Otherwise, the end of an inner pipeline always has to be followed by648// a comma, and then we can continue.649if (!text.consume_front(","))650return errorHandler(text.data(), "expected ',' after parsing pipeline");651}652
653// Check for unbalanced parentheses.654if (pipelineStack.size() > 1)655return errorHandler(656text.data(),657"encountered unbalanced parentheses while parsing pipeline");658
659assert(pipelineStack.back() == &pipeline &&660"wrong pipeline at the bottom of the stack");661return success();662}
663
664/// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
665/// the corresponding registry entry.
666LogicalResult TextualPipeline::resolvePipelineElements(667MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {668for (auto &elt : elements)669if (failed(resolvePipelineElement(elt, errorHandler)))670return failure();671return success();672}
673
674/// Resolve a single element of the pipeline.
675LogicalResult
676TextualPipeline::resolvePipelineElement(PipelineElement &element,677ErrorHandlerT errorHandler) {678// If the inner pipeline of this element is not empty, this is an operation679// pipeline.680if (!element.innerPipeline.empty())681return resolvePipelineElements(element.innerPipeline, errorHandler);682
683// Otherwise, this must be a pass or pass pipeline.684// Check to see if a pipeline was registered with this name.685if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))686return success();687
688// If not, then this must be a specific pass name.689if ((element.registryEntry = PassInfo::lookup(element.name)))690return success();691
692// Emit an error for the unknown pass.693auto *rawLoc = element.name.data();694return errorHandler(rawLoc, "'" + element.name +695"' does not refer to a "696"registered pass or pass pipeline");697}
698
699/// Add the given pipeline elements to the provided pass manager.
700LogicalResult TextualPipeline::addToPipeline(701ArrayRef<PipelineElement> elements, OpPassManager &pm,702function_ref<LogicalResult(const Twine &)> errorHandler) const {703for (auto &elt : elements) {704if (elt.registryEntry) {705if (failed(elt.registryEntry->addToPipeline(pm, elt.options,706errorHandler))) {707return errorHandler("failed to add `" + elt.name + "` with options `" +708elt.options + "`");709}710} else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),711errorHandler))) {712return errorHandler("failed to add `" + elt.name + "` with options `" +713elt.options + "` to inner pipeline");714}715}716return success();717}
718
719LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,720raw_ostream &errorStream) {721TextualPipeline pipelineParser;722if (failed(pipelineParser.initialize(pipeline, errorStream)))723return failure();724auto errorHandler = [&](Twine msg) {725errorStream << msg << "\n";726return failure();727};728if (failed(pipelineParser.addToPipeline(pm, errorHandler)))729return failure();730return success();731}
732
733FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,734raw_ostream &errorStream) {735pipeline = pipeline.trim();736// Pipelines are expected to be of the form `<op-name>(<pipeline>)`.737size_t pipelineStart = pipeline.find_first_of('(');738if (pipelineStart == 0 || pipelineStart == StringRef::npos ||739!pipeline.consume_back(")")) {740errorStream << "expected pass pipeline to be wrapped with the anchor "741"operation type, e.g. 'builtin.module(...)'";742return failure();743}744
745StringRef opName = pipeline.take_front(pipelineStart).rtrim();746OpPassManager pm(opName);747if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,748errorStream)))749return failure();750return pm;751}
752
753//===----------------------------------------------------------------------===//
754// PassNameParser
755//===----------------------------------------------------------------------===//
756
757namespace {758/// This struct represents the possible data entries in a parsed pass pipeline
759/// list.
760struct PassArgData {761PassArgData() = default;762PassArgData(const PassRegistryEntry *registryEntry)763: registryEntry(registryEntry) {}764
765/// This field is used when the parsed option corresponds to a registered pass766/// or pass pipeline.767const PassRegistryEntry *registryEntry{nullptr};768
769/// This field is set when instance specific pass options have been provided770/// on the command line.771StringRef options;772};773} // namespace774
775namespace llvm {776namespace cl {777/// Define a valid OptionValue for the command line pass argument.
778template <>779struct OptionValue<PassArgData> final780: OptionValueBase<PassArgData, /*isClass=*/true> {781OptionValue(const PassArgData &value) { this->setValue(value); }782OptionValue() = default;783void anchor() override {}784
785bool hasValue() const { return true; }786const PassArgData &getValue() const { return value; }787void setValue(const PassArgData &value) { this->value = value; }788
789PassArgData value;790};791} // namespace cl792} // namespace llvm793
794namespace {795
796/// The name for the command line option used for parsing the textual pass
797/// pipeline.
798#define PASS_PIPELINE_ARG "pass-pipeline"799
800/// Adds command line option for each registered pass or pass pipeline, as well
801/// as textual pass pipelines.
802struct PassNameParser : public llvm::cl::parser<PassArgData> {803PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}804
805void initialize();806void printOptionInfo(const llvm::cl::Option &opt,807size_t globalWidth) const override;808size_t getOptionWidth(const llvm::cl::Option &opt) const override;809bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,810PassArgData &value);811
812/// If true, this parser only parses entries that correspond to a concrete813/// pass registry entry, and does not include pipeline entries or the options814/// for pass entries.815bool passNamesOnly = false;816};817} // namespace818
819void PassNameParser::initialize() {820llvm::cl::parser<PassArgData>::initialize();821
822/// Add the pass entries.823for (const auto &kv : *passRegistry) {824addLiteralOption(kv.second.getPassArgument(), &kv.second,825kv.second.getPassDescription());826}827/// Add the pass pipeline entries.828if (!passNamesOnly) {829for (const auto &kv : *passPipelineRegistry) {830addLiteralOption(kv.second.getPassArgument(), &kv.second,831kv.second.getPassDescription());832}833}834}
835
836void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,837size_t globalWidth) const {838// If this parser is just parsing pass names, print a simplified option839// string.840if (passNamesOnly) {841llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>";842opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);843return;844}845
846// Print the information for the top-level option.847if (opt.hasArgStr()) {848llvm::outs() << " --" << opt.ArgStr;849opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);850} else {851llvm::outs() << " " << opt.HelpStr << '\n';852}853
854// Functor used to print the ordered entries of a registration map.855auto printOrderedEntries = [&](StringRef header, auto &map) {856llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;857for (auto &kv : map)858orderedEntries.push_back(&kv.second);859llvm::array_pod_sort(860orderedEntries.begin(), orderedEntries.end(),861[](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {862return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());863});864
865llvm::outs().indent(4) << header << ":\n";866for (PassRegistryEntry *entry : orderedEntries)867entry->printHelpStr(/*indent=*/6, globalWidth);868};869
870// Print the available passes.871printOrderedEntries("Passes", *passRegistry);872
873// Print the available pass pipelines.874if (!passPipelineRegistry->empty())875printOrderedEntries("Pass Pipelines", *passPipelineRegistry);876}
877
878size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {879size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;880
881// Check for any wider pass or pipeline options.882for (auto &entry : *passRegistry)883maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);884for (auto &entry : *passPipelineRegistry)885maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);886return maxWidth;887}
888
889bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,890StringRef arg, PassArgData &value) {891if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))892return true;893value.options = arg;894return false;895}
896
897//===----------------------------------------------------------------------===//
898// PassPipelineCLParser
899//===----------------------------------------------------------------------===//
900
901namespace mlir {902namespace detail {903struct PassPipelineCLParserImpl {904PassPipelineCLParserImpl(StringRef arg, StringRef description,905bool passNamesOnly)906: passList(arg, llvm::cl::desc(description)) {907passList.getParser().passNamesOnly = passNamesOnly;908passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);909}910
911/// Returns true if the given pass registry entry was registered at the912/// top-level of the parser, i.e. not within an explicit textual pipeline.913bool contains(const PassRegistryEntry *entry) const {914return llvm::any_of(passList, [&](const PassArgData &data) {915return data.registryEntry == entry;916});917}918
919/// The set of passes and pass pipelines to run.920llvm::cl::list<PassArgData, bool, PassNameParser> passList;921};922} // namespace detail923} // namespace mlir924
925/// Construct a pass pipeline parser with the given command line description.
926PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)927: impl(std::make_unique<detail::PassPipelineCLParserImpl>(928arg, description, /*passNamesOnly=*/false)),929passPipeline(930PASS_PIPELINE_ARG,931llvm::cl::desc("Textual description of the pass pipeline to run")) {}932
933PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,934StringRef alias)935: PassPipelineCLParser(arg, description) {936passPipelineAlias.emplace(alias,937llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),938llvm::cl::aliasopt(passPipeline));939}
940
941PassPipelineCLParser::~PassPipelineCLParser() = default;942
943/// Returns true if this parser contains any valid options to add.
944bool PassPipelineCLParser::hasAnyOccurrences() const {945return passPipeline.getNumOccurrences() != 0 ||946impl->passList.getNumOccurrences() != 0;947}
948
949/// Returns true if the given pass registry entry was registered at the
950/// top-level of the parser, i.e. not within an explicit textual pipeline.
951bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {952return impl->contains(entry);953}
954
955/// Adds the passes defined by this parser entry to the given pass manager.
956LogicalResult PassPipelineCLParser::addToPipeline(957OpPassManager &pm,958function_ref<LogicalResult(const Twine &)> errorHandler) const {959if (passPipeline.getNumOccurrences()) {960if (impl->passList.getNumOccurrences())961return errorHandler(962"'-" PASS_PIPELINE_ARG963"' option can't be used with individual pass options");964std::string errMsg;965llvm::raw_string_ostream os(errMsg);966FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);967if (failed(parsed))968return errorHandler(errMsg);969pm = std::move(*parsed);970return success();971}972
973for (auto &passIt : impl->passList) {974if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,975errorHandler)))976return failure();977}978return success();979}
980
981//===----------------------------------------------------------------------===//
982// PassNameCLParser
983
984/// Construct a pass pipeline parser with the given command line description.
985PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)986: impl(std::make_unique<detail::PassPipelineCLParserImpl>(987arg, description, /*passNamesOnly=*/true)) {988impl->passList.setMiscFlag(llvm::cl::CommaSeparated);989}
990PassNameCLParser::~PassNameCLParser() = default;991
992/// Returns true if this parser contains any valid options to add.
993bool PassNameCLParser::hasAnyOccurrences() const {994return impl->passList.getNumOccurrences() != 0;995}
996
997/// Returns true if the given pass registry entry was registered at the
998/// top-level of the parser, i.e. not within an explicit textual pipeline.
999bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {1000return impl->contains(entry);1001}
1002