1
#include <c10/util/irange.h>
2
#include <torch/csrc/jit/ir/subgraph_matcher.h>
3
#include <torch/csrc/jit/jit_log.h>
14
class SubgraphMatcher {
16
explicit SubgraphMatcher(const Graph& pattern) : pattern_(pattern) {}
27
bool matchesSubgraphFromAnchorNode(Node* anchor);
30
std::unordered_map<const Node*, Node*> nodes_map() const {
35
std::unordered_map<const Value*, Value*> values_map() const {
40
bool matchValues(const Value* v1, Value* v2);
41
bool matchNodes(const Node* n1, Node* n2);
42
bool matchAttributes(const Node* n1, Node* n2);
44
static bool isInput(const Value* v);
45
static bool isOutput(const Value* v);
47
std::unordered_map<const Node*, Node*> nodes_map_;
48
std::unordered_map<const Value*, Value*> values_map_;
50
const Graph& pattern_;
51
const Node* anchor_ = nullptr;
58
bool patternGraphIsValid(const Graph& pattern) {
60
for (const Node* n : pattern.nodes()) {
61
if (!n->blocks().empty()) {
70
bool SubgraphMatcher::isInput(const Value* v) {
71
return v->node()->kind() == prim::Param;
74
bool SubgraphMatcher::isOutput(const Value* v) {
75
for (const Value* output : v->owningGraph()->outputs()) {
90
bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) {
92
if (values_map_.count(v1)) {
93
if (values_map_.at(v1) != v2) {
99
" did not match because %",
101
" has already been matched with %",
102
values_map_.at(v1)->debugName(),
112
if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) {
118
" did not match because number of their uses is different.\n");
125
"Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n");
126
values_map_[v1] = v2;
127
return matchNodes(v1->node(), v2->node());
130
bool SubgraphMatcher::matchAttributes(const Node* n1, Node* n2) {
131
if (n1->numAttributes() != n2->numAttributes()) {
132
GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2);
135
for (const Symbol& attr_name : n1->attributeNames()) {
136
if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) {
138
"Nodes did not match because type of attribute '",
139
attr_name.toQualString(),
140
"' did not match:\n",
145
switch (n1->kindOf(attr_name)) {
146
case AttributeKind::s:
147
if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) {
149
"Nodes did not match because attribute '",
150
attr_name.toQualString(),
161
case AttributeKind::c:
162
if (n1->c(attr_name) != n2->c(attr_name)) {
164
"Nodes did not match because attribute '",
165
attr_name.toQualString(),
176
case AttributeKind::f:
177
if (n1->f(attr_name) != n2->f(attr_name)) {
179
"Nodes did not match because attribute '",
180
attr_name.toQualString(),
191
case AttributeKind::i:
192
if (n1->i(attr_name) != n2->i(attr_name)) {
194
"Nodes did not match because attribute '",
195
attr_name.toQualString(),
209
"Nodes did not match because type of attribute '",
210
attr_name.toQualString(),
211
"' is not supported.\n",
221
static bool endsWith(const std::string& str, const std::string& suffix) {
222
return str.size() >= suffix.size() &&
223
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
237
bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) {
239
if (nodes_map_.count(n1)) {
240
return nodes_map_.at(n1) == n2;
244
if (n1->kind() == prim::Param) {
245
GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
251
if (n2->owningBlock() != anchor_->owningBlock()) {
253
"Nodes did not match because it is in the different block:\n",
260
if (n1->kind() == Symbol::fromQualString("match::module")) {
261
if (n2->kind() == prim::GetAttr) {
262
if (!n1->hasAttributeS("name")) {
264
"Nodes did not match because special node match::module does not have 'name' attribute:\n",
269
auto t = n2->output()->type()->expect<ClassType>();
270
auto real_typename = t->name()->qualifiedName();
271
auto pattern_typename = n1->s(attr::name);
272
if (!endsWith(real_typename, pattern_typename)) {
274
"Nodes did not match because expected module type is different:\n");
275
GRAPH_DEBUG(" actualtype: ", real_typename, "\n");
276
GRAPH_DEBUG(" expected type: ", pattern_typename, "\n");
277
GRAPH_DEBUG("Nodes:", *n1, *n2);
282
if (n1->kind() != n2->kind() ||
283
n1->outputs().size() != n2->outputs().size() ||
284
n1->inputs().size() != n2->inputs().size()) {
286
"Nodes did not match in their kind or number of inputs/outputs:\n",
291
if (!matchAttributes(n1, n2)) {
299
for (const auto i : c10::irange(n1->outputs().size())) {
300
if (!matchValues(n1->outputs()[i], n2->outputs()[i])) {
304
for (const auto i : c10::irange(n1->inputs().size())) {
305
if (!matchValues(n1->inputs()[i], n2->inputs()[i])) {
310
GRAPH_DEBUG("Nodes matched:\n", *n1, *n2);
318
bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
319
GRAPH_UPDATE("Starting match from a new anchor: ", *anchor);
324
const Node* bottom_node = *(pattern_.nodes().end());
325
bottom_node = bottom_node->input(0)->node();
327
if (!matchNodes(bottom_node, anchor)) {
331
for (const Value* output : pattern_.outputs()) {
332
AT_ASSERT(values_map_.count(output));
335
GRAPH_UPDATE("Pattern matched!\n");
342
std::vector<Match> findPatternMatches(const Graph& pattern, Graph& graph) {
343
AT_ASSERT(patternGraphIsValid(pattern));
344
GRAPH_DUMP("Pattern graph: ", &pattern);
345
GRAPH_DUMP("Target graph: ", &graph);
347
SubgraphMatcher m(pattern);
348
std::vector<Match> matches;
349
std::stack<Block*> blocks_to_visit;
353
blocks_to_visit.push(graph.block());
354
while (!blocks_to_visit.empty()) {
355
Block* block = blocks_to_visit.top();
356
blocks_to_visit.pop();
357
for (Node* n : block->nodes()) {
358
if (m.matchesSubgraphFromAnchorNode(n)) {
359
matches.push_back({n, m.nodes_map(), m.values_map()});
361
for (Block* subblock : n->blocks()) {
362
blocks_to_visit.push(subblock);