pytorch

Форк
0
/
subgraph_matcher.cpp 
369 строк · 10.6 Кб
1
#include <c10/util/irange.h>
2
#include <torch/csrc/jit/ir/subgraph_matcher.h>
3
#include <torch/csrc/jit/jit_log.h>
4

5
#include <regex>
6
#include <stack>
7

8
namespace torch::jit {
9
namespace {
10

11
/**
12
 * \brief A class implementing an API for comparing subgraphs.
13
 */
14
class SubgraphMatcher {
15
 public:
16
  explicit SubgraphMatcher(const Graph& pattern) : pattern_(pattern) {}
17

18
  /**
19
   * \brief Compare matchGraph with the part of the graph denoted by a node \p
20
   * ANCHOR.
21
   *
22
   * The anchor node would be compared against the deepest node in the
23
   * match-graph. A node is considered matching if its number of inputs/outputs
24
   * is the same as in the corresponding matchGraph node, its type is the same,
25
   * and all nodes producing input-values also match.
26
   */
27
  bool matchesSubgraphFromAnchorNode(Node* anchor);
28

29
  /** \brief Return match map for nodes. */
30
  std::unordered_map<const Node*, Node*> nodes_map() const {
31
    return nodes_map_;
32
  }
33

34
  /** \brief Return match map for values. */
35
  std::unordered_map<const Value*, Value*> values_map() const {
36
    return values_map_;
37
  }
38

39
 private:
40
  bool matchValues(const Value* v1, Value* v2);
41
  bool matchNodes(const Node* n1, Node* n2);
42
  bool matchAttributes(const Node* n1, Node* n2);
43

44
  static bool isInput(const Value* v);
45
  static bool isOutput(const Value* v);
46

47
  std::unordered_map<const Node*, Node*> nodes_map_;
48
  std::unordered_map<const Value*, Value*> values_map_;
49

50
  const Graph& pattern_;
51
  const Node* anchor_ = nullptr;
52
};
53

54
/**
55
 * \brief A function to verify that \p PATTERN is valid. Concrete requirements
56
 * for validity can be found in subgraph_matcher.h.
57
 */
58
bool patternGraphIsValid(const Graph& pattern) {
59
  // Verify that pattern graph has a single block.
60
  for (const Node* n : pattern.nodes()) {
61
    if (!n->blocks().empty()) {
62
      return false;
63
    }
64
  }
65

66
  // TODO: Verify that nodes in the pattern don't alias.
67
  return true;
68
}
69

70
bool SubgraphMatcher::isInput(const Value* v) {
71
  return v->node()->kind() == prim::Param;
72
}
73

74
bool SubgraphMatcher::isOutput(const Value* v) {
75
  for (const Value* output : v->owningGraph()->outputs()) {
76
    if (v == output) {
77
      return true;
78
    }
79
  }
80
  return false;
81
}
82

83
/**
84
 * Compare two Values. V1 is from pattern, V2 is from the actual graph.
85
 *
86
 * The values are considered matching if:
87
 * 1) the nodes defining them match
88
 * 2) they have the same number of uses, except they are entry or exit nodes.
89
 */
90
bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) {
91
  // Check if we've already visited these values.
92
  if (values_map_.count(v1)) {
93
    if (values_map_.at(v1) != v2) {
94
      GRAPH_DEBUG(
95
          "Values %",
96
          v1->debugName(),
97
          " and %",
98
          v2->debugName(),
99
          " did not match because %",
100
          v1->debugName(),
101
          " has already been matched with %",
102
          values_map_.at(v1)->debugName(),
103
          ".\n");
104
      return false;
105
    }
106
    return true;
107
  }
108

109
  // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is
110
  // PARAM, we're comparing entering values - in these two cases the number of
111
  // uses don't need to be the same.
112
  if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) {
113
    GRAPH_DEBUG(
114
        "Values %",
115
        v1->debugName(),
116
        " and %",
117
        v2->debugName(),
118
        " did not match because number of their uses is different.\n");
119
    return false;
120
  }
121

122
  // Add the values to the map before calling matchNodes to avoid infinite
123
  // recursion.
124
  GRAPH_DEBUG(
125
      "Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n");
126
  values_map_[v1] = v2;
127
  return matchNodes(v1->node(), v2->node());
128
}
129

130
bool SubgraphMatcher::matchAttributes(const Node* n1, Node* n2) {
131
  if (n1->numAttributes() != n2->numAttributes()) {
132
    GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2);
133
    return false;
134
  }
135
  for (const Symbol& attr_name : n1->attributeNames()) {
136
    if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
137
      GRAPH_DEBUG(
138
          "Nodes did not match because type of attribute '",
139
          attr_name.toQualString(),
140
          "' did not match:\n",
141
          *n1,
142
          *n2);
143
      return false;
144
    }
145
    switch (n1->kindOf(attr_name)) {
146
      case AttributeKind::s:
147
        if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) {
148
          GRAPH_DEBUG(
149
              "Nodes did not match because attribute '",
150
              attr_name.toQualString(),
151
              "' did not match: ",
152
              n1->s(attr_name),
153
              " != ",
154
              n2->s(attr_name),
155
              " \n",
156
              *n1,
157
              *n2);
158
          return false;
159
        }
160
        break;
161
      case AttributeKind::c:
162
        if (n1->c(attr_name) != n2->c(attr_name)) {
163
          GRAPH_DEBUG(
164
              "Nodes did not match because attribute '",
165
              attr_name.toQualString(),
166
              "' did not match:",
167
              n1->c(attr_name),
168
              " != ",
169
              n2->c(attr_name),
170
              " \n",
171
              *n1,
172
              *n2);
173
          return false;
174
        }
175
        break;
176
      case AttributeKind::f:
177
        if (n1->f(attr_name) != n2->f(attr_name)) {
178
          GRAPH_DEBUG(
179
              "Nodes did not match because attribute '",
180
              attr_name.toQualString(),
181
              "' did not match:",
182
              n1->f(attr_name),
183
              " != ",
184
              n2->f(attr_name),
185
              " \n",
186
              *n1,
187
              *n2);
188
          return false;
189
        }
190
        break;
191
      case AttributeKind::i:
192
        if (n1->i(attr_name) != n2->i(attr_name)) {
193
          GRAPH_DEBUG(
194
              "Nodes did not match because attribute '",
195
              attr_name.toQualString(),
196
              "' did not match:",
197
              n1->i(attr_name),
198
              " != ",
199
              n2->i(attr_name),
200
              " \n",
201
              *n1,
202
              *n2);
203
          return false;
204
        }
205
        break;
206
      default: {
207
        // Other attributes types not supported yet
208
        GRAPH_DEBUG(
209
            "Nodes did not match because type of attribute '",
210
            attr_name.toQualString(),
211
            "' is not supported.\n",
212
            *n1,
213
            *n2);
214
        return false;
215
      }
216
    }
217
  }
218
  return true;
219
}
220

221
static bool endsWith(const std::string& str, const std::string& suffix) {
222
  return str.size() >= suffix.size() &&
223
      0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
224
}
225

226
/**
227
 * Compare two Nodes. N1 is from pattern, N2 is from the actual graph.
228
 *
229
 * The nodes are considered matching if:
230
 * 1) N1 and N2 are of the same kind.
231
 * 2) Number of inputs and outputs is the same.
232
 * 3) All input and output values match.
233
 *
234
 * A special case is when N1 is PARAM - this is considered outside the pattern,
235
 * so it matches everything.
236
 */
237
bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) {
238
  // Check if we've already visited these nodes.
239
  if (nodes_map_.count(n1)) {
240
    return nodes_map_.at(n1) == n2;
241
  }
242

243
  // Param node in pattern graph matches everything.
244
  if (n1->kind() == prim::Param) {
245
    GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
246
    return true;
247
  }
248

249
  // We don't allow matches to span across blocks, so check if N2 is in the same
250
  // block as the first (anchor) node.
251
  if (n2->owningBlock() != anchor_->owningBlock()) {
252
    GRAPH_DEBUG(
253
        "Nodes did not match because it is in the different block:\n",
254
        *n1,
255
        *n2);
256
    return false;
257
  }
258

259
  // Special handling for matching modules
260
  if (n1->kind() == Symbol::fromQualString("match::module")) {
261
    if (n2->kind() == prim::GetAttr) {
262
      if (!n1->hasAttributeS("name")) {
263
        GRAPH_DEBUG(
264
            "Nodes did not match because special node match::module does not have 'name' attribute:\n",
265
            *n1,
266
            *n2);
267
        return false;
268
      }
269
      auto t = n2->output()->type()->expect<ClassType>();
270
      auto real_typename = t->name()->qualifiedName();
271
      auto pattern_typename = n1->s(attr::name);
272
      if (!endsWith(real_typename, pattern_typename)) {
273
        GRAPH_DEBUG(
274
            "Nodes did not match because expected module type is different:\n");
275
        GRAPH_DEBUG("  actualtype:    ", real_typename, "\n");
276
        GRAPH_DEBUG("  expected type: ", pattern_typename, "\n");
277
        GRAPH_DEBUG("Nodes:", *n1, *n2);
278
        return false;
279
      }
280
    }
281
  } else {
282
    if (n1->kind() != n2->kind() ||
283
        n1->outputs().size() != n2->outputs().size() ||
284
        n1->inputs().size() != n2->inputs().size()) {
285
      GRAPH_DEBUG(
286
          "Nodes did not match in their kind or number of inputs/outputs:\n",
287
          *n1,
288
          *n2);
289
      return false;
290
    }
291
    if (!matchAttributes(n1, n2)) {
292
      return false;
293
    }
294
  }
295

296
  // Add nodes to the map before calling matchValues to avoid infinite
297
  // recursion.
298
  nodes_map_[n1] = n2;
299
  for (const auto i : c10::irange(n1->outputs().size())) {
300
    if (!matchValues(n1->outputs()[i], n2->outputs()[i])) {
301
      return false;
302
    }
303
  }
304
  for (const auto i : c10::irange(n1->inputs().size())) {
305
    if (!matchValues(n1->inputs()[i], n2->inputs()[i])) {
306
      return false;
307
    }
308
  }
309

310
  GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
311
  return true;
312
}
313

314
/**
315
 * Recursively try to match pattern with the actual graph starting from the
316
 * exiting node in the pattern and anchor node in the actual graph.
317
 */
318
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
319
  GRAPH_UPDATE("Starting match from a new anchor: ", *anchor);
320
  nodes_map_.clear();
321
  values_map_.clear();
322
  anchor_ = anchor;
323

324
  const Node* bottom_node = *(pattern_.nodes().end());
325
  bottom_node = bottom_node->input(0)->node();
326

327
  if (!matchNodes(bottom_node, anchor)) {
328
    return false;
329
  }
330

331
  for (const Value* output : pattern_.outputs()) {
332
    AT_ASSERT(values_map_.count(output));
333
  }
334

335
  GRAPH_UPDATE("Pattern matched!\n");
336
  return true;
337
}
338

339
} // unnamed namespace
340

341
// Main entry point for the subgraph matching.
342
std::vector<Match> findPatternMatches(const Graph& pattern, Graph& graph) {
343
  AT_ASSERT(patternGraphIsValid(pattern));
344
  GRAPH_DUMP("Pattern graph: ", &pattern);
345
  GRAPH_DUMP("Target graph: ", &graph);
346

347
  SubgraphMatcher m(pattern);
348
  std::vector<Match> matches;
349
  std::stack<Block*> blocks_to_visit;
350

351
  // Iterate over all nodes in the graph (including nodes in subblocks) trying
352
  // to match the pattern each node.
353
  blocks_to_visit.push(graph.block());
354
  while (!blocks_to_visit.empty()) {
355
    Block* block = blocks_to_visit.top();
356
    blocks_to_visit.pop();
357
    for (Node* n : block->nodes()) {
358
      if (m.matchesSubgraphFromAnchorNode(n)) {
359
        matches.push_back({n, m.nodes_map(), m.values_map()});
360
      }
361
      for (Block* subblock : n->blocks()) {
362
        blocks_to_visit.push(subblock);
363
      }
364
    }
365
  }
366
  return matches;
367
}
368

369
} // namespace torch::jit
370

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.