llvm-project

Форк
0
/
ViewOpGraph.cpp 
387 строк · 12.4 Кб
1
//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8

9
#include "mlir/Transforms/ViewOpGraph.h"
10

11
#include "mlir/Analysis/TopologicalSortUtils.h"
12
#include "mlir/IR/Block.h"
13
#include "mlir/IR/BuiltinTypes.h"
14
#include "mlir/IR/Operation.h"
15
#include "mlir/Pass/Pass.h"
16
#include "mlir/Support/IndentedOstream.h"
17
#include "llvm/Support/Format.h"
18
#include "llvm/Support/GraphWriter.h"
19
#include <map>
20
#include <optional>
21
#include <utility>
22

23
namespace mlir {
24
#define GEN_PASS_DEF_VIEWOPGRAPH
25
#include "mlir/Transforms/Passes.h.inc"
26
} // namespace mlir
27

28
using namespace mlir;
29

30
static const StringRef kLineStyleControlFlow = "dashed";
31
static const StringRef kLineStyleDataFlow = "solid";
32
static const StringRef kShapeNode = "ellipse";
33
static const StringRef kShapeNone = "plain";
34

35
/// Return the size limits for eliding large attributes.
36
static int64_t getLargeAttributeSizeLimit() {
37
  // Use the default from the printer flags if possible.
38
  if (std::optional<int64_t> limit =
39
          OpPrintingFlags().getLargeElementsAttrLimit())
40
    return *limit;
41
  return 16;
42
}
43

44
/// Return all values printed onto a stream as a string.
45
static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46
  std::string buf;
47
  llvm::raw_string_ostream os(buf);
48
  func(os);
49
  return os.str();
50
}
51

52
/// Escape special characters such as '\n' and quotation marks.
53
static std::string escapeString(std::string str) {
54
  return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
55
}
56

57
/// Put quotation marks around a given string.
58
static std::string quoteString(const std::string &str) {
59
  return "\"" + str + "\"";
60
}
61

62
using AttributeMap = std::map<std::string, std::string>;
63

64
namespace {
65

66
/// This struct represents a node in the DOT language. Each node has an
67
/// identifier and an optional identifier for the cluster (subgraph) that
68
/// contains the node.
69
/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
70
/// not between clusters. However, edges can be clipped to the boundary of a
71
/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
72
/// cluster, an invisible "anchor" node is created.
73
struct Node {
74
public:
75
  Node(int id = 0, std::optional<int> clusterId = std::nullopt)
76
      : id(id), clusterId(clusterId) {}
77

78
  int id;
79
  std::optional<int> clusterId;
80
};
81

82
/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
83
/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
84
/// about the Graphviz DOT language.
85
class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
86
public:
87
  PrintOpPass(raw_ostream &os) : os(os) {}
88
  PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
89

90
  void runOnOperation() override {
91
    initColorMapping(*getOperation());
92
    emitGraph([&]() {
93
      processOperation(getOperation());
94
      emitAllEdgeStmts();
95
    });
96
  }
97

98
  /// Create a CFG graph for a region. Used in `Region::viewGraph`.
99
  void emitRegionCFG(Region &region) {
100
    printControlFlowEdges = true;
101
    printDataFlowEdges = false;
102
    initColorMapping(region);
103
    emitGraph([&]() { processRegion(region); });
104
  }
105

106
private:
107
  /// Generate a color mapping that will color every operation with the same
108
  /// name the same way. It'll interpolate the hue in the HSV color-space,
109
  /// attempting to keep the contrast suitable for black text.
110
  template <typename T>
111
  void initColorMapping(T &irEntity) {
112
    backgroundColors.clear();
113
    SmallVector<Operation *> ops;
114
    irEntity.walk([&](Operation *op) {
115
      auto &entry = backgroundColors[op->getName()];
116
      if (entry.first == 0)
117
        ops.push_back(op);
118
      ++entry.first;
119
    });
120
    for (auto indexedOps : llvm::enumerate(ops)) {
121
      double hue = ((double)indexedOps.index()) / ops.size();
122
      backgroundColors[indexedOps.value()->getName()].second =
123
          std::to_string(hue) + " 1.0 1.0";
124
    }
125
  }
126

127
  /// Emit all edges. This function should be called after all nodes have been
128
  /// emitted.
129
  void emitAllEdgeStmts() {
130
    if (printDataFlowEdges) {
131
      for (const auto &[value, node, label] : dataFlowEdges) {
132
        emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow);
133
      }
134
    }
135

136
    for (const std::string &edge : edges)
137
      os << edge << ";\n";
138
    edges.clear();
139
  }
140

141
  /// Emit a cluster (subgraph). The specified builder generates the body of the
142
  /// cluster. Return the anchor node of the cluster.
143
  Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
144
    int clusterId = ++counter;
145
    os << "subgraph cluster_" << clusterId << " {\n";
146
    os.indent();
147
    // Emit invisible anchor node from/to which arrows can be drawn.
148
    Node anchorNode = emitNodeStmt(" ", kShapeNone);
149
    os << attrStmt("label", quoteString(escapeString(std::move(label))))
150
       << ";\n";
151
    builder();
152
    os.unindent();
153
    os << "}\n";
154
    return Node(anchorNode.id, clusterId);
155
  }
156

157
  /// Generate an attribute statement.
158
  std::string attrStmt(const Twine &key, const Twine &value) {
159
    return (key + " = " + value).str();
160
  }
161

162
  /// Emit an attribute list.
163
  void emitAttrList(raw_ostream &os, const AttributeMap &map) {
164
    os << "[";
165
    interleaveComma(map, os, [&](const auto &it) {
166
      os << this->attrStmt(it.first, it.second);
167
    });
168
    os << "]";
169
  }
170

171
  // Print an MLIR attribute to `os`. Large attributes are truncated.
172
  void emitMlirAttr(raw_ostream &os, Attribute attr) {
173
    // A value used to elide large container attribute.
174
    int64_t largeAttrLimit = getLargeAttributeSizeLimit();
175

176
    // Always emit splat attributes.
177
    if (isa<SplatElementsAttr>(attr)) {
178
      attr.print(os);
179
      return;
180
    }
181

182
    // Elide "big" elements attributes.
183
    auto elements = dyn_cast<ElementsAttr>(attr);
184
    if (elements && elements.getNumElements() > largeAttrLimit) {
185
      os << std::string(elements.getShapedType().getRank(), '[') << "..."
186
         << std::string(elements.getShapedType().getRank(), ']') << " : "
187
         << elements.getType();
188
      return;
189
    }
190

191
    auto array = dyn_cast<ArrayAttr>(attr);
192
    if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
193
      os << "[...]";
194
      return;
195
    }
196

197
    // Print all other attributes.
198
    std::string buf;
199
    llvm::raw_string_ostream ss(buf);
200
    attr.print(ss);
201
    os << truncateString(ss.str());
202
  }
203

204
  /// Append an edge to the list of edges.
205
  /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
206
  void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
207
    AttributeMap attrs;
208
    attrs["style"] = style.str();
209
    // Do not label edges that start/end at a cluster boundary. Such edges are
210
    // clipped at the boundary, but labels are not. This can lead to labels
211
    // floating around without any edge next to them.
212
    if (!n1.clusterId && !n2.clusterId)
213
      attrs["label"] = quoteString(escapeString(std::move(label)));
214
    // Use `ltail` and `lhead` to draw edges between clusters.
215
    if (n1.clusterId)
216
      attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
217
    if (n2.clusterId)
218
      attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
219

220
    edges.push_back(strFromOs([&](raw_ostream &os) {
221
      os << llvm::format("v%i -> v%i ", n1.id, n2.id);
222
      emitAttrList(os, attrs);
223
    }));
224
  }
225

226
  /// Emit a graph. The specified builder generates the body of the graph.
227
  void emitGraph(function_ref<void()> builder) {
228
    os << "digraph G {\n";
229
    os.indent();
230
    // Edges between clusters are allowed only in compound mode.
231
    os << attrStmt("compound", "true") << ";\n";
232
    builder();
233
    os.unindent();
234
    os << "}\n";
235
  }
236

237
  /// Emit a node statement.
238
  Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
239
                    StringRef background = "") {
240
    int nodeId = ++counter;
241
    AttributeMap attrs;
242
    attrs["label"] = quoteString(escapeString(std::move(label)));
243
    attrs["shape"] = shape.str();
244
    if (!background.empty()) {
245
      attrs["style"] = "filled";
246
      attrs["fillcolor"] = ("\"" + background + "\"").str();
247
    }
248
    os << llvm::format("v%i ", nodeId);
249
    emitAttrList(os, attrs);
250
    os << ";\n";
251
    return Node(nodeId);
252
  }
253

254
  /// Generate a label for an operation.
255
  std::string getLabel(Operation *op) {
256
    return strFromOs([&](raw_ostream &os) {
257
      // Print operation name and type.
258
      os << op->getName();
259
      if (printResultTypes) {
260
        os << " : (";
261
        std::string buf;
262
        llvm::raw_string_ostream ss(buf);
263
        interleaveComma(op->getResultTypes(), ss);
264
        os << truncateString(ss.str()) << ")";
265
      }
266

267
      // Print attributes.
268
      if (printAttrs) {
269
        os << "\n";
270
        for (const NamedAttribute &attr : op->getAttrs()) {
271
          os << '\n' << attr.getName().getValue() << ": ";
272
          emitMlirAttr(os, attr.getValue());
273
        }
274
      }
275
    });
276
  }
277

278
  /// Generate a label for a block argument.
279
  std::string getLabel(BlockArgument arg) {
280
    return "arg" + std::to_string(arg.getArgNumber());
281
  }
282

283
  /// Process a block. Emit a cluster and one node per block argument and
284
  /// operation inside the cluster.
285
  void processBlock(Block &block) {
286
    emitClusterStmt([&]() {
287
      for (BlockArgument &blockArg : block.getArguments())
288
        valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
289

290
      // Emit a node for each operation.
291
      std::optional<Node> prevNode;
292
      for (Operation &op : block) {
293
        Node nextNode = processOperation(&op);
294
        if (printControlFlowEdges && prevNode)
295
          emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
296
                       kLineStyleControlFlow);
297
        prevNode = nextNode;
298
      }
299
    });
300
  }
301

302
  /// Process an operation. If the operation has regions, emit a cluster.
303
  /// Otherwise, emit a node.
304
  Node processOperation(Operation *op) {
305
    Node node;
306
    if (op->getNumRegions() > 0) {
307
      // Emit cluster for op with regions.
308
      node = emitClusterStmt(
309
          [&]() {
310
            for (Region &region : op->getRegions())
311
              processRegion(region);
312
          },
313
          getLabel(op));
314
    } else {
315
      node = emitNodeStmt(getLabel(op), kShapeNode,
316
                          backgroundColors[op->getName()].second);
317
    }
318

319
    // Insert data flow edges originating from each operand.
320
    if (printDataFlowEdges) {
321
      unsigned numOperands = op->getNumOperands();
322
      for (unsigned i = 0; i < numOperands; i++)
323
        dataFlowEdges.push_back({op->getOperand(i), node,
324
                                 numOperands == 1 ? "" : std::to_string(i)});
325
    }
326

327
    for (Value result : op->getResults())
328
      valueToNode[result] = node;
329

330
    return node;
331
  }
332

333
  /// Process a region.
334
  void processRegion(Region &region) {
335
    for (Block &block : region.getBlocks())
336
      processBlock(block);
337
  }
338

339
  /// Truncate long strings.
340
  std::string truncateString(std::string str) {
341
    if (str.length() <= maxLabelLen)
342
      return str;
343
    return str.substr(0, maxLabelLen) + "...";
344
  }
345

346
  /// Output stream to write DOT file to.
347
  raw_indented_ostream os;
348
  /// A list of edges. For simplicity, should be emitted after all nodes were
349
  /// emitted.
350
  std::vector<std::string> edges;
351
  /// Mapping of SSA values to Graphviz nodes/clusters.
352
  DenseMap<Value, Node> valueToNode;
353
  /// Output for data flow edges is delayed until the end to handle cycles
354
  std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
355
  /// Counter for generating unique node/subgraph identifiers.
356
  int counter = 0;
357

358
  DenseMap<OperationName, std::pair<int, std::string>> backgroundColors;
359
};
360

361
} // namespace
362

363
std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
364
  return std::make_unique<PrintOpPass>(os);
365
}
366

367
/// Generate a CFG for a region and show it in a window.
368
static void llvmViewGraph(Region &region, const Twine &name) {
369
  int fd;
370
  std::string filename = llvm::createGraphFilename(name.str(), fd);
371
  {
372
    llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
373
    if (fd == -1) {
374
      llvm::errs() << "error opening file '" << filename << "' for writing\n";
375
      return;
376
    }
377
    PrintOpPass pass(os);
378
    pass.emitRegionCFG(region);
379
  }
380
  llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
381
}
382

383
void mlir::Region::viewGraph(const Twine &regionName) {
384
  llvmViewGraph(*this, regionName);
385
}
386

387
void mlir::Region::viewGraph() { viewGraph("region"); }
388

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.