llvm-project
387 строк · 12.4 Кб
1//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/ViewOpGraph.h"
10
11#include "mlir/Analysis/TopologicalSortUtils.h"
12#include "mlir/IR/Block.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/Operation.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Support/IndentedOstream.h"
17#include "llvm/Support/Format.h"
18#include "llvm/Support/GraphWriter.h"
19#include <map>
20#include <optional>
21#include <utility>
22
23namespace mlir {
24#define GEN_PASS_DEF_VIEWOPGRAPH
25#include "mlir/Transforms/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30static const StringRef kLineStyleControlFlow = "dashed";
31static const StringRef kLineStyleDataFlow = "solid";
32static const StringRef kShapeNode = "ellipse";
33static const StringRef kShapeNone = "plain";
34
35/// Return the size limits for eliding large attributes.
36static int64_t getLargeAttributeSizeLimit() {
37// Use the default from the printer flags if possible.
38if (std::optional<int64_t> limit =
39OpPrintingFlags().getLargeElementsAttrLimit())
40return *limit;
41return 16;
42}
43
44/// Return all values printed onto a stream as a string.
45static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46std::string buf;
47llvm::raw_string_ostream os(buf);
48func(os);
49return os.str();
50}
51
52/// Escape special characters such as '\n' and quotation marks.
53static std::string escapeString(std::string str) {
54return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
55}
56
57/// Put quotation marks around a given string.
58static std::string quoteString(const std::string &str) {
59return "\"" + str + "\"";
60}
61
62using AttributeMap = std::map<std::string, std::string>;
63
64namespace {
65
66/// This struct represents a node in the DOT language. Each node has an
67/// identifier and an optional identifier for the cluster (subgraph) that
68/// contains the node.
69/// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
70/// not between clusters. However, edges can be clipped to the boundary of a
71/// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
72/// cluster, an invisible "anchor" node is created.
73struct Node {
74public:
75Node(int id = 0, std::optional<int> clusterId = std::nullopt)
76: id(id), clusterId(clusterId) {}
77
78int id;
79std::optional<int> clusterId;
80};
81
82/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
83/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
84/// about the Graphviz DOT language.
85class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
86public:
87PrintOpPass(raw_ostream &os) : os(os) {}
88PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
89
90void runOnOperation() override {
91initColorMapping(*getOperation());
92emitGraph([&]() {
93processOperation(getOperation());
94emitAllEdgeStmts();
95});
96}
97
98/// Create a CFG graph for a region. Used in `Region::viewGraph`.
99void emitRegionCFG(Region ®ion) {
100printControlFlowEdges = true;
101printDataFlowEdges = false;
102initColorMapping(region);
103emitGraph([&]() { processRegion(region); });
104}
105
106private:
107/// Generate a color mapping that will color every operation with the same
108/// name the same way. It'll interpolate the hue in the HSV color-space,
109/// attempting to keep the contrast suitable for black text.
110template <typename T>
111void initColorMapping(T &irEntity) {
112backgroundColors.clear();
113SmallVector<Operation *> ops;
114irEntity.walk([&](Operation *op) {
115auto &entry = backgroundColors[op->getName()];
116if (entry.first == 0)
117ops.push_back(op);
118++entry.first;
119});
120for (auto indexedOps : llvm::enumerate(ops)) {
121double hue = ((double)indexedOps.index()) / ops.size();
122backgroundColors[indexedOps.value()->getName()].second =
123std::to_string(hue) + " 1.0 1.0";
124}
125}
126
127/// Emit all edges. This function should be called after all nodes have been
128/// emitted.
129void emitAllEdgeStmts() {
130if (printDataFlowEdges) {
131for (const auto &[value, node, label] : dataFlowEdges) {
132emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow);
133}
134}
135
136for (const std::string &edge : edges)
137os << edge << ";\n";
138edges.clear();
139}
140
141/// Emit a cluster (subgraph). The specified builder generates the body of the
142/// cluster. Return the anchor node of the cluster.
143Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
144int clusterId = ++counter;
145os << "subgraph cluster_" << clusterId << " {\n";
146os.indent();
147// Emit invisible anchor node from/to which arrows can be drawn.
148Node anchorNode = emitNodeStmt(" ", kShapeNone);
149os << attrStmt("label", quoteString(escapeString(std::move(label))))
150<< ";\n";
151builder();
152os.unindent();
153os << "}\n";
154return Node(anchorNode.id, clusterId);
155}
156
157/// Generate an attribute statement.
158std::string attrStmt(const Twine &key, const Twine &value) {
159return (key + " = " + value).str();
160}
161
162/// Emit an attribute list.
163void emitAttrList(raw_ostream &os, const AttributeMap &map) {
164os << "[";
165interleaveComma(map, os, [&](const auto &it) {
166os << this->attrStmt(it.first, it.second);
167});
168os << "]";
169}
170
171// Print an MLIR attribute to `os`. Large attributes are truncated.
172void emitMlirAttr(raw_ostream &os, Attribute attr) {
173// A value used to elide large container attribute.
174int64_t largeAttrLimit = getLargeAttributeSizeLimit();
175
176// Always emit splat attributes.
177if (isa<SplatElementsAttr>(attr)) {
178attr.print(os);
179return;
180}
181
182// Elide "big" elements attributes.
183auto elements = dyn_cast<ElementsAttr>(attr);
184if (elements && elements.getNumElements() > largeAttrLimit) {
185os << std::string(elements.getShapedType().getRank(), '[') << "..."
186<< std::string(elements.getShapedType().getRank(), ']') << " : "
187<< elements.getType();
188return;
189}
190
191auto array = dyn_cast<ArrayAttr>(attr);
192if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
193os << "[...]";
194return;
195}
196
197// Print all other attributes.
198std::string buf;
199llvm::raw_string_ostream ss(buf);
200attr.print(ss);
201os << truncateString(ss.str());
202}
203
204/// Append an edge to the list of edges.
205/// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
206void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
207AttributeMap attrs;
208attrs["style"] = style.str();
209// Do not label edges that start/end at a cluster boundary. Such edges are
210// clipped at the boundary, but labels are not. This can lead to labels
211// floating around without any edge next to them.
212if (!n1.clusterId && !n2.clusterId)
213attrs["label"] = quoteString(escapeString(std::move(label)));
214// Use `ltail` and `lhead` to draw edges between clusters.
215if (n1.clusterId)
216attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
217if (n2.clusterId)
218attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
219
220edges.push_back(strFromOs([&](raw_ostream &os) {
221os << llvm::format("v%i -> v%i ", n1.id, n2.id);
222emitAttrList(os, attrs);
223}));
224}
225
226/// Emit a graph. The specified builder generates the body of the graph.
227void emitGraph(function_ref<void()> builder) {
228os << "digraph G {\n";
229os.indent();
230// Edges between clusters are allowed only in compound mode.
231os << attrStmt("compound", "true") << ";\n";
232builder();
233os.unindent();
234os << "}\n";
235}
236
237/// Emit a node statement.
238Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
239StringRef background = "") {
240int nodeId = ++counter;
241AttributeMap attrs;
242attrs["label"] = quoteString(escapeString(std::move(label)));
243attrs["shape"] = shape.str();
244if (!background.empty()) {
245attrs["style"] = "filled";
246attrs["fillcolor"] = ("\"" + background + "\"").str();
247}
248os << llvm::format("v%i ", nodeId);
249emitAttrList(os, attrs);
250os << ";\n";
251return Node(nodeId);
252}
253
254/// Generate a label for an operation.
255std::string getLabel(Operation *op) {
256return strFromOs([&](raw_ostream &os) {
257// Print operation name and type.
258os << op->getName();
259if (printResultTypes) {
260os << " : (";
261std::string buf;
262llvm::raw_string_ostream ss(buf);
263interleaveComma(op->getResultTypes(), ss);
264os << truncateString(ss.str()) << ")";
265}
266
267// Print attributes.
268if (printAttrs) {
269os << "\n";
270for (const NamedAttribute &attr : op->getAttrs()) {
271os << '\n' << attr.getName().getValue() << ": ";
272emitMlirAttr(os, attr.getValue());
273}
274}
275});
276}
277
278/// Generate a label for a block argument.
279std::string getLabel(BlockArgument arg) {
280return "arg" + std::to_string(arg.getArgNumber());
281}
282
283/// Process a block. Emit a cluster and one node per block argument and
284/// operation inside the cluster.
285void processBlock(Block &block) {
286emitClusterStmt([&]() {
287for (BlockArgument &blockArg : block.getArguments())
288valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
289
290// Emit a node for each operation.
291std::optional<Node> prevNode;
292for (Operation &op : block) {
293Node nextNode = processOperation(&op);
294if (printControlFlowEdges && prevNode)
295emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
296kLineStyleControlFlow);
297prevNode = nextNode;
298}
299});
300}
301
302/// Process an operation. If the operation has regions, emit a cluster.
303/// Otherwise, emit a node.
304Node processOperation(Operation *op) {
305Node node;
306if (op->getNumRegions() > 0) {
307// Emit cluster for op with regions.
308node = emitClusterStmt(
309[&]() {
310for (Region ®ion : op->getRegions())
311processRegion(region);
312},
313getLabel(op));
314} else {
315node = emitNodeStmt(getLabel(op), kShapeNode,
316backgroundColors[op->getName()].second);
317}
318
319// Insert data flow edges originating from each operand.
320if (printDataFlowEdges) {
321unsigned numOperands = op->getNumOperands();
322for (unsigned i = 0; i < numOperands; i++)
323dataFlowEdges.push_back({op->getOperand(i), node,
324numOperands == 1 ? "" : std::to_string(i)});
325}
326
327for (Value result : op->getResults())
328valueToNode[result] = node;
329
330return node;
331}
332
333/// Process a region.
334void processRegion(Region ®ion) {
335for (Block &block : region.getBlocks())
336processBlock(block);
337}
338
339/// Truncate long strings.
340std::string truncateString(std::string str) {
341if (str.length() <= maxLabelLen)
342return str;
343return str.substr(0, maxLabelLen) + "...";
344}
345
346/// Output stream to write DOT file to.
347raw_indented_ostream os;
348/// A list of edges. For simplicity, should be emitted after all nodes were
349/// emitted.
350std::vector<std::string> edges;
351/// Mapping of SSA values to Graphviz nodes/clusters.
352DenseMap<Value, Node> valueToNode;
353/// Output for data flow edges is delayed until the end to handle cycles
354std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
355/// Counter for generating unique node/subgraph identifiers.
356int counter = 0;
357
358DenseMap<OperationName, std::pair<int, std::string>> backgroundColors;
359};
360
361} // namespace
362
363std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
364return std::make_unique<PrintOpPass>(os);
365}
366
367/// Generate a CFG for a region and show it in a window.
368static void llvmViewGraph(Region ®ion, const Twine &name) {
369int fd;
370std::string filename = llvm::createGraphFilename(name.str(), fd);
371{
372llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
373if (fd == -1) {
374llvm::errs() << "error opening file '" << filename << "' for writing\n";
375return;
376}
377PrintOpPass pass(os);
378pass.emitRegionCFG(region);
379}
380llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
381}
382
383void mlir::Region::viewGraph(const Twine ®ionName) {
384llvmViewGraph(*this, regionName);
385}
386
387void mlir::Region::viewGraph() { viewGraph("region"); }
388