llvm-project

Форк
0
/
PassRegistry.cpp 
1001 строка · 36.9 Кб
1
//===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8

9
#include "mlir/Pass/PassRegistry.h"
10

11
#include "mlir/Pass/Pass.h"
12
#include "mlir/Pass/PassManager.h"
13
#include "llvm/ADT/DenseMap.h"
14
#include "llvm/ADT/ScopeExit.h"
15
#include "llvm/Support/Format.h"
16
#include "llvm/Support/ManagedStatic.h"
17
#include "llvm/Support/MemoryBuffer.h"
18
#include "llvm/Support/SourceMgr.h"
19

20
#include <optional>
21
#include <utility>
22

23
using namespace mlir;
24
using namespace detail;
25

26
/// Static mapping of all of the registered passes.
27
static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
28

29
/// A mapping of the above pass registry entries to the corresponding TypeID
30
/// of the pass that they generate.
31
static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
32

33
/// Static mapping of all of the registered pass pipelines.
34
static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
35
    passPipelineRegistry;
36

37
/// Utility to create a default registry function from a pass instance.
38
static PassRegistryFunction
39
buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
40
  return [=](OpPassManager &pm, StringRef options,
41
             function_ref<LogicalResult(const Twine &)> errorHandler) {
42
    std::unique_ptr<Pass> pass = allocator();
43
    LogicalResult result = pass->initializeOptions(options, errorHandler);
44

45
    std::optional<StringRef> pmOpName = pm.getOpName();
46
    std::optional<StringRef> passOpName = pass->getOpName();
47
    if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
48
        passOpName && *pmOpName != *passOpName) {
49
      return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
50
                          "' restricted to '" + *pass->getOpName() +
51
                          "' on a PassManager intended to run on '" +
52
                          pm.getOpAnchorName() + "', did you intend to nest?");
53
    }
54
    pm.addPass(std::move(pass));
55
    return result;
56
  };
57
}
58

59
/// Utility to print the help string for a specific option.
60
static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
61
                            size_t descIndent, bool isTopLevel) {
62
  size_t numSpaces = descIndent - indent - 4;
63
  llvm::outs().indent(indent)
64
      << "--" << llvm::left_justify(arg, numSpaces) << "-   " << desc << '\n';
65
}
66

67
//===----------------------------------------------------------------------===//
68
// PassRegistry
69
//===----------------------------------------------------------------------===//
70

71
/// Print the help information for this pass. This includes the argument,
72
/// description, and any pass options. `descIndent` is the indent that the
73
/// descriptions should be aligned.
74
void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
75
  printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
76
                  /*isTopLevel=*/true);
77
  // If this entry has options, print the help for those as well.
78
  optHandler([=](const PassOptions &options) {
79
    options.printHelp(indent, descIndent);
80
  });
81
}
82

83
/// Return the maximum width required when printing the options of this
84
/// entry.
85
size_t PassRegistryEntry::getOptionWidth() const {
86
  size_t maxLen = 0;
87
  optHandler([&](const PassOptions &options) mutable {
88
    maxLen = options.getOptionWidth() + 2;
89
  });
90
  return maxLen;
91
}
92

93
//===----------------------------------------------------------------------===//
94
// PassPipelineInfo
95
//===----------------------------------------------------------------------===//
96

97
void mlir::registerPassPipeline(
98
    StringRef arg, StringRef description, const PassRegistryFunction &function,
99
    std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
100
  PassPipelineInfo pipelineInfo(arg, description, function,
101
                                std::move(optHandler));
102
  bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
103
#ifndef NDEBUG
104
  if (!inserted)
105
    report_fatal_error("Pass pipeline " + arg + " registered multiple times");
106
#endif
107
  (void)inserted;
108
}
109

110
//===----------------------------------------------------------------------===//
111
// PassInfo
112
//===----------------------------------------------------------------------===//
113

114
PassInfo::PassInfo(StringRef arg, StringRef description,
115
                   const PassAllocatorFunction &allocator)
116
    : PassRegistryEntry(
117
          arg, description, buildDefaultRegistryFn(allocator),
118
          // Use a temporary pass to provide an options instance.
119
          [=](function_ref<void(const PassOptions &)> optHandler) {
120
            optHandler(allocator()->passOptions);
121
          }) {}
122

123
void mlir::registerPass(const PassAllocatorFunction &function) {
124
  std::unique_ptr<Pass> pass = function();
125
  StringRef arg = pass->getArgument();
126
  if (arg.empty())
127
    llvm::report_fatal_error(llvm::Twine("Trying to register '") +
128
                             pass->getName() +
129
                             "' pass that does not override `getArgument()`");
130
  StringRef description = pass->getDescription();
131
  PassInfo passInfo(arg, description, function);
132
  passRegistry->try_emplace(arg, passInfo);
133

134
  // Verify that the registered pass has the same ID as any registered to this
135
  // arg before it.
136
  TypeID entryTypeID = pass->getTypeID();
137
  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
138
  if (it->second != entryTypeID)
139
    llvm::report_fatal_error(
140
        "pass allocator creates a different pass than previously "
141
        "registered for pass " +
142
        arg);
143
}
144

145
/// Returns the pass info for the specified pass argument or null if unknown.
146
const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
147
  auto it = passRegistry->find(passArg);
148
  return it == passRegistry->end() ? nullptr : &it->second;
149
}
150

151
/// Returns the pass pipeline info for the specified pass pipeline argument or
152
/// null if unknown.
153
const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
154
  auto it = passPipelineRegistry->find(pipelineArg);
155
  return it == passPipelineRegistry->end() ? nullptr : &it->second;
156
}
157

158
//===----------------------------------------------------------------------===//
159
// PassOptions
160
//===----------------------------------------------------------------------===//
161

162
LogicalResult detail::pass_options::parseCommaSeparatedList(
163
    llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
164
    function_ref<LogicalResult(StringRef)> elementParseFn) {
165
  // Functor used for finding a character in a string, and skipping over
166
  // various "range" characters.
167
  llvm::unique_function<size_t(StringRef, size_t, char)> findChar =
168
      [&](StringRef str, size_t index, char c) -> size_t {
169
    for (size_t i = index, e = str.size(); i < e; ++i) {
170
      if (str[i] == c)
171
        return i;
172
      // Check for various range characters.
173
      if (str[i] == '{')
174
        i = findChar(str, i + 1, '}');
175
      else if (str[i] == '(')
176
        i = findChar(str, i + 1, ')');
177
      else if (str[i] == '[')
178
        i = findChar(str, i + 1, ']');
179
      else if (str[i] == '\"')
180
        i = str.find_first_of('\"', i + 1);
181
      else if (str[i] == '\'')
182
        i = str.find_first_of('\'', i + 1);
183
    }
184
    return StringRef::npos;
185
  };
186

187
  size_t nextElePos = findChar(optionStr, 0, ',');
188
  while (nextElePos != StringRef::npos) {
189
    // Process the portion before the comma.
190
    if (failed(elementParseFn(optionStr.substr(0, nextElePos))))
191
      return failure();
192

193
    optionStr = optionStr.substr(nextElePos + 1);
194
    nextElePos = findChar(optionStr, 0, ',');
195
  }
196
  return elementParseFn(optionStr.substr(0, nextElePos));
197
}
198

199
/// Out of line virtual function to provide home for the class.
200
void detail::PassOptions::OptionBase::anchor() {}
201

202
/// Copy the option values from 'other'.
203
void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
204
  assert(options.size() == other.options.size());
205
  if (options.empty())
206
    return;
207
  for (auto optionsIt : llvm::zip(options, other.options))
208
    std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
209
}
210

211
/// Parse in the next argument from the given options string. Returns a tuple
212
/// containing [the key of the option, the value of the option, updated
213
/// `options` string pointing after the parsed option].
214
static std::tuple<StringRef, StringRef, StringRef>
215
parseNextArg(StringRef options) {
216
  // Functor used to extract an argument from 'options' and update it to point
217
  // after the arg.
218
  auto extractArgAndUpdateOptions = [&](size_t argSize) {
219
    StringRef str = options.take_front(argSize).trim();
220
    options = options.drop_front(argSize).ltrim();
221
    // Handle escape sequences
222
    if (str.size() > 2) {
223
      const auto escapePairs = {std::make_pair('\'', '\''),
224
                                std::make_pair('"', '"'),
225
                                std::make_pair('{', '}')};
226
      for (const auto &escape : escapePairs) {
227
        if (str.front() == escape.first && str.back() == escape.second) {
228
          // Drop the escape characters and trim.
229
          str = str.drop_front().drop_back().trim();
230
          // Don't process additional escape sequences.
231
          break;
232
        }
233
      }
234
    }
235
    return str;
236
  };
237
  // Try to process the given punctuation, properly escaping any contained
238
  // characters.
239
  auto tryProcessPunct = [&](size_t &currentPos, char punct) {
240
    if (options[currentPos] != punct)
241
      return false;
242
    size_t nextIt = options.find_first_of(punct, currentPos + 1);
243
    if (nextIt != StringRef::npos)
244
      currentPos = nextIt;
245
    return true;
246
  };
247

248
  // Parse the argument name of the option.
249
  StringRef argName;
250
  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
251
    // Check for the end of the full option.
252
    if (argEndIt == optionsE || options[argEndIt] == ' ') {
253
      argName = extractArgAndUpdateOptions(argEndIt);
254
      return std::make_tuple(argName, StringRef(), options);
255
    }
256

257
    // Check for the end of the name and the start of the value.
258
    if (options[argEndIt] == '=') {
259
      argName = extractArgAndUpdateOptions(argEndIt);
260
      options = options.drop_front();
261
      break;
262
    }
263
  }
264

265
  // Parse the value of the option.
266
  for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
267
    // Handle the end of the options string.
268
    if (argEndIt == optionsE || options[argEndIt] == ' ') {
269
      StringRef value = extractArgAndUpdateOptions(argEndIt);
270
      return std::make_tuple(argName, value, options);
271
    }
272

273
    // Skip over escaped sequences.
274
    char c = options[argEndIt];
275
    if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
276
      continue;
277
    // '{...}' is used to specify options to passes, properly escape it so
278
    // that we don't accidentally split any nested options.
279
    if (c == '{') {
280
      size_t braceCount = 1;
281
      for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
282
        // Allow nested punctuation.
283
        if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
284
          continue;
285
        if (options[argEndIt] == '{')
286
          ++braceCount;
287
        else if (options[argEndIt] == '}' && --braceCount == 0)
288
          break;
289
      }
290
      // Account for the increment at the top of the loop.
291
      --argEndIt;
292
    }
293
  }
294
  llvm_unreachable("unexpected control flow in pass option parsing");
295
}
296

297
LogicalResult detail::PassOptions::parseFromString(StringRef options,
298
                                                   raw_ostream &errorStream) {
299
  // NOTE: `options` is modified in place to always refer to the unprocessed
300
  // part of the string.
301
  while (!options.empty()) {
302
    StringRef key, value;
303
    std::tie(key, value, options) = parseNextArg(options);
304
    if (key.empty())
305
      continue;
306

307
    auto it = OptionsMap.find(key);
308
    if (it == OptionsMap.end()) {
309
      errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
310
      return failure();
311
    }
312
    if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
313
      return failure();
314
  }
315

316
  return success();
317
}
318

319
/// Print the options held by this struct in a form that can be parsed via
320
/// 'parseFromString'.
321
void detail::PassOptions::print(raw_ostream &os) {
322
  // If there are no options, there is nothing left to do.
323
  if (OptionsMap.empty())
324
    return;
325

326
  // Sort the options to make the ordering deterministic.
327
  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
328
  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
329
    return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
330
  };
331
  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
332

333
  // Interleave the options with ' '.
334
  os << '{';
335
  llvm::interleave(
336
      orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
337
  os << '}';
338
}
339

340
/// Print the help string for the options held by this struct. `descIndent` is
341
/// the indent within the stream that the descriptions should be aligned.
342
void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
343
  // Sort the options to make the ordering deterministic.
344
  SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
345
  auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
346
    return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
347
  };
348
  llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
349
  for (OptionBase *option : orderedOps) {
350
    // TODO: printOptionInfo assumes a specific indent and will
351
    // print options with values with incorrect indentation. We should add
352
    // support to llvm::cl::Option for passing in a base indent to use when
353
    // printing.
354
    llvm::outs().indent(indent);
355
    option->getOption()->printOptionInfo(descIndent - indent);
356
  }
357
}
358

359
/// Return the maximum width required when printing the help string.
360
size_t detail::PassOptions::getOptionWidth() const {
361
  size_t max = 0;
362
  for (auto *option : options)
363
    max = std::max(max, option->getOption()->getOptionWidth());
364
  return max;
365
}
366

367
//===----------------------------------------------------------------------===//
368
// MLIR Options
369
//===----------------------------------------------------------------------===//
370

371
//===----------------------------------------------------------------------===//
372
// OpPassManager: OptionValue
373

374
llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
375
llvm::cl::OptionValue<OpPassManager>::OptionValue(
376
    const mlir::OpPassManager &value) {
377
  setValue(value);
378
}
379
llvm::cl::OptionValue<OpPassManager>::OptionValue(
380
    const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) {
381
  if (rhs.hasValue())
382
    setValue(rhs.getValue());
383
}
384
llvm::cl::OptionValue<OpPassManager> &
385
llvm::cl::OptionValue<OpPassManager>::operator=(
386
    const mlir::OpPassManager &rhs) {
387
  setValue(rhs);
388
  return *this;
389
}
390

391
llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
392

393
void llvm::cl::OptionValue<OpPassManager>::setValue(
394
    const OpPassManager &newValue) {
395
  if (hasValue())
396
    *value = newValue;
397
  else
398
    value = std::make_unique<mlir::OpPassManager>(newValue);
399
}
400
void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
401
  FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
402
  assert(succeeded(pipeline) && "invalid pass pipeline");
403
  setValue(*pipeline);
404
}
405

406
bool llvm::cl::OptionValue<OpPassManager>::compare(
407
    const mlir::OpPassManager &rhs) const {
408
  std::string lhsStr, rhsStr;
409
  {
410
    raw_string_ostream lhsStream(lhsStr);
411
    value->printAsTextualPipeline(lhsStream);
412

413
    raw_string_ostream rhsStream(rhsStr);
414
    rhs.printAsTextualPipeline(rhsStream);
415
  }
416

417
  // Use the textual format for pipeline comparisons.
418
  return lhsStr == rhsStr;
419
}
420

421
void llvm::cl::OptionValue<OpPassManager>::anchor() {}
422

423
//===----------------------------------------------------------------------===//
424
// OpPassManager: Parser
425

426
namespace llvm {
427
namespace cl {
428
template class basic_parser<OpPassManager>;
429
} // namespace cl
430
} // namespace llvm
431

432
bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
433
                                            ParsedPassManager &value) {
434
  FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
435
  if (failed(pipeline))
436
    return true;
437
  value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
438
  return false;
439
}
440

441
void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
442
                                            const OpPassManager &value) {
443
  value.printAsTextualPipeline(os);
444
}
445

446
void llvm::cl::parser<OpPassManager>::printOptionDiff(
447
    const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
448
    size_t globalWidth) const {
449
  printOptionName(opt, globalWidth);
450
  outs() << "= ";
451
  pm.printAsTextualPipeline(outs());
452

453
  if (defaultValue.hasValue()) {
454
    outs().indent(2) << " (default: ";
455
    defaultValue.getValue().printAsTextualPipeline(outs());
456
    outs() << ")";
457
  }
458
  outs() << "\n";
459
}
460

461
void llvm::cl::parser<OpPassManager>::anchor() {}
462

463
llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
464
    default;
465
llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
466
    ParsedPassManager &&) = default;
467
llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
468
    default;
469

470
//===----------------------------------------------------------------------===//
471
// TextualPassPipeline Parser
472
//===----------------------------------------------------------------------===//
473

474
namespace {
475
/// This class represents a textual description of a pass pipeline.
476
class TextualPipeline {
477
public:
478
  /// Try to initialize this pipeline with the given pipeline text.
479
  /// `errorStream` is the output stream to emit errors to.
480
  LogicalResult initialize(StringRef text, raw_ostream &errorStream);
481

482
  /// Add the internal pipeline elements to the provided pass manager.
483
  LogicalResult
484
  addToPipeline(OpPassManager &pm,
485
                function_ref<LogicalResult(const Twine &)> errorHandler) const;
486

487
private:
488
  /// A functor used to emit errors found during pipeline handling. The first
489
  /// parameter corresponds to the raw location within the pipeline string. This
490
  /// should always return failure.
491
  using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
492

493
  /// A struct to capture parsed pass pipeline names.
494
  ///
495
  /// A pipeline is defined as a series of names, each of which may in itself
496
  /// recursively contain a nested pipeline. A name is either the name of a pass
497
  /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
498
  /// the name is the name of a pass, the InnerPipeline is empty, since passes
499
  /// cannot contain inner pipelines.
500
  struct PipelineElement {
501
    PipelineElement(StringRef name) : name(name) {}
502

503
    StringRef name;
504
    StringRef options;
505
    const PassRegistryEntry *registryEntry = nullptr;
506
    std::vector<PipelineElement> innerPipeline;
507
  };
508

509
  /// Parse the given pipeline text into the internal pipeline vector. This
510
  /// function only parses the structure of the pipeline, and does not resolve
511
  /// its elements.
512
  LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
513

514
  /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
515
  /// the corresponding registry entry.
516
  LogicalResult
517
  resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
518
                          ErrorHandlerT errorHandler);
519

520
  /// Resolve a single element of the pipeline.
521
  LogicalResult resolvePipelineElement(PipelineElement &element,
522
                                       ErrorHandlerT errorHandler);
523

524
  /// Add the given pipeline elements to the provided pass manager.
525
  LogicalResult
526
  addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
527
                function_ref<LogicalResult(const Twine &)> errorHandler) const;
528

529
  std::vector<PipelineElement> pipeline;
530
};
531

532
} // namespace
533

534
/// Try to initialize this pipeline with the given pipeline text. An option is
535
/// given to enable accurate error reporting.
536
LogicalResult TextualPipeline::initialize(StringRef text,
537
                                          raw_ostream &errorStream) {
538
  if (text.empty())
539
    return success();
540

541
  // Build a source manager to use for error reporting.
542
  llvm::SourceMgr pipelineMgr;
543
  pipelineMgr.AddNewSourceBuffer(
544
      llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
545
                                       /*RequiresNullTerminator=*/false),
546
      SMLoc());
547
  auto errorHandler = [&](const char *rawLoc, Twine msg) {
548
    pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
549
                             llvm::SourceMgr::DK_Error, msg);
550
    return failure();
551
  };
552

553
  // Parse the provided pipeline string.
554
  if (failed(parsePipelineText(text, errorHandler)))
555
    return failure();
556
  return resolvePipelineElements(pipeline, errorHandler);
557
}
558

559
/// Add the internal pipeline elements to the provided pass manager.
560
LogicalResult TextualPipeline::addToPipeline(
561
    OpPassManager &pm,
562
    function_ref<LogicalResult(const Twine &)> errorHandler) const {
563
  // Temporarily disable implicit nesting while we append to the pipeline. We
564
  // want the created pipeline to exactly match the parsed text pipeline, so
565
  // it's preferrable to just error out if implicit nesting would be required.
566
  OpPassManager::Nesting nesting = pm.getNesting();
567
  pm.setNesting(OpPassManager::Nesting::Explicit);
568
  auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
569

570
  return addToPipeline(pipeline, pm, errorHandler);
571
}
572

573
/// Parse the given pipeline text into the internal pipeline vector. This
574
/// function only parses the structure of the pipeline, and does not resolve
575
/// its elements.
576
LogicalResult TextualPipeline::parsePipelineText(StringRef text,
577
                                                 ErrorHandlerT errorHandler) {
578
  SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
579
  for (;;) {
580
    std::vector<PipelineElement> &pipeline = *pipelineStack.back();
581
    size_t pos = text.find_first_of(",(){");
582
    pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
583

584
    // If we have a single terminating name, we're done.
585
    if (pos == StringRef::npos)
586
      break;
587

588
    text = text.substr(pos);
589
    char sep = text[0];
590

591
    // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
592
    if (sep == '{') {
593
      text = text.substr(1);
594

595
      // Skip over everything until the closing '}' and store as options.
596
      size_t close = StringRef::npos;
597
      for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
598
        if (text[i] == '{') {
599
          ++braceCount;
600
          continue;
601
        }
602
        if (text[i] == '}' && --braceCount == 0) {
603
          close = i;
604
          break;
605
        }
606
      }
607

608
      // Check to see if a closing options brace was found.
609
      if (close == StringRef::npos) {
610
        return errorHandler(
611
            /*rawLoc=*/text.data() - 1,
612
            "missing closing '}' while processing pass options");
613
      }
614
      pipeline.back().options = text.substr(0, close);
615
      text = text.substr(close + 1);
616

617
      // Consume space characters that an user might add for readability.
618
      text = text.ltrim();
619

620
      // Skip checking for '(' because nested pipelines cannot have options.
621
    } else if (sep == '(') {
622
      text = text.substr(1);
623

624
      // Push the inner pipeline onto the stack to continue processing.
625
      pipelineStack.push_back(&pipeline.back().innerPipeline);
626
      continue;
627
    }
628

629
    // When handling the close parenthesis, we greedily consume them to avoid
630
    // empty strings in the pipeline.
631
    while (text.consume_front(")")) {
632
      // If we try to pop the outer pipeline we have unbalanced parentheses.
633
      if (pipelineStack.size() == 1)
634
        return errorHandler(/*rawLoc=*/text.data() - 1,
635
                            "encountered extra closing ')' creating unbalanced "
636
                            "parentheses while parsing pipeline");
637

638
      pipelineStack.pop_back();
639
      // Consume space characters that an user might add for readability.
640
      text = text.ltrim();
641
    }
642

643
    // Check if we've finished parsing.
644
    if (text.empty())
645
      break;
646

647
    // Otherwise, the end of an inner pipeline always has to be followed by
648
    // a comma, and then we can continue.
649
    if (!text.consume_front(","))
650
      return errorHandler(text.data(), "expected ',' after parsing pipeline");
651
  }
652

653
  // Check for unbalanced parentheses.
654
  if (pipelineStack.size() > 1)
655
    return errorHandler(
656
        text.data(),
657
        "encountered unbalanced parentheses while parsing pipeline");
658

659
  assert(pipelineStack.back() == &pipeline &&
660
         "wrong pipeline at the bottom of the stack");
661
  return success();
662
}
663

664
/// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
665
/// the corresponding registry entry.
666
LogicalResult TextualPipeline::resolvePipelineElements(
667
    MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
668
  for (auto &elt : elements)
669
    if (failed(resolvePipelineElement(elt, errorHandler)))
670
      return failure();
671
  return success();
672
}
673

674
/// Resolve a single element of the pipeline.
675
LogicalResult
676
TextualPipeline::resolvePipelineElement(PipelineElement &element,
677
                                        ErrorHandlerT errorHandler) {
678
  // If the inner pipeline of this element is not empty, this is an operation
679
  // pipeline.
680
  if (!element.innerPipeline.empty())
681
    return resolvePipelineElements(element.innerPipeline, errorHandler);
682

683
  // Otherwise, this must be a pass or pass pipeline.
684
  // Check to see if a pipeline was registered with this name.
685
  if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
686
    return success();
687

688
  // If not, then this must be a specific pass name.
689
  if ((element.registryEntry = PassInfo::lookup(element.name)))
690
    return success();
691

692
  // Emit an error for the unknown pass.
693
  auto *rawLoc = element.name.data();
694
  return errorHandler(rawLoc, "'" + element.name +
695
                                  "' does not refer to a "
696
                                  "registered pass or pass pipeline");
697
}
698

699
/// Add the given pipeline elements to the provided pass manager.
700
LogicalResult TextualPipeline::addToPipeline(
701
    ArrayRef<PipelineElement> elements, OpPassManager &pm,
702
    function_ref<LogicalResult(const Twine &)> errorHandler) const {
703
  for (auto &elt : elements) {
704
    if (elt.registryEntry) {
705
      if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
706
                                                  errorHandler))) {
707
        return errorHandler("failed to add `" + elt.name + "` with options `" +
708
                            elt.options + "`");
709
      }
710
    } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
711
                                    errorHandler))) {
712
      return errorHandler("failed to add `" + elt.name + "` with options `" +
713
                          elt.options + "` to inner pipeline");
714
    }
715
  }
716
  return success();
717
}
718

719
LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
720
                                      raw_ostream &errorStream) {
721
  TextualPipeline pipelineParser;
722
  if (failed(pipelineParser.initialize(pipeline, errorStream)))
723
    return failure();
724
  auto errorHandler = [&](Twine msg) {
725
    errorStream << msg << "\n";
726
    return failure();
727
  };
728
  if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
729
    return failure();
730
  return success();
731
}
732

733
FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
734
                                                 raw_ostream &errorStream) {
735
  pipeline = pipeline.trim();
736
  // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
737
  size_t pipelineStart = pipeline.find_first_of('(');
738
  if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
739
      !pipeline.consume_back(")")) {
740
    errorStream << "expected pass pipeline to be wrapped with the anchor "
741
                   "operation type, e.g. 'builtin.module(...)'";
742
    return failure();
743
  }
744

745
  StringRef opName = pipeline.take_front(pipelineStart).rtrim();
746
  OpPassManager pm(opName);
747
  if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
748
                               errorStream)))
749
    return failure();
750
  return pm;
751
}
752

753
//===----------------------------------------------------------------------===//
754
// PassNameParser
755
//===----------------------------------------------------------------------===//
756

757
namespace {
758
/// This struct represents the possible data entries in a parsed pass pipeline
759
/// list.
760
struct PassArgData {
761
  PassArgData() = default;
762
  PassArgData(const PassRegistryEntry *registryEntry)
763
      : registryEntry(registryEntry) {}
764

765
  /// This field is used when the parsed option corresponds to a registered pass
766
  /// or pass pipeline.
767
  const PassRegistryEntry *registryEntry{nullptr};
768

769
  /// This field is set when instance specific pass options have been provided
770
  /// on the command line.
771
  StringRef options;
772
};
773
} // namespace
774

775
namespace llvm {
776
namespace cl {
777
/// Define a valid OptionValue for the command line pass argument.
778
template <>
779
struct OptionValue<PassArgData> final
780
    : OptionValueBase<PassArgData, /*isClass=*/true> {
781
  OptionValue(const PassArgData &value) { this->setValue(value); }
782
  OptionValue() = default;
783
  void anchor() override {}
784

785
  bool hasValue() const { return true; }
786
  const PassArgData &getValue() const { return value; }
787
  void setValue(const PassArgData &value) { this->value = value; }
788

789
  PassArgData value;
790
};
791
} // namespace cl
792
} // namespace llvm
793

794
namespace {
795

796
/// The name for the command line option used for parsing the textual pass
797
/// pipeline.
798
#define PASS_PIPELINE_ARG "pass-pipeline"
799

800
/// Adds command line option for each registered pass or pass pipeline, as well
801
/// as textual pass pipelines.
802
struct PassNameParser : public llvm::cl::parser<PassArgData> {
803
  PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
804

805
  void initialize();
806
  void printOptionInfo(const llvm::cl::Option &opt,
807
                       size_t globalWidth) const override;
808
  size_t getOptionWidth(const llvm::cl::Option &opt) const override;
809
  bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
810
             PassArgData &value);
811

812
  /// If true, this parser only parses entries that correspond to a concrete
813
  /// pass registry entry, and does not include pipeline entries or the options
814
  /// for pass entries.
815
  bool passNamesOnly = false;
816
};
817
} // namespace
818

819
void PassNameParser::initialize() {
820
  llvm::cl::parser<PassArgData>::initialize();
821

822
  /// Add the pass entries.
823
  for (const auto &kv : *passRegistry) {
824
    addLiteralOption(kv.second.getPassArgument(), &kv.second,
825
                     kv.second.getPassDescription());
826
  }
827
  /// Add the pass pipeline entries.
828
  if (!passNamesOnly) {
829
    for (const auto &kv : *passPipelineRegistry) {
830
      addLiteralOption(kv.second.getPassArgument(), &kv.second,
831
                       kv.second.getPassDescription());
832
    }
833
  }
834
}
835

836
void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
837
                                     size_t globalWidth) const {
838
  // If this parser is just parsing pass names, print a simplified option
839
  // string.
840
  if (passNamesOnly) {
841
    llvm::outs() << "  --" << opt.ArgStr << "=<pass-arg>";
842
    opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
843
    return;
844
  }
845

846
  // Print the information for the top-level option.
847
  if (opt.hasArgStr()) {
848
    llvm::outs() << "  --" << opt.ArgStr;
849
    opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
850
  } else {
851
    llvm::outs() << "  " << opt.HelpStr << '\n';
852
  }
853

854
  // Functor used to print the ordered entries of a registration map.
855
  auto printOrderedEntries = [&](StringRef header, auto &map) {
856
    llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
857
    for (auto &kv : map)
858
      orderedEntries.push_back(&kv.second);
859
    llvm::array_pod_sort(
860
        orderedEntries.begin(), orderedEntries.end(),
861
        [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
862
          return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
863
        });
864

865
    llvm::outs().indent(4) << header << ":\n";
866
    for (PassRegistryEntry *entry : orderedEntries)
867
      entry->printHelpStr(/*indent=*/6, globalWidth);
868
  };
869

870
  // Print the available passes.
871
  printOrderedEntries("Passes", *passRegistry);
872

873
  // Print the available pass pipelines.
874
  if (!passPipelineRegistry->empty())
875
    printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
876
}
877

878
size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
879
  size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
880

881
  // Check for any wider pass or pipeline options.
882
  for (auto &entry : *passRegistry)
883
    maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
884
  for (auto &entry : *passPipelineRegistry)
885
    maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
886
  return maxWidth;
887
}
888

889
bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
890
                           StringRef arg, PassArgData &value) {
891
  if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
892
    return true;
893
  value.options = arg;
894
  return false;
895
}
896

897
//===----------------------------------------------------------------------===//
898
// PassPipelineCLParser
899
//===----------------------------------------------------------------------===//
900

901
namespace mlir {
902
namespace detail {
903
struct PassPipelineCLParserImpl {
904
  PassPipelineCLParserImpl(StringRef arg, StringRef description,
905
                           bool passNamesOnly)
906
      : passList(arg, llvm::cl::desc(description)) {
907
    passList.getParser().passNamesOnly = passNamesOnly;
908
    passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
909
  }
910

911
  /// Returns true if the given pass registry entry was registered at the
912
  /// top-level of the parser, i.e. not within an explicit textual pipeline.
913
  bool contains(const PassRegistryEntry *entry) const {
914
    return llvm::any_of(passList, [&](const PassArgData &data) {
915
      return data.registryEntry == entry;
916
    });
917
  }
918

919
  /// The set of passes and pass pipelines to run.
920
  llvm::cl::list<PassArgData, bool, PassNameParser> passList;
921
};
922
} // namespace detail
923
} // namespace mlir
924

925
/// Construct a pass pipeline parser with the given command line description.
926
PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
927
    : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
928
          arg, description, /*passNamesOnly=*/false)),
929
      passPipeline(
930
          PASS_PIPELINE_ARG,
931
          llvm::cl::desc("Textual description of the pass pipeline to run")) {}
932

933
PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
934
                                           StringRef alias)
935
    : PassPipelineCLParser(arg, description) {
936
  passPipelineAlias.emplace(alias,
937
                            llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
938
                            llvm::cl::aliasopt(passPipeline));
939
}
940

941
PassPipelineCLParser::~PassPipelineCLParser() = default;
942

943
/// Returns true if this parser contains any valid options to add.
944
bool PassPipelineCLParser::hasAnyOccurrences() const {
945
  return passPipeline.getNumOccurrences() != 0 ||
946
         impl->passList.getNumOccurrences() != 0;
947
}
948

949
/// Returns true if the given pass registry entry was registered at the
950
/// top-level of the parser, i.e. not within an explicit textual pipeline.
951
bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
952
  return impl->contains(entry);
953
}
954

955
/// Adds the passes defined by this parser entry to the given pass manager.
956
LogicalResult PassPipelineCLParser::addToPipeline(
957
    OpPassManager &pm,
958
    function_ref<LogicalResult(const Twine &)> errorHandler) const {
959
  if (passPipeline.getNumOccurrences()) {
960
    if (impl->passList.getNumOccurrences())
961
      return errorHandler(
962
          "'-" PASS_PIPELINE_ARG
963
          "' option can't be used with individual pass options");
964
    std::string errMsg;
965
    llvm::raw_string_ostream os(errMsg);
966
    FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
967
    if (failed(parsed))
968
      return errorHandler(errMsg);
969
    pm = std::move(*parsed);
970
    return success();
971
  }
972

973
  for (auto &passIt : impl->passList) {
974
    if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
975
                                                   errorHandler)))
976
      return failure();
977
  }
978
  return success();
979
}
980

981
//===----------------------------------------------------------------------===//
982
// PassNameCLParser
983

984
/// Construct a pass pipeline parser with the given command line description.
985
PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
986
    : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
987
          arg, description, /*passNamesOnly=*/true)) {
988
  impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
989
}
990
PassNameCLParser::~PassNameCLParser() = default;
991

992
/// Returns true if this parser contains any valid options to add.
993
bool PassNameCLParser::hasAnyOccurrences() const {
994
  return impl->passList.getNumOccurrences() != 0;
995
}
996

997
/// Returns true if the given pass registry entry was registered at the
998
/// top-level of the parser, i.e. not within an explicit textual pipeline.
999
bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
1000
  return impl->contains(entry);
1001
}
1002

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.