pytorch

Форк
0
2321 строка · 64.7 Кб
1
#include <torch/csrc/jit/ir/ir.h>
2

3
#include <ATen/core/builtin_function.h>
4
#include <ATen/core/function.h>
5
#include <c10/util/Exception.h>
6
#include <c10/util/StringUtil.h>
7
#include <c10/util/irange.h>
8
#include <torch/csrc/jit/api/function_impl.h>
9
#include <torch/csrc/jit/frontend/error_report.h>
10
#include <torch/csrc/jit/frontend/schema_matching.h>
11
#include <torch/csrc/jit/ir/constants.h>
12
#include <torch/csrc/jit/runtime/operator.h>
13
#include <torch/csrc/jit/serialization/python_print.h>
14

15
#include <algorithm>
16
#include <iostream>
17
#include <locale>
18
#include <memory>
19
#include <set>
20
#include <sstream>
21
#include <string>
22
#include <unordered_map>
23
#include <unordered_set>
24
#include <utility>
25

26
namespace torch::jit {
27

28
namespace utils {
29
std::string getNodesModuleHierarchy(const Node& n) {
30
  if (!n.callstack().has_value()) {
31
    return std::string();
32
  }
33
  InlinedCallStackPtr callstack_ptr = n.callstack().value();
34
  std::string module_hierarchy;
35
  for (auto& entry : callstack_ptr->vec()) {
36
    const auto& opt_module_info = std::get<kModuleInstanceInfo>(entry);
37
    if (opt_module_info.has_value()) {
38
      const auto& module_instance_info = opt_module_info.value();
39
      if (!module_hierarchy.empty()) {
40
        module_hierarchy.append(".");
41
      }
42
      module_hierarchy.append(utils::get_module_info(module_instance_info));
43
    } else {
44
      module_hierarchy += ".UNKNOWN_INSTANCE(UNKNOWN_TYPE)";
45
    }
46
  }
47
  return module_hierarchy;
48
}
49
} // namespace utils
50

51
namespace {
52

53
// Constants relating to maintaining the topological index of nodes.
54
//
55
// Lower and upper bounds of the index. Inclusive range.
56
constexpr topo_position_t kLowerBound = INT64_MIN;
57
constexpr topo_position_t kUpperBound = INT64_MAX;
58
constexpr topo_position_t kMidPoint = 0;
59

60
// How far away to space nodes that are appended to the graph.
61
// should be 2^n, where:
62
//   - n is the maximum number of repeated insertions without a re-index
63
//   - 2^(64-n) is the maximum number of appends to the end without reindex
64
constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
65

66
void printValueRef(std::ostream& out, const Value* n) {
67
  out << "%" << n->debugName();
68
}
69

70
bool isNumber(c10::string_view str) {
71
  return str.find_first_not_of("0123456789") == std::string::npos;
72
}
73

74
std::string normalizeAttrName(c10::string_view field) {
75
  if (isNumber(field)) {
76
    return "_" + std::string{field};
77
  }
78
  return std::string{field};
79
}
80

81
void findAllNodes(
82
    Block& block,
83
    Symbol kind,
84
    bool recurse,
85
    std::vector<Node*>& ret) {
86
  for (Node* n : block.nodes()) {
87
    if (n->kind() == kind) {
88
      ret.push_back(n);
89
    }
90
    if (recurse) {
91
      for (auto b : n->blocks()) {
92
        findAllNodes(*b, kind, recurse, ret);
93
      }
94
    }
95
  }
96
}
97

98
} // namespace
99

100
// NB: This overload will become ambiguous with the one Caffe2 provides in its
101
// logging, if they ever intersect.
102
template <typename T>
103
std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) {
104
  out << at::ArrayRef<T>{nodes};
105
  return out;
106
}
107

108
template <typename T>
109
static std::ostream& printValueRefs(
110
    std::ostream& out,
111
    const at::ArrayRef<T> nodes) {
112
  size_t i = 0;
113
  for (auto n : nodes) {
114
    if (i++ > 0) {
115
      out << ", ";
116
    }
117
    printValueRef(out, n);
118
  }
119
  return out;
120
}
121

122
// Can't make these two overloads directly a template, it'll be ambiguous with
123
// the global printer for operator<<.
124

125
static std::ostream& operator<<(
126
    std::ostream& out,
127
    const at::ArrayRef<const Value*> nodes) {
128
  return printValueRefs(out, nodes);
129
}
130

131
static std::ostream& operator<<(
132
    std::ostream& out,
133
    const at::ArrayRef<Value*> nodes) {
134
  return printValueRefs(out, nodes);
135
}
136

137
struct const_value_list_with_types {
138
  const ArrayRef<const Value*> values;
139
  std::string delim;
140
  const_value_list_with_types(
141
      ArrayRef<const Value*> values,
142
      std::string delim_ = ", ")
143
      : values(values), delim(std::move(delim_)) {}
144
};
145

146
static std::ostream& operator<<(
147
    std::ostream& out,
148
    const const_value_list_with_types& l) {
149
  size_t i = 0;
150
  for (auto n : l.values) {
151
    if (i++ > 0) {
152
      out << l.delim;
153
    }
154
    printValueRef(out, n);
155
    if (c10::type_verbosity() >= c10::TypeVerbosity::Type) {
156
      out << " : ";
157
      out << *n->type();
158
    }
159
  }
160
  return out;
161
}
162

163
static void printAttribute(std::ostream& out, const at::Tensor& tensor) {
164
  // 1-elem tensors are usually boxed scalars, so print them like it
165
  if (tensor.numel() == 1) {
166
    auto scalar_tensor = tensor.view(std::vector<int64_t>{}).item();
167
    out << "{";
168
    if (scalar_tensor.isFloatingPoint()) {
169
      out << scalar_tensor.toDouble();
170
    } else if (scalar_tensor.isComplex()) {
171
      out << scalar_tensor.toComplexDouble();
172
    } else {
173
      out << scalar_tensor.toLong();
174
    }
175
    out << "}";
176
  } else if (tensor.numel() <= max_tensor_display_size) {
177
    // TODO: This is awful code.  Also it doesn't work on Windows.
178
    std::ostringstream tensor_ss;
179
    tensor_ss << tensor;
180
    std::string tensor_s{tensor_ss.str()};
181
    // Remove newlines
182
    std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
183
    out << tensor_s;
184
  } else {
185
    out << "<Tensor>";
186
  }
187
}
188

189
static void printAttribute(std::ostream& out, const IValue& ival) {
190
  const auto customFormatter = [](std::ostream& ss, const IValue& input) {
191
    if (input.isTensor()) {
192
      printAttribute(ss, input.toTensor());
193
      return true;
194
    } else if (input.isTensorList()) {
195
      ss << "[<Tensors>]";
196
      return true;
197
    } else if (input.isObject() && !input.type()->is_module()) {
198
      ss << "object(" << &input.toObjectRef() << ")";
199
      return true;
200
    }
201
    return false;
202
  };
203
  ival.repr(out, customFormatter);
204
}
205

206
static void printTypeList(
207
    std::ostream& out,
208
    const std::vector<TypePtr>& items) {
209
  out << "[";
210
  int i = 0;
211
  for (auto& item : items) {
212
    if (i++ > 0)
213
      out << ", ";
214
    out << *item;
215
  }
216
  out << "]";
217
}
218

219
void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
220
  switch (kindOf(name)) {
221
    case AttributeKind::c:
222
      printAttribute(out, c(name));
223
      break;
224
    case AttributeKind::cs:
225
      // TODO(@anjali411): fix this
226
      AT_ASSERT(false);
227
      break;
228
    case AttributeKind::f:
229
      printAttribute(out, f(name));
230
      break;
231
    case AttributeKind::fs:
232
      printAttribute(out, fs(name));
233
      break;
234
    case AttributeKind::i:
235
      printAttribute(out, i(name));
236
      break;
237
    case AttributeKind::is:
238
      printAttribute(out, is(name));
239
      break;
240
    case AttributeKind::s:
241
      printAttribute(out, s(name));
242
      break;
243
    case AttributeKind::ss:
244
      printAttribute(out, ss(name));
245
      break;
246
    case AttributeKind::t:
247
      printAttribute(out, t(name));
248
      break;
249
    case AttributeKind::ts:
250
      out << "[<Tensors>]";
251
      break;
252
    case AttributeKind::ival:
253
      printAttribute(out, ival(name));
254
      break;
255
    case AttributeKind::g:
256
      out << "<Graph>";
257
      break;
258
    case AttributeKind::gs:
259
      out << "[<Graphs>]";
260
      break;
261
    case AttributeKind::ty:
262
      out << *ty(name);
263
      break;
264
    case AttributeKind::tys:
265
      printTypeList(out, tys(name));
266
      break;
267
  }
268
}
269

270
void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
271
    const {
272
  out << "[";
273
  auto names = attributeNames();
274
  int i = 0;
275
  for (auto name : names) {
276
    if (ignore_subgraph && name == attr::Subgraph) {
277
      continue;
278
    }
279
    if (i++ > 0) {
280
      out << ", ";
281
    }
282
    // TODO: debugging mode to see the qualifier.  We definitely
283
    // don't want to print the qualifier since it should always
284
    // be attribute, but you might be able to track down a weird
285
    // bug by printing it out.
286
    out << name.toUnqualString() << "=";
287

288
    printAttrValue(out, name);
289
  }
290
  out << "]";
291
}
292

293
SourceRange Node::sourceRange() const {
294
  if (source_range_) {
295
    return *source_range_;
296
  }
297
  return SourceRange();
298
}
299

300
static std::ostream& indent(std::ostream& out, size_t level) {
301
  for (const auto i : c10::irange(level)) {
302
    (void)i; // Suppress unused variable warning
303
    out << "  ";
304
  }
305
  return out;
306
}
307

308
std::ostream& Node::print(
309
    std::ostream& out,
310
    size_t level,
311
    std::vector<const Node*>* groups,
312
    bool print_source_locations,
313
    bool print_attributes,
314
    bool print_scopes,
315
    bool print_body) const {
316
  auto outs = outputs();
317
  indent(out, level) << const_value_list_with_types(outs);
318
  out << " = ";
319
  if (kind() == prim::PythonOp) {
320
    auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this);
321
    out << "^" << pyOp->name();
322
    printAttributes(out, /*ignore_subgraph=*/false);
323
    pyOp->writeScalars(out);
324
  } else if (hasAttribute(attr::Subgraph) && groups) {
325
    out << kind().toQualString() << "_" << groups->size();
326
    if (print_attributes && numAttributes() > 1 &&
327
        kind() != prim::DifferentiableGraph) {
328
      printAttributes(out, /*ignore_subgraph=*/true);
329
    }
330

331
    groups->push_back(this);
332
  } else {
333
    out << kind().toQualString();
334
    if (print_attributes && hasAttributes()) {
335
      printAttributes(out);
336
    }
337
  }
338
  out << "(" << inputs() << ")";
339

340
  if (print_scopes) {
341
    std::string scName = scopeName();
342
    if (!scName.empty()) {
343
      out << ", ";
344
      out << "scope: " << scName;
345
    }
346
  }
347

348
  // In debug print, append file:line:col as a comment after each node
349
  if (print_source_locations) {
350
    SourceRange r = sourceRange();
351
    if (sourceRange().source()) {
352
      if (auto orig = sourceRange().source()->findSourceRangeThatGenerated(r)) {
353
        r = *orig;
354
      }
355
    }
356
    if (auto file_line_col = r.file_line_col()) {
357
      auto [filename, line, col] = *file_line_col;
358
      out << " # " << filename << ":" << line << ":" << col;
359
    }
360
  }
361

362
  if (!print_body) {
363
    return out;
364
  }
365

366
  out << "\n";
367

368
  for (const auto i : c10::irange(blocks().size())) {
369
    auto b = blocks()[i];
370
    indent(out, level + 1) << "block" << i << "("
371
                           << const_value_list_with_types(b->inputs())
372
                           << "):\n";
373
    for (auto nested : b->nodes()) {
374
      nested->print(out, level + 2, groups);
375
    }
376
    indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
377
  }
378

379
  return out;
380
}
381

382
std::ostream& operator<<(std::ostream& out, const Node& n) {
383
  return n.print(out, 0, nullptr);
384
}
385

386
std::ostream& Graph::print(std::ostream& out, bool print_source_locations)
387
    const {
388
  out << "graph(" << const_value_list_with_types(inputs(), ",\n      ")
389
      << "):\n";
390
  std::vector<const Node*> groups;
391
  for (auto n : nodes()) {
392
    n->print(out, 1, &groups, print_source_locations);
393
  }
394
  out << "  return (" << outputs() << ")\n";
395
  size_t i = 0;
396
  for (auto fg : groups) {
397
    out << "with " << fg->kind().toQualString() << "_" << i++ << " = "
398
        << *fg->g(attr::Subgraph);
399
  }
400
  out.flush();
401

402
  /*
403
  // Uncomment this to debug all_nodes issues
404
  {
405
    out << "\n";
406
    out << "all_nodes:\n";
407
    for (auto& n : all_nodes) {
408
      printNode(out, const_cast<Node*>(n), nullptr);
409
    }
410
  }
411
  */
412
  return out;
413
}
414

415
std::ostream& operator<<(std::ostream& out, const Graph& g) {
416
  return g.print(out, true);
417
}
418

419
static void checkSameDevice(const Node* node) {
420
  bool has_device = false;
421
  c10::optional<at::Device> device = c10::nullopt;
422
  auto checkValue = [&](const Value* v) {
423
    if (TensorTypePtr type = v->type()->cast<TensorType>()) {
424
      if (type->device() && !has_device) {
425
        has_device = true;
426
        device = *type->device();
427
      } else {
428
        AT_ASSERT(device == type->device());
429
      }
430
    }
431
  };
432
  for (auto input : node->inputs()) {
433
    checkValue(input);
434
  }
435
  for (auto output : node->outputs()) {
436
    checkValue(output);
437
  }
438
}
439

440
using node_set = std::set<const Node*>;
441
#define ALL_OF(container) container.begin(), container.end()
442

443
// These functions purposely operate on the internal members directly, to force
444
// you to think about how the invariants change if you change the data
445
// representation (even if the external API does not change.)
446

447
// NB: This assert is written to assume you don't have any unattached
448
// nodes.  Unattached nodes can occur while manipulations to the
449
// graph are occurring.
450
void Node::lint() const {
451
  // Node invariants
452
  // - if node should live in list, nodes_iter is consistent
453
  // - Inputs are all marked as a use by the nodes they refer to
454
  // - Owning graph is non-null and consistent
455
  // - The "Select" invariant, when the node is MultiReturn
456
  //
457
  // The handle invariant:
458
  //    If a node takes a handle as an input, it is always the
459
  //    LAST input of the node.  There is at most one handle input.
460

461
  {
462
    size_t i = 0;
463
    for (auto input : inputs_) {
464
      // WARNING: O(n^2)
465
      // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
466
      AT_ASSERT(
467
          std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) !=
468
          input->uses_.end());
469
      AT_ASSERT(graph_->all_nodes.count(this) == 1);
470
      i++;
471
    }
472
  }
473

474
  for (auto o : outputs()) {
475
    for (auto use : o->uses()) {
476
      // Use invariants
477
      // - Use is consistent with inputs
478
      // - Every user node is live (checked in Graph)
479
      AT_ASSERT(use.user->inputs_[use.offset] == o);
480
    }
481
  }
482

483
  // Node subclass invariants
484
  switch (kind()) {
485
    case prim::Constant:
486
      AT_ASSERT(inputs_.empty());
487
      break;
488
    case prim::Return:
489
      // Return uses is zero
490
      AT_ASSERT(outputs().empty());
491
      break;
492
    case prim::Param:
493
      // Param inputs is zero
494
      AT_ASSERT(inputs_.empty());
495
      break;
496
    case prim::PythonOp: {
497
      // Python operator cconv is correct
498
      auto* value = static_cast<const PythonOp*>(this);
499
      value->lint_python();
500
      break;
501
    }
502
    case prim::Eval:
503
      // TODO: add invariants
504
      // TODO: It's not good for these ops to be top-level, it makes cases
505
      // longer.
506
      break;
507
    case prim::FusionGroup:
508
    case prim::CudaFusionGroup:
509
    case prim::oneDNNFusionGroup:
510
      checkSameDevice(this);
511
      // TODO: Typecheck the parameters
512
      g(attr::Subgraph)->lint();
513
      break;
514
  }
515
}
516

517
// TODO: When lint fails, give better indication about which
518
// instruction triggered the failure.
519
void Graph::lint() const {
520
  // Graph invariants
521

522
  // Uncomment the following to see the graph
523
  // std::cout << *const_cast<Graph*>(this);
524

525
  // nodes
526
  // - nodes_ is a valid topological ordering for inputs
527
  // - No repeated nodes
528
  // - Params and return do NOT occur in nodes
529
  // - next_unique_ is greater than all uniques in graph
530
  // - uniques in all_nodes are unique
531
  // - every use will occur later in the toposort
532

533
  struct LintScope {
534
    LintScope() = default;
535
    LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {}
536
    bool contains(const Value* v) {
537
      return values.count(v) > 0 || (parent && parent->contains(v));
538
    }
539
    bool contains(const Node* n) {
540
      return nodes.count(n) > 0 || (parent && parent->contains(n));
541
    }
542
    void insert(const Value* v) {
543
      AT_ASSERT(!contains(v));
544
      values.insert(v);
545
    }
546
    void insert(const Node* n) {
547
      AT_ASSERT(!contains(n));
548
      nodes.insert(n);
549
    }
550
    // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
551
    std::unique_ptr<LintScope> parent;
552

553
   private:
554
    std::unordered_set<const Value*> values;
555
    std::unordered_set<const Node*> nodes;
556
  };
557
  // Struct enables mutual recursion in linting methods.
558
  // Putting it inside Graph::lint enables access to private Graph members
559
  struct LintImpl {
560
    LintImpl(const Graph& g)
561
        : g(g),
562
          scope(new LintScope()),
563
          all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered*
564
    const Graph& g;
565
    std::unique_ptr<LintScope> scope;
566
    std::unordered_set<size_t> seen_uniques;
567
    std::unordered_map<const Node*, int64_t> anticipated_uses;
568
    node_set all_nodes_set;
569
    node_set sum_set;
570

571
    void check_value(const Value* v) {
572
      scope->insert(v);
573
      auto b2 = seen_uniques.insert(v->unique());
574
      AT_ASSERT(b2.second); // insertion took place
575
      AT_ASSERT(v->unique() < g.next_unique_);
576

577
      for (auto use : v->uses()) {
578
        AT_ASSERT(!scope->contains(use.user));
579
        AT_ASSERT(g.all_nodes.count(use.user) == 1);
580
        anticipated_uses[use.user]++; // int default constructs to 0
581
      }
582
    }
583
    void check_node(const Node* n) {
584
      for (auto input : n->inputs_) {
585
        if (!scope->contains(input)) {
586
          AT_ASSERTM(0, input->unique(), " not in scope");
587
        }
588
      }
589
      AT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
590
      anticipated_uses[n] = -1; // we saw the anticipated user!
591
      scope->insert(n);
592
      for (auto block : n->blocks()) {
593
        std::unique_ptr<LintScope> new_scope(new LintScope(std::move(scope)));
594
        scope = std::move(new_scope);
595
        check_block(block);
596
        scope = std::move(scope->parent);
597
      }
598
      size_t i = 0;
599
      for (auto o : n->outputs()) {
600
        AT_ASSERT(o->node() == n);
601
        AT_ASSERT(i++ == o->offset_);
602
        check_value(o);
603
      }
604
      n->lint();
605
    }
606
    void check_block(const Block* b) {
607
      // Check topological ordering
608
      AT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
609
      auto curNode = *b->nodes().begin();
610
      while (curNode != b->return_node()) {
611
        AT_ASSERT(curNode->isBefore(curNode->next()));
612
        curNode = curNode->next();
613
      }
614

615
      for (auto input : b->inputs()) {
616
        check_value(input);
617
        AT_ASSERT(input->node()->kind_ == prim::Param);
618
      }
619

620
      for (auto n : b->nodes()) {
621
        AT_ASSERT(n->kind_ != prim::Param);
622
        AT_ASSERT(n->kind_ != prim::Return);
623
        check_node(n);
624
      }
625

626
      AT_ASSERT(b->output_->kind() == prim::Return);
627
      check_node(b->output_);
628

629
      // all_nodes
630
      // - inputs_, output_ and nodes_ are all included in all_nodes
631
      // - all_nodes does not contain dead nodes??? (likely to be temporarily
632
      // suspended).  Weaker: all_nodes contains all inputs and returns
633
      // - only one return node???
634

635
      node_set nodes_set(ALL_OF(b->nodes()));
636
      node_set inputs_set{b->input_};
637
      node_set output_set{b->output_};
638
      // TODO: Make a more type safe std::includes wrapper which disallows use
639
      // on non-ordered containers
640
      AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set)));
641
      AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set)));
642
      AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set)));
643

644
      sum_set.insert(ALL_OF(nodes_set));
645
      sum_set.insert(ALL_OF(inputs_set));
646
      sum_set.insert(ALL_OF(output_set));
647
    }
648
    void check_graph() {
649
      node_set all_nodes_set(
650
          ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered*
651

652
      check_block(g.block_);
653
      for (auto kv : anticipated_uses) {
654
        AT_ASSERT(kv.second == -1);
655
      }
656
      AT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set)));
657
    }
658
  };
659
  LintImpl(*this).check_graph();
660
}
661

662
void Graph::dump() const {
663
  std::cout << *this << "\n";
664
}
665

666
void Graph::push_scope(const std::string& scope_name) {
667
  current_scope_ = current_scope_->push(Symbol::scope(scope_name));
668
  Node* block_node = insertNode(create(prim::TracedModuleForward, 0));
669
  block_node->s_(attr::scope, scope_name);
670
  Block* b = block_node->addBlock();
671
  setInsertPoint(b);
672
}
673
void Graph::pop_scope() {
674
  current_scope_ = current_scope_->parent();
675
  if (insertPoint()->owningBlock()->owningNode()->kind() ==
676
      prim::TracedModuleForward) {
677
    setInsertPoint(insertPoint()->owningBlock()->owningNode()->next());
678
  }
679
}
680

681
void LintGraph(const std::shared_ptr<Graph>& graph) {
682
  graph->lint();
683
}
684

685
Block::Block(Graph* graph_, Node* node_)
686
    : graph_(graph_),
687
      output_(graph_->create(prim::Return, 0)),
688
      input_(graph_->create(prim::Param, 0)),
689
      owning_node_(node_) {
690
  input_->next() = output_;
691
  input_->prev() = output_;
692
  output_->next() = input_;
693
  output_->prev() = input_;
694

695
  graph_->all_blocks.emplace(this);
696
  output_->owning_block_ = this;
697
  output_->topo_position_ = kUpperBound;
698
  input_->owning_block_ = this;
699
  input_->topo_position_ = kLowerBound;
700
}
701

702
void Block::reIndexTopology() {
703
  auto curPos = kLowerBound;
704
  for (auto node : nodes()) {
705
    AT_ASSERT(curPos <= (kUpperBound - kAppendInterval));
706
    curPos += kAppendInterval;
707
    node->topo_position_ = curPos;
708
  }
709
}
710

711
void Block::cloneFrom(Block* src, std::function<Value*(Value*)> value_map) {
712
  std::unordered_map<Value*, Value*> local_map;
713
  auto env = [&](Value* v) {
714
    auto it = local_map.find(v);
715
    if (it != local_map.end()) {
716
      return it->second;
717
    }
718
    return value_map(v);
719
  };
720

721
  auto graph = owningGraph();
722
  for (auto input : src->inputs()) {
723
    local_map[input] = this->addInput()->copyMetadata(input);
724
  }
725

726
  for (auto node : src->nodes()) {
727
    auto new_node = this->appendNode(graph->createClone(node, env));
728
    for (size_t i = 0; i < node->outputs().size(); ++i) {
729
      auto oo = node->outputs()[i];
730
      auto no = new_node->outputs()[i];
731
      local_map[oo] = no;
732
      no->copyMetadata(oo);
733
    }
734
  }
735
  for (auto output : src->outputs()) {
736
    this->registerOutput(env(output));
737
  }
738
}
739

740
void Block::destroy() {
741
  // we cannot destroy the output because it is used as the sentinel
742
  // for the nodes() list and has to remain valid for the loop
743
  output_->removeAllInputs();
744
  for (auto it = this->nodes().reverse().begin(),
745
            end = this->nodes().reverse().end();
746
       it != end;
747
       ++it) {
748
    it.destroyCurrent();
749
  }
750
  output_->destroy();
751
  input_->destroy();
752
  graph_->freeBlock(this);
753
}
754

755
void Graph::cloneFrom(Graph& src) {
756
  auto env = [](Value* v) -> Value* {
757
    AT_ERROR(
758
        "Graph::copy() encountered a use of a value " + v->debugName() +
759
        " not in scope. Run lint!");
760
  };
761
  block()->cloneFrom(src.block(), env);
762
}
763

764
std::shared_ptr<Graph> Graph::copy() {
765
  auto new_g = std::make_shared<Graph>();
766
  new_g->cloneFrom(*this);
767
  return new_g;
768
}
769

770
std::unique_ptr<Graph> Graph::copyUnique() {
771
  auto new_g = std::make_unique<Graph>();
772
  new_g->cloneFrom(*this);
773
  return new_g;
774
}
775

776
void Block::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
777
  for (Value* input : inputs()) {
778
    input->setType(type_map(input->type()));
779
  }
780
  for (Node* node : nodes()) {
781
    for (Value* output : node->outputs()) {
782
      output->setType(type_map(output->type()));
783
    }
784
    for (Block* sub_block : node->blocks()) {
785
      sub_block->remapTypes(type_map);
786
    }
787
    for (Symbol name : node->attributeNames()) {
788
      if (node->kindOf(name) == AttributeKind::g) {
789
        node->g(name)->remapTypes(type_map);
790
      } else if (node->kindOf(name) == AttributeKind::gs) {
791
        for (const auto& g : node->gs(name)) {
792
          g->remapTypes(type_map);
793
        }
794
      }
795
    }
796
  }
797
}
798

799
void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
800
  block()->remapTypes(type_map);
801
}
802

803
void Value::inferTypeFrom(const at::Tensor& output) {
804
  setType(TensorType::create(output));
805
}
806

807
void Value::inferTypeFrom(
808
    const c10::intrusive_ptr<c10::ivalue::Object>& output) {
809
  setType(output->type());
810
}
811

812
bool Value::mustBeNone() const {
813
  return type()->cast<NoneType>() || node_->mustBeNone();
814
}
815
bool Value::mustNotBeNone() const {
816
  return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
817
      !type()->cast<OptionalType>() &&
818
      !(type()->cast<UnionType>() &&
819
        type()->expect<UnionType>()->canHoldType(*NoneType::get()));
820
}
821

822
std::string Value::debugNameBase() const {
823
  std::string name = debugName();
824
  std::string name_base = name;
825
  auto last_dot_pos = name.find_last_of('.');
826
  if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
827
    if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
828
        std::string::npos) {
829
      name_base = name.substr(0, last_dot_pos);
830
    }
831
  }
832
  return name_base;
833
}
834

835
bool Value::isValidName(const std::string& name) {
836
  // Empty strings are legal
837
  if (name.empty()) {
838
    return true;
839
  }
840

841
  // Numbers are not legal
842
  if (isNumber(name)) {
843
    return false;
844
  }
845

846
  return true;
847
}
848

849
Value* Value::setDebugName(const std::string& name) {
850
  if (!isValidName(name)) {
851
    throw std::runtime_error("Invalid name: '" + name + "'");
852
  }
853

854
  auto& names = node()->owningGraph()->unique_names_;
855

856
  // clear any old name from the map
857
  if (hasDebugName()) {
858
    names.erase(unique_name_);
859
    unique_name_ = "";
860
  }
861

862
  // allow "" to clear the uniquename
863
  if (name.empty()) {
864
    return this;
865
  }
866

867
  // if someone else has this name, then rename the other value
868
  auto old_owner_of_name = names.find(name);
869
  if (old_owner_of_name != names.end()) {
870
    size_t suffix = 1;
871
    std::string name_base = name;
872
    auto last_dot_pos = name.find_last_of('.');
873
    if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
874
      if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
875
          std::string::npos) {
876
        suffix = std::stoll(name.substr(last_dot_pos + 1));
877
        name_base = name.substr(0, last_dot_pos);
878
      }
879
    }
880

881
    auto& names_suffixes = node()->owningGraph()->name_base_suffix_;
882
    auto it = names_suffixes.find(name_base);
883
    if (it != names_suffixes.end()) {
884
      suffix = std::max(suffix, it->second + 1);
885
    }
886

887
    // Verify that new name is not used and find next usable name in case
888
    // suffix is used.
889
    std::string replacement_name;
890
    do {
891
      std::stringstream ss;
892
#ifndef _WIN32
893
      // Protect 12345 integer from becoming "1,2345" if some other process sets
894
      // global locale For more details see
895
      // https://github.com/pytorch/pytorch/issues/79583#issuecomment-1161260061
896
      static std::locale c_locale("C");
897
      ss.imbue(c_locale);
898
#endif
899
      ss << name_base << "." << suffix++;
900
      replacement_name = ss.str();
901
    } while (names.count(replacement_name) > 0);
902

903
    names_suffixes[name_base] = suffix;
904

905
    old_owner_of_name->second->setDebugName(replacement_name);
906
  }
907

908
  names[name] = this;
909
  unique_name_ = name;
910
  return this;
911
}
912

913
Value* Value::copyMetadata(Value* from) {
914
  setType(from->type());
915
  if (from->hasDebugName()) {
916
    setDebugName(from->debugName());
917
  }
918
  return this;
919
}
920

921
void Value::replaceFirstUseWith(Value* newValue) {
922
  AT_ASSERT(owningGraph() == newValue->owningGraph());
923
  auto u = uses()[0];
924
  u.user->inputs_[u.offset] = newValue;
925
  newValue->uses_.push_back(u);
926
  uses_.erase(uses_.begin());
927
}
928

929
void Value::replaceAllUsesWith(Value* newValue) {
930
  while (!uses().empty()) {
931
    replaceFirstUseWith(newValue);
932
  }
933
}
934

935
void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) {
936
  std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
937
    if (u.user->isAfter(node)) {
938
      u.user->inputs_[u.offset] = newValue;
939
      newValue->uses_.push_back(u);
940
    }
941
  });
942

943
  uses_.erase(
944
      std::remove_if(
945
          uses_.begin(),
946
          uses_.end(),
947
          [&node](const Use& u) { return u.user->isAfter(node); }),
948
      uses_.end());
949
}
950

951
void Value::replaceAllUsesDominatedByNodeWith(
952
    const Node* node,
953
    Value* newValue) {
954
  std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
955
    if (u.user->isDominatedBy(node)) {
956
      u.user->inputs_[u.offset] = newValue;
957
      newValue->uses_.push_back(u);
958
    }
959
  });
960

961
  uses_.erase(
962
      std::remove_if(
963
          uses_.begin(),
964
          uses_.end(),
965
          [&node](const Use& u) { return u.user->isDominatedBy(node); }),
966
      uses_.end());
967
}
968

969
static size_t findArgument(
970
    const FunctionSchema& the_schema,
971
    const std::string& unqualName) {
972
  for (const auto i : c10::irange(the_schema.arguments().size())) {
973
    const Argument* arg = &the_schema.arguments()[i];
974
    if (arg->name() == unqualName) {
975
      return i;
976
    }
977
  }
978
  throw std::runtime_error(
979
      std::string("Couldn't find an argument called ") + unqualName);
980
}
981

982
static size_t findArgument(const FunctionSchema& the_schema, Symbol name) {
983
  const auto unqualName = name.toUnqualString();
984
  return findArgument(the_schema, unqualName);
985
}
986

987
c10::optional<IValue> Node::get(Symbol name) const {
988
  return toIValue(namedInput(name));
989
}
990

991
bool Node::hasNamedInput(const std::string& name) const {
992
  for (const auto& argument : schema().arguments()) {
993
    if (argument.name() == name) {
994
      return true;
995
    }
996
  }
997
  return false;
998
}
999

1000
Value* Node::namedInput(const std::string& unqualName) const {
1001
  return input(findArgument(schema(), unqualName));
1002
}
1003
Value* Node::namedInput(Symbol name) const {
1004
  return input(findArgument(schema(), name));
1005
}
1006

1007
bool Node::matches(const FunctionSchema& schema) const {
1008
  if (isBlockListedSchema(schema)) {
1009
    return false;
1010
  }
1011
  // wrong name
1012
  if (kind().toQualString() != schema.name()) {
1013
    return false;
1014
  }
1015
  at::ArrayRef<const Value*> actuals = inputs();
1016
  const auto& formals = schema.arguments();
1017

1018
  // not enough inputs
1019
  if (actuals.size() < formals.size()) {
1020
    return false;
1021
  }
1022

1023
  TypeEnv type_env;
1024
  for (const auto i : c10::irange(formals.size())) {
1025
    auto formal = formals[i].type();
1026
    const MatchTypeReturn matched_type =
1027
        matchTypeVariables(formal, actuals[i]->type(), type_env);
1028
    if (!matched_type.success()) {
1029
      return false;
1030
    }
1031

1032
    TypePtr resolved = tryEvalTypeVariables(formal, type_env);
1033
    if (resolved) {
1034
      formal = resolved;
1035
    }
1036
    // note: it is possible at this point that type variable matching has
1037
    // not resolved all type variables, e.g. if None was matched to Optional[T]
1038
    // we will not succeed at matching T. However None <: Optional[T] so this
1039
    // check can still succeed.
1040

1041
    if (!actuals[i]->type()->isSubtypeOf(*formal)) {
1042
      return false;
1043
    }
1044
  }
1045

1046
  // too many inputs
1047
  if (!schema.is_vararg() && actuals.size() != formals.size()) {
1048
    return false;
1049
  }
1050

1051
  return true;
1052
}
1053

1054
bool Node::matches(
1055
    const char* signature_literal,
1056
    at::ArrayRef<Symbol> const_inputs) const {
1057
  if (!matches(getOperatorForLiteral(signature_literal)->schema())) {
1058
    return false;
1059
  }
1060
  for (Symbol s : const_inputs) {
1061
    if (!is_constant(s)) {
1062
      return false;
1063
    }
1064
  }
1065
  return true;
1066
}
1067

1068
bool Node::mustBeNone() const {
1069
  // We can statically deduce this Node has returning None if:
1070
  return
1071
      // It's an AutogradZero node, or ...
1072
      kind_ == prim::AutogradZero ||
1073
      // It has only one output and that output is NoneType, or ...
1074
      (outputs().size() == 1 && output()->type() == NoneType::get()) ||
1075
      // It's a constant optional with no value in the attributes.
1076
      (kind_ == prim::Constant && !this->hasAttributes() &&
1077
       output()->type()->cast<OptionalType>());
1078
}
1079

1080
void Node::dump() const {
1081
  std::cout << *this << "\n";
1082
}
1083

1084
const FunctionSchema& Node::schema() const {
1085
  if (op_) {
1086
    return op_->schema();
1087
  }
1088
  return getOperator().schema();
1089
}
1090

1091
const FunctionSchema* Node::maybeSchema() const {
1092
  if (auto op = maybeOperator()) {
1093
    return &op->schema();
1094
  }
1095
  return nullptr;
1096
}
1097

1098
const Operator* Node::maybeOperator() const {
1099
  if (!op_) {
1100
    const auto& candidates = getAllOperatorsFor(kind());
1101
    for (const auto& candidate : candidates) {
1102
      if (matches(candidate->schema())) {
1103
        op_ = candidate.get();
1104
        break;
1105
      }
1106
    }
1107
  }
1108
  return op_;
1109
}
1110

1111
const Operator& Node::getOperator() const {
1112
  const Operator* maybe = maybeOperator();
1113
  if (maybe)
1114
    return *maybe;
1115

1116
  auto er = ErrorReport(sourceRange());
1117
  er << "Schema not found for node. File a bug report.\n";
1118
  er << "Node: " << *this << "\n";
1119
  er << "Input types:";
1120
  for (const auto i : c10::irange(inputs().size())) {
1121
    if (i > 0)
1122
      er << ", ";
1123
    er << *inputs()[i]->type();
1124
  }
1125
  const auto& candidates = getAllOperatorsFor(kind());
1126
  if (!candidates.empty()) {
1127
    er << "\ncandidates were:\n";
1128
    for (auto& candidate : candidates) {
1129
      er << "  " << candidate->schema() << "\n";
1130
    }
1131
  } else {
1132
    er << "\nno candidates found\n";
1133
  }
1134
  er << "within the graph:\n";
1135
  er << *owningGraph() << "\n";
1136
  throw er;
1137
}
1138

1139
Operation Node::getOperation() const {
1140
  // note: some operators require the node to produce a runnable operation,
1141
  // which is why 'this' is passed here. getOperator() ensures that 'this'
1142
  // matches the schema of the returned operator.
1143
  return getOperator().getOperation(this);
1144
}
1145

1146
bool Node::isNondeterministic() const {
1147
  const auto schema = maybeSchema();
1148
  if (!kind().is_aten()) {
1149
    return false;
1150
  }
1151
  // All aten ops are expecte to have a schema. However this is left as a
1152
  // warning instead of an assert to ensure that previous use cases do not
1153
  // break.
1154
  if (!schema) {
1155
    TORCH_WARN("aten Schema not found.");
1156
    return false;
1157
  }
1158
  torch::utils::SchemaInfo schema_info(*schema);
1159
  if (hasNamedInput("train")) {
1160
    auto value = constant_as<bool>(namedInput("train"));
1161
    if (value.has_value()) {
1162
      schema_info.addArgumentValue("train", *value);
1163
    }
1164
  }
1165
  return schema_info.is_nondeterministic();
1166
}
1167

1168
bool Node::hasSideEffects() const {
1169
  switch (kind_) {
1170
    case prim::PythonOp:
1171
    case prim::IgnoredPythonOp:
1172
    case prim::Print:
1173
    case prim::RaiseException:
1174
    case aten::warn:
1175
    case aten::save:
1176
    case aten::manual_seed:
1177
    case prim::AddStatValue:
1178
    case prim::TimePoint:
1179
    case prim::CallFunction:
1180
    case prim::CallMethod:
1181
    case prim::BailoutTemplate:
1182
    case prim::BailOut:
1183
    case prim::rpc_async: // It represents RPC message sent.
1184
    case prim::rpc_sync: // It represents RPC message sent.
1185
    case prim::rpc_remote: // It represents RPC message sent.
1186
    case aten::wait: // It can represent RPC message received.
1187
#if !defined(USE_ROCM)
1188
    case cuda::set_stream:
1189
    case cuda::_set_device:
1190
    case cuda::_current_device:
1191
    case cuda::synchronize:
1192
#endif
1193
    case prim::Enter:
1194
    case prim::Exit:
1195
      return true;
1196
  }
1197

1198
  auto op = maybeOperator();
1199
  if (!op) {
1200
    TORCH_INTERNAL_ASSERT(
1201
        kind_.is_prim(),
1202
        "Only prim ops are allowed to not have a registered operator but ",
1203
        kind_.toDisplayString(),
1204
        " doesn't have one either. We don't know if this op has side effects.");
1205
    return false;
1206
  }
1207

1208
  if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) {
1209
    // TODO There is nothing in the system that relies on aten:: and prim::
1210
    // ops using AliasAnalysisKind::FROM_SCHEMA,
1211
    // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or
1212
    // AliasAnalysisKind::CONSERVATIVE but this is the intended behavior for all
1213
    // current ops and a good error check. We can consider lifting this
1214
    // constraint later if we have a use case for it.
1215
    TORCH_INTERNAL_ASSERT(
1216
        op->aliasAnalysisKind() == AliasAnalysisKind::INTERNAL_SPECIAL_CASE ||
1217
            op->aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA ||
1218
            op->aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE,
1219
        "aten:: and prim:: ops should have AliasAnalysisKind::INTERNAL_SPECIAL_CASE"
1220
        ", AliasAnalysisKind::FROM_SCHEMA or AliasAnalysisKind::CONSERVATIVE but ",
1221
        kind_.toDisplayString(),
1222
        " has ",
1223
        toString(op->aliasAnalysisKind()));
1224
  }
1225

1226
  switch (op->aliasAnalysisKind()) {
1227
    case AliasAnalysisKind::PURE_FUNCTION:
1228
    case AliasAnalysisKind::FROM_SCHEMA:
1229
    case AliasAnalysisKind::INTERNAL_SPECIAL_CASE:
1230
      return false;
1231
    case AliasAnalysisKind::CONSERVATIVE:
1232
      return true;
1233
  }
1234
  TORCH_INTERNAL_ASSERT(false, "Unhandled AliasAnalysisKind case");
1235
  return false; // silence compiler warning
1236
}
1237

1238
// Assign this node a topological position, to facilitate fast isBefore() and
1239
// isAfter() queries. Must be called right after a node is inserted into the
1240
// node list.
1241
//
1242
// The basic scheme is: assign every node a position (uint64_t).  The common
1243
// case (appending to the end of the graph) is made more efficient by advancing
1244
// a fixed interval past the previous node and placing `this` there. Otherwise,
1245
// assign `this` a position at the midpoint between its prev() and next()
1246
// nodes.
1247
//
1248
// If we ever run out of space (by, e.g. inserting too much in place), we
1249
// reindex by spreading out all the nodes again.
1250
void Node::assignTopoPosition() {
1251
  bool is_first = prev() == owningBlock()->param_node();
1252
  bool is_last = next() == owningBlock()->return_node();
1253

1254
  const auto prevPos = prev()->topo_position_;
1255
  const auto nextPos = next()->topo_position_;
1256

1257
  // Append to the end of the graph
1258
  if (is_last) {
1259
    if (is_first) {
1260
      // the node list is empty, assign the first position
1261
      topo_position_ = kMidPoint;
1262
      return;
1263
    }
1264

1265
    if (prevPos >= (kUpperBound - kAppendInterval)) {
1266
      // we're running off the edge
1267
      owningBlock()->reIndexTopology();
1268
      return;
1269
    }
1270

1271
    topo_position_ = prevPos + kAppendInterval;
1272

1273
    // Prepend to the graph
1274
  } else if (is_first) {
1275
    // next() is the first element in the block list
1276
    if (nextPos <= (kLowerBound + kAppendInterval)) {
1277
      // we're running off the edge
1278
      owningBlock()->reIndexTopology();
1279
      return;
1280
    }
1281
    topo_position_ = nextPos - kAppendInterval;
1282

1283
    // insert between two existing nodes
1284
  } else {
1285
    int64_t remaining = nextPos - prevPos;
1286
    AT_ASSERT(remaining > 0);
1287
    if (remaining == 1) {
1288
      // There was no room
1289
      owningBlock()->reIndexTopology();
1290
      return;
1291
    }
1292
    int64_t predicted_future_insertions = 0;
1293
    if (next() == graph_->insertPoint()) {
1294
      predicted_future_insertions = graph_->predicted_insert_count_++;
1295
    }
1296
    topo_position_ = prevPos +
1297
        std::max(int64_t(1), remaining / (2 + predicted_future_insertions));
1298
    AT_ASSERT(prevPos < topo_position_ && topo_position_ < nextPos);
1299
  }
1300
}
1301

1302
Node::Node(Graph* graph_, NodeKind kind_)
1303
    : kind_(kind_),
1304
      graph_(graph_),
1305
      owning_block_(nullptr),
1306
      scope_(graph_->current_scope_),
1307
      callstack_(c10::nullopt),
1308
      op_(nullptr),
1309
      topo_position_(0) {
1310
  graph_->all_nodes.emplace(this);
1311
}
1312

1313
void Node::eraseOutput(size_t i) {
1314
  AT_ASSERT(i < outputs_.size());
1315
  AT_ASSERT(outputs_[i]->uses().empty());
1316
  op_ = nullptr;
1317
  Value* n = outputs_[i];
1318
  outputs_.erase(outputs_.begin() + i);
1319
  owningGraph()->freeValue(n);
1320
  for (const auto j : c10::irange(i, outputs_.size())) {
1321
    outputs_[j]->offset_--;
1322
  }
1323
}
1324

1325
Block* Node::addBlock() {
1326
  op_ = nullptr;
1327
  blocks_.push_back(new Block(owningGraph(), this));
1328
  return blocks_.back();
1329
}
1330

1331
void Node::eraseBlock(size_t i) {
1332
  AT_ASSERT(i < blocks_.size());
1333
  op_ = nullptr;
1334
  Block* n = blocks_[i];
1335
  blocks_.erase(blocks_.begin() + i);
1336
  n->destroy();
1337
}
1338

1339
void Node::destroy() {
1340
  while (!outputs().empty()) {
1341
    eraseOutput(outputs().size() - 1);
1342
  }
1343
  while (!blocks().empty()) {
1344
    eraseBlock(blocks().size() - 1);
1345
  }
1346
  removeAllInputs();
1347
  if (inBlockList()) {
1348
    removeFromList();
1349
  }
1350
  graph_->freeNode(this);
1351
}
1352

1353
void Node::cloneFrom(Node* s) {
1354
  source_range_ = s->source_range_;
1355
  if (s->scope_ && !s->scope_->isBlank()) {
1356
    scope_ = s->scope_;
1357
  }
1358
  copyAttributes(*s);
1359
  callstack_ = s->callstack_;
1360
}
1361

1362
void Node::replaceAllUsesWith(Node* n) {
1363
  AT_ASSERT(outputs().size() == n->outputs().size());
1364
  size_t nOutputs = outputs().size();
1365
  for (const auto i : c10::irange(nOutputs)) {
1366
    outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
1367
  }
1368
}
1369

1370
Node* Node::replaceWithNewSymbol(Symbol new_symbol) {
1371
  WithInsertPoint insert_guard{this};
1372
  bool had_operator = maybeOperator() != nullptr;
1373
  auto graph = owningGraph();
1374
  auto replace_node = graph->insertNode(graph->create(new_symbol, 0));
1375
  for (Value* v : inputs()) {
1376
    replace_node->addInput(v);
1377
  }
1378
  for (Value* v : outputs()) {
1379
    auto new_out = replace_node->addOutput()->copyMetadata(v);
1380
    v->replaceAllUsesWith(new_out);
1381
  }
1382
  replace_node->copyMetadata(this);
1383
  replace_node->copyAttributes(*this);
1384
  TORCH_INTERNAL_ASSERT(
1385
      (replace_node->maybeOperator() != nullptr) == had_operator,
1386
      "invalid symbol replacement:",
1387
      new_symbol,
1388
      kind());
1389
  return replace_node;
1390
}
1391

1392
bool Node::isDominatedBy(const Node* dominator) const {
1393
  const Node* node = this;
1394
  while (node) {
1395
    if (node->owningBlock() == dominator->owningBlock()) {
1396
      return dominator->isBefore(node);
1397
    }
1398
    node = node->owningBlock()->owningNode();
1399
  }
1400
  return false;
1401
}
1402

1403
Value* Node::insertInput(size_t i, Value* value) {
1404
  AT_ASSERT(graph_ == value->owningGraph());
1405
  op_ = nullptr;
1406
  // First we update the offsets for all existing inputs that will reside
1407
  // after the one we're inserting. Concretely, these are the inputs at
1408
  // indices [i, # input). Since we're inserting one input before all of
1409
  // these inputs, increment their use offsets for this value by 1
1410
  for (const auto use_itr : c10::irange(i, inputs_.size())) {
1411
    // See Note [User node does not uniquely identify use]
1412
    auto use = findUseForInput(use_itr);
1413
    use->offset += 1;
1414
  }
1415
  // Insert the actual input at the specified index
1416
  inputs_.insert(inputs_.begin() + i, value);
1417
  // Register the new use of the value we're inserted as an input.
1418
  value->uses_.emplace_back(this, i);
1419
  return value;
1420
}
1421

1422
Value* Node::addInput(Value* value) {
1423
  AT_ASSERT(graph_ == value->owningGraph());
1424
  op_ = nullptr;
1425
  value->uses_.emplace_back(this, inputs_.size());
1426
  inputs_.push_back(value);
1427
  return value;
1428
}
1429

1430
Value* Node::replaceInput(size_t i, Value* newValue) {
1431
  AT_ASSERT(newValue->owningGraph() == graph_);
1432
  op_ = nullptr;
1433
  Value* old = dropInput(i);
1434
  inputs_[i] = newValue;
1435
  newValue->uses_.emplace_back(this, i);
1436
  return old;
1437
}
1438

1439
void Node::replaceInputWith(Value* from, Value* to) {
1440
  AT_ASSERT(from->owningGraph() == graph_);
1441
  AT_ASSERT(to->owningGraph() == graph_);
1442
  op_ = nullptr;
1443
  size_t i = 0;
1444
  for (auto input : inputs()) {
1445
    if (input == from) {
1446
      replaceInput(i, to);
1447
    }
1448
    i++;
1449
  }
1450
}
1451

1452
Value* Node::addOutput() {
1453
  outputs_.push_back(new Value(this, outputs_.size()));
1454
  op_ = nullptr;
1455
  return outputs_.back();
1456
}
1457

1458
Value* Node::insertOutput(size_t i) {
1459
  op_ = nullptr;
1460
  outputs_.insert(outputs_.begin() + i, new Value(this, i));
1461
  for (size_t itr = i + 1; itr < outputs_.size(); ++itr) {
1462
    outputs_[itr]->setOffset(outputs_[itr]->offset() + 1);
1463
  }
1464
  return outputs_.at(i);
1465
}
1466

1467
bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
1468
  if (this->owningBlock() == n->owningBlock()) {
1469
    if (moveSide == MoveSide::BEFORE) {
1470
      return this->topo_position_ < n->topo_position_;
1471
    }
1472

1473
    if (moveSide == MoveSide::AFTER) {
1474
      return this->topo_position_ > n->topo_position_;
1475
    }
1476

1477
    AT_ASSERT(this == n);
1478
    return false;
1479
  }
1480

1481
  // These nodes don't share a common block. Traverse the blockchains upward
1482
  // until we find the first common block.
1483
  auto lhs = this;
1484
  while (lhs) {
1485
    AT_ASSERT(lhs->owningBlock());
1486

1487
    auto rhs = n;
1488
    while (rhs) {
1489
      if (!rhs->owningBlock()) {
1490
        break;
1491
      }
1492

1493
      if (lhs->owningBlock() == rhs->owningBlock()) {
1494
        return lhs->isBeforeOrAfter(rhs, moveSide);
1495
      }
1496
      rhs = rhs->owningBlock()->owningNode();
1497
    }
1498

1499
    lhs = lhs->owningBlock()->owningNode();
1500
  }
1501
  // should never reach here, since both nodes are ultimately in the same graph
1502
  AT_ASSERT(false);
1503
}
1504

1505
bool Node::isBefore(const Node* n) const {
1506
  return isBeforeOrAfter(n, MoveSide::BEFORE);
1507
}
1508

1509
bool Node::isAfter(const Node* n) const {
1510
  return isBeforeOrAfter(n, MoveSide::AFTER);
1511
}
1512

1513
Node* Node::insertBefore(Node* n) {
1514
  AT_ASSERT(n->inBlockList());
1515
  insertAfter(n->prev());
1516
  return this;
1517
}
1518

1519
Node* Node::insertAfter(Node* n) {
1520
  AT_ASSERT(!inBlockList() && n->inBlockList());
1521
  AT_ASSERT(n->owningBlock());
1522
  AT_ASSERTM(
1523
      n->kind() != prim::Return,
1524
      "Attempting to insert a Node after the Return node or before the Param node. Tried to insert",
1525
      *this,
1526
      " after ",
1527
      *n,
1528
      ".");
1529
  this->owning_block_ = n->owningBlock();
1530
  Node* next = n->next();
1531
  n->next() = this;
1532
  this->prev() = n;
1533
  this->next() = next;
1534
  next->prev() = this;
1535
  assignTopoPosition();
1536
  return this;
1537
}
1538

1539
void Node::moveAfter(Node* n) {
1540
  removeFromList();
1541
  insertAfter(n);
1542
}
1543

1544
void Node::moveBefore(Node* n) {
1545
  removeFromList();
1546
  insertBefore(n);
1547
}
1548

1549
void Node::removeInput(size_t i) {
1550
  op_ = nullptr;
1551
  dropInput(i);
1552
  // everything after this input shifts left,
1553
  // so we need to update their use offsets to match
1554
  for (size_t j = i + 1; j < inputs_.size(); j++) {
1555
    auto it = findUseForInput(j);
1556
    it->offset--;
1557
  }
1558
  inputs_.erase(inputs_.begin() + i);
1559
}
1560

1561
void Node::removeAllInputs() {
1562
  op_ = nullptr;
1563
  for (const auto i : c10::irange(inputs().size())) {
1564
    dropInput(i);
1565
  }
1566
  inputs_.clear();
1567
}
1568

1569
void Node::removeAllOutputs() {
1570
  op_ = nullptr;
1571
  size_t init_osize = outputs_.size();
1572
  for (auto i : c10::irange(init_osize)) {
1573
    eraseOutput(init_osize - i - 1);
1574
  }
1575
}
1576

1577
void Node::permuteInputs(const std::vector<size_t>& new_order) {
1578
  op_ = nullptr;
1579
  AT_ASSERT(new_order.size() == inputs_.size());
1580
  std::vector<Value*> new_inputs;
1581
  new_inputs.reserve(new_order.size());
1582
  for (const auto i : c10::irange(new_order.size())) {
1583
    AT_ASSERTM(inputs_.at(new_order[i]) != nullptr, "Repeated index");
1584
    new_inputs.push_back(inputs_.at(new_order[i]));
1585
    auto it = findUseForInput(new_order[i]);
1586
    it->offset = i;
1587
    inputs_.at(new_order[i]) = nullptr;
1588
  }
1589
  inputs_ = std::move(new_inputs);
1590
}
1591

1592
void Node::permuteOutputs(const std::vector<size_t>& new_order) {
1593
  op_ = nullptr;
1594
  AT_ASSERT(new_order.size() == outputs_.size());
1595
  std::vector<Value*> new_outputs;
1596
  new_outputs.reserve(new_order.size());
1597
  for (const auto i : c10::irange(new_order.size())) {
1598
    AT_ASSERTM(outputs_.at(new_order[i]) != nullptr, "Repeated index");
1599
    new_outputs.push_back(outputs_.at(new_order[i]));
1600
    outputs_.at(new_order[i])->setOffset(i);
1601
    outputs_.at(new_order[i]) = nullptr;
1602
  }
1603
  outputs_ = std::move(new_outputs);
1604
}
1605

1606
use_list::iterator Node::findUseForInput(size_t i) {
1607
  auto& input_uses = inputs_[i]->uses_;
1608
  // O(N) on the use list, but unless we get nodes with +100 uses
1609
  // vector traversal still is probably faster than linked list
1610
  auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
1611
  AT_ASSERT(use_it != input_uses.end());
1612
  return use_it;
1613
}
1614

1615
Value* Node::dropInput(size_t i) {
1616
  AT_ASSERT(i < inputs_.size());
1617
  auto input_node = inputs_[i];
1618
  auto use_it = findUseForInput(i);
1619
  input_node->uses_.erase(use_it);
1620
  inputs_[i] = nullptr;
1621
  return input_node;
1622
}
1623

1624
void Node::removeFromList() {
1625
  AT_ASSERT(inBlockList());
1626
  this->owning_block_ = nullptr;
1627
  Node* next = this->next();
1628
  Node* prev = this->prev();
1629
  prev->next() = next;
1630
  next->prev() = prev;
1631
  this->next() = nullptr;
1632
  this->prev() = nullptr;
1633
}
1634

1635
Block* Node::findCommonAncestorBlockWith(Node* n) {
1636
  if (n->owningBlock() == owningBlock()) {
1637
    return owningBlock();
1638
  }
1639

1640
  Node* n1 = this;
1641
  Node* n2 = n;
1642

1643
  size_t d_1 = n1->blocksFromGraphBlock();
1644
  size_t d_2 = n2->blocksFromGraphBlock();
1645

1646
  for (; d_1 > d_2; --d_1) {
1647
    n1 = n1->owningBlock()->owningNode();
1648
    // n2 contains n1
1649
  }
1650

1651
  for (; d_2 > d_1; --d_2) {
1652
    n2 = n2->owningBlock()->owningNode();
1653
  }
1654

1655
  // Now they are the same numer of blocks from the graph block,
1656
  // recurse upwards, checking if they are on the same block
1657
  while (true) {
1658
    if (n1->owningBlock() == n2->owningBlock()) {
1659
      return n1->owningBlock();
1660
    }
1661

1662
    n1 = n1->owningBlock()->owningNode();
1663
    n2 = n2->owningBlock()->owningNode();
1664

1665
    AT_ASSERT(n1 != nullptr);
1666
    AT_ASSERT(n2 != nullptr);
1667
  }
1668
}
1669

1670
size_t Node::blocksFromGraphBlock() {
1671
  Node* n = this;
1672
  size_t dist = 0;
1673
  while (n->owningBlock()->owningNode()) {
1674
    n = n->owningBlock()->owningNode();
1675
    ++dist;
1676
  }
1677
  return dist;
1678
}
1679

1680
inline const SourceRange& fakeRange() {
1681
  static SourceRange range(std::make_shared<Source>(std::string("")), 0, 1);
1682
  return range;
1683
}
1684

1685
Value* Graph::insert(
1686
    Symbol opname,
1687
    at::ArrayRef<NamedValue> args,
1688
    at::ArrayRef<NamedValue> kwargs,
1689
    const c10::optional<SourceRange>& range) {
1690
  return emitBuiltinCall(
1691
      range.value_or(fakeRange()), *this, opname, args, kwargs);
1692
}
1693

1694
Node* Graph::create(NodeKind kind, size_t num_outputs) {
1695
  // NB: Node constructor adds node to all_nodes
1696
  auto n = new Node(this, kind);
1697
  for (const auto i : c10::irange(num_outputs)) {
1698
    (void)i;
1699
    n->addOutput();
1700
  }
1701
  return n;
1702
}
1703

1704
Node* Graph::create(
1705
    NodeKind kind,
1706
    ArrayRef<Value*> inputs,
1707
    size_t num_outputs) {
1708
  auto n = create(kind, num_outputs);
1709
  for (auto i : inputs) {
1710
    n->addInput(i);
1711
  }
1712
  return n;
1713
}
1714

1715
Node* Graph::createAutogradZero() {
1716
  return create(prim::AutogradZero);
1717
}
1718

1719
Node* Graph::createNone() {
1720
  Node* n = create(prim::Constant);
1721
  n->output()->setType(NoneType::get());
1722
  return n;
1723
}
1724

1725
Node* Graph::createUninitialized(TypePtr typ) {
1726
  Node* n = create(prim::Uninitialized);
1727
  n->output()->setType(std::move(typ));
1728
  return n;
1729
}
1730

1731
Node* Graph::createWithSubgraph(Symbol kind) {
1732
  auto n = create(kind, 0);
1733
  n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope()));
1734
  return n;
1735
}
1736

1737
Node* Graph::createTuple(at::ArrayRef<Value*> values, TupleTypePtr tuple_type) {
1738
  TORCH_INTERNAL_ASSERT(
1739
      !tuple_type || tuple_type->schema(),
1740
      "only pass tuple_type when creating a named tuple");
1741
  if (!tuple_type) {
1742
    auto types = fmap(values, [](Value* v) { return v->type(); });
1743
    tuple_type = TupleType::create(std::move(types));
1744
  }
1745
  auto n = create(prim::TupleConstruct, values);
1746

1747
  n->output()->setType(tuple_type);
1748
  return n;
1749
}
1750

1751
Node* Graph::createTupleUnpack(Value* v) {
1752
  TupleTypePtr tt = v->type()->expect<TupleType>();
1753
  auto n = create(prim::TupleUnpack, {v}, 0);
1754
  for (auto& element : tt->elements()) {
1755
    n->addOutput()->setType(element);
1756
  }
1757
  return n;
1758
}
1759

1760
Node* Graph::createTupleIndex(
1761
    Value* tup,
1762
    Value* idx,
1763
    const TypePtr& output_type) {
1764
  auto n = create(prim::TupleIndex, {tup, idx});
1765
  n->output()->setType(output_type);
1766
  return n;
1767
}
1768

1769
Node* Graph::createTupleSlice(
1770
    Value* tup,
1771
    int64_t beg,
1772
    int64_t step_size,
1773
    int64_t num_values) {
1774
  std::vector<Value*> new_vals;
1775
  TupleTypePtr tt = tup->type()->expect<TupleType>();
1776
  new_vals.reserve(num_values);
1777

1778
  int64_t i = beg;
1779
  for (const auto j : c10::irange(num_values)) {
1780
    (void)j; // Suppress unused variable warning
1781
    auto idx = insertConstant(IValue(static_cast<int64_t>(i)));
1782
    auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i]));
1783

1784
    new_vals.push_back(tupleIndex->output());
1785
    i += step_size;
1786
  }
1787

1788
  auto n = createTuple(new_vals);
1789
  return n;
1790
}
1791

1792
Node* Graph::createEnumName(Value* e) {
1793
  e->type()->expect<EnumType>();
1794
  assert(e->type()->cast<EnumType>());
1795
  auto n = create(prim::EnumName, {e});
1796
  n->output()->setType(StringType::get());
1797
  return n;
1798
}
1799

1800
Node* Graph::createEnumValue(Value* e) {
1801
  auto enum_type = e->type()->expect<EnumType>();
1802
  auto n = create(prim::EnumValue, {e});
1803
  n->output()->setType(enum_type->getValueType());
1804
  return n;
1805
}
1806

1807
Node* Graph::createList(
1808
    const TypePtr& contained_type,
1809
    at::ArrayRef<Value*> values) {
1810
  auto n = create(prim::ListConstruct, values);
1811
  for (const auto& v : values) {
1812
    TORCH_CHECK(
1813
        v->type()->isSubtypeOf(*contained_type),
1814
        "Expected a list element that subtypes '",
1815
        contained_type->repr_str(),
1816
        "' but got an element of type '",
1817
        v->type()->repr_str(),
1818
        "'");
1819
  }
1820
  n->output()->setType(ListType::create(contained_type));
1821
  return n;
1822
}
1823

1824
Node* Graph::createListUnpack(Value* v, size_t size) {
1825
  ListTypePtr list_type = v->type()->expect<ListType>();
1826
  TypePtr elem_type = list_type->getElementType();
1827
  auto n = create(prim::ListUnpack, {v}, 0);
1828
  for (const auto i : c10::irange(size)) {
1829
    (void)i; // Suppress unused variable warning
1830
    n->addOutput()->setType(elem_type);
1831
  }
1832
  return n;
1833
}
1834

1835
Node* Graph::createDict(
1836
    const TypePtr& key_type,
1837
    const TypePtr& value_type,
1838
    at::ArrayRef<Value*> keys,
1839
    at::ArrayRef<Value*> values) {
1840
  AT_ASSERT(keys.size() == values.size());
1841
  auto n = create(prim::DictConstruct, 1);
1842
  for (const auto i : c10::irange(keys.size())) {
1843
    AT_ASSERT(keys[i]->type()->isSubtypeOf(*key_type));
1844
    AT_ASSERT(values[i]->type()->isSubtypeOf(*value_type));
1845

1846
    n->addInput(keys[i]);
1847
    n->addInput(values[i]);
1848
  }
1849
  n->output()->setType(DictType::create(key_type, value_type));
1850
  return n;
1851
}
1852

1853
Node* Graph::createNumToTensor(Value* value) {
1854
  Node* result = create(prim::NumToTensor, {value});
1855
  result->output()->setType(TensorType::fromNumberType(*value->type()));
1856
  return result;
1857
}
1858

1859
Node* Graph::createObject(const ClassTypePtr& type) {
1860
  auto result = create(prim::CreateObject);
1861
  result->output()->setType(type);
1862
  return result;
1863
}
1864

1865
Node* Graph::createSetAttr(
1866
    Value* obj,
1867
    const std::string& field,
1868
    Value* newValue) {
1869
  auto n = create(prim::SetAttr, {obj, newValue}, /*num_outputs=*/0);
1870
  n->s_(attr::name, field);
1871
  return n;
1872
}
1873

1874
Node* Graph::createGetAttr(Value* obj, const std::string& field) {
1875
  const auto classType = obj->type()->expect<ClassType>();
1876

1877
  auto n = create(prim::GetAttr, {obj}, /*num_outputs=*/1);
1878
  n->s_(attr::name, field);
1879

1880
  const auto outputType = classType->getAttribute(field);
1881
  n->output()->setType(outputType);
1882
  n->output()->setDebugName(normalizeAttrName(field));
1883
  return n;
1884
}
1885

1886
Node* Graph::createStore(const std::string& name, Value* v) {
1887
  auto n = create(prim::Store, {v}, /*num_outputs*/ 0);
1888
  n->s_(attr::name, name);
1889
  return n;
1890
}
1891

1892
Node* Graph::createLoad(const std::string& name, const TypePtr& type) {
1893
  auto n = create(prim::Load, {}, /*num_outputs*/ 1);
1894
  n->s_(attr::name, name);
1895
  n->output()->setType(type);
1896
  return n;
1897
}
1898

1899
Node* Graph::createIsInstance(Value* v, at::ArrayRef<TypePtr> types) {
1900
  auto n = create(prim::isinstance, {v}, /*num_outputs*/ 1);
1901
  n->tys_(attr::types, types.vec());
1902
  n->output()->setType(BoolType::get());
1903
  return n;
1904
}
1905
Value* Graph::insertUncheckedCast(Value* v, TypePtr type) {
1906
  Node* n = insertNode(create(prim::unchecked_cast, {v}));
1907
  n->output()->setType(std::move(type));
1908
  return n->output();
1909
}
1910

1911
Value* Graph::insertToList(Value* v, TypePtr type) {
1912
  int dim = 0;
1913
  TypePtr ptr = type;
1914

1915
  // Unwrap the type to determine the number of dimensions.
1916
  while (auto list_type = ptr->cast<ListType>()) {
1917
    ptr = list_type->getElementType();
1918
    ++dim;
1919
  }
1920

1921
  // Encode the base element type as an integer.
1922
  int elem_ty = 0;
1923
  if (ptr == IntType::get()) {
1924
    elem_ty = 0;
1925
  } else if (ptr == FloatType::get()) {
1926
    elem_ty = 1;
1927
  } else if (ptr == BoolType::get()) {
1928
    elem_ty = 2;
1929
  } else if (ptr == ComplexType::get()) {
1930
    elem_ty = 3;
1931
  } else {
1932
    TORCH_CHECK(
1933
        false,
1934
        ptr->repr_str(),
1935
        " is not one of the supported element types for tolist: int, float, complex, bool");
1936
  }
1937

1938
  // Pass in the number of dimensions and base element type as arguments
1939
  // to the op.
1940
  Value* dim_val = insertConstant(IValue(dim));
1941
  Value* elem_ty_val = insertConstant(IValue(elem_ty));
1942
  Node* n = insertNode(create(prim::tolist, {v, dim_val, elem_ty_val}));
1943
  n->output()->setType(std::move(type));
1944
  return n->output();
1945
}
1946

1947
Value* Graph::insertFunctionCall(
1948
    Function* callee,
1949
    const MatchedSchema& matched) {
1950
  std::string func_name = callee->name();
1951
  Value* fn_constant = insertNode(create(prim::Constant))
1952
                           ->s_(attr::name, func_name)
1953
                           ->output()
1954
                           ->setType(FunctionType::create(callee));
1955
  std::vector<Value*> inputs = {fn_constant};
1956
  inputs.insert(inputs.end(), matched.inputs.begin(), matched.inputs.end());
1957
  Value* result = insertNode(create(prim::CallFunction, inputs))
1958
                      ->output()
1959
                      ->setType(matched.return_types.at(0));
1960
  return result;
1961
}
1962

1963
Value* Graph::insertMethodCall(
1964
    std::string method_name,
1965
    const MatchedSchema& matched) {
1966
  Value* result = insertNode(create(prim::CallMethod, matched.inputs))
1967
                      ->s_(attr::name, std::move(method_name))
1968
                      ->output()
1969
                      ->setType(matched.return_types.at(0));
1970
  return result;
1971
}
1972

1973
Node* Graph::createClone(
1974
    Node* n,
1975
    const std::function<Value*(Value*)>& value_map,
1976
    bool copy_blocks) {
1977
  // n can be from a different graph
1978
  Node* r = n->allocNewInstance(this);
1979
  for (auto o : n->outputs()) {
1980
    r->addOutput()->copyMetadata(o);
1981
  }
1982
  r->cloneFrom(n);
1983
  for (auto i : n->inputs()) {
1984
    r->addInput(value_map(i));
1985
  }
1986
  if (copy_blocks) {
1987
    for (auto b : n->blocks()) {
1988
      r->addBlock()->cloneFrom(b, value_map);
1989
    }
1990
  }
1991
  return r;
1992
}
1993

1994
Value* Graph::insertConstant(
1995
    const IValue& val,
1996
    c10::optional<SourceRange> loc,
1997
    c10::optional<ScopePtr> scope) {
1998
  return jit::insertConstant(*this, val, std::move(loc), std::move(scope));
1999
}
2000

2001
std::string Graph::toString(bool print_source_locations) const {
2002
  std::ostringstream oss;
2003
  print(oss, print_source_locations);
2004
  return oss.str();
2005
}
2006

2007
Graph::~Graph() {
2008
  for (const Node* n : all_nodes) {
2009
    delete n;
2010
  }
2011
  for (const Value* v : all_values) {
2012
    delete v;
2013
  }
2014
  for (const Block* b : all_blocks) {
2015
    delete b;
2016
  }
2017
}
2018

2019
void Graph::freeNode(Node* n) {
2020
  auto it = all_nodes.find(n);
2021
  AT_ASSERT(it != all_nodes.end());
2022
  delete *it;
2023
  all_nodes.erase(it);
2024
}
2025
void Graph::freeValue(Value* v) {
2026
  v->setDebugName("");
2027
  auto it = all_values.find(v);
2028
  AT_ASSERT(it != all_values.end());
2029
  delete *it;
2030
  all_values.erase(it);
2031
}
2032
void Graph::freeBlock(Block* b) {
2033
  auto it = all_blocks.find(b);
2034
  AT_ASSERT(it != all_blocks.end());
2035
  delete *it;
2036
  all_blocks.erase(it);
2037
}
2038

2039
at::ArrayRef<Value*> createTupleUnpack(Value* v) {
2040
  // small peephole optimization to ensure IntArrayRef attributes can still turn
2041
  // into constants e.g. in x.expand([3, 4])
2042
  if (v->node()->kind() == prim::TupleConstruct) {
2043
    return v->node()->inputs();
2044
  }
2045
  auto& g = *v->owningGraph();
2046
  return g.insertNode(g.createTupleUnpack(v))->outputs();
2047
}
2048

2049
void inlineCallStackOfNode(
2050
    Node* n,
2051
    std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
2052
    Function* callee,
2053
    Node* to_replace,
2054
    c10::optional<ModuleInstanceInfo> m_info);
2055

2056
static void inlineCallStackOfBlock(
2057
    Block* b,
2058
    std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
2059
    Function* callee,
2060
    Node* to_replace,
2061
    c10::optional<ModuleInstanceInfo> m_info) {
2062
  for (auto n : b->nodes()) {
2063
    inlineCallStackOfNode(n, new_cs_entries, callee, to_replace, m_info);
2064
  }
2065
}
2066

2067
void inlineCallStackOfNode(
2068
    Node* new_node,
2069
    std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
2070
    Function* callee,
2071
    Node* to_replace,
2072
    c10::optional<ModuleInstanceInfo> m_info) {
2073
  auto new_node_cs = new_node->callstack();
2074

2075
  InlinedCallStack* raw_callstack_ptr =
2076
      new_node_cs ? new_node_cs->get() : nullptr;
2077

2078
  if (!new_cs_entries.count(raw_callstack_ptr)) {
2079
    if (new_node_cs) {
2080
      new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
2081
          *new_node_cs, callee, to_replace->sourceRange(), m_info);
2082
    } else {
2083
      new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
2084
          callee, to_replace->sourceRange(), m_info);
2085
    }
2086
  }
2087
  new_node->setCallStack(new_cs_entries.at(raw_callstack_ptr));
2088
  // We updated the inlined callstack of new_node.
2089
  // Same must be done for the nodes of the blocks of new_node.
2090
  // For example If node's block otherwise is not annotated appropriately.
2091
  for (auto block : new_node->blocks()) {
2092
    inlineCallStackOfBlock(block, new_cs_entries, callee, to_replace, m_info);
2093
  }
2094
}
2095

2096
std::vector<Value*> inlineCallTo(
2097
    Node* to_replace,
2098
    GraphFunction* callee,
2099
    Graph* callee_graph) {
2100
  WithInsertPoint guard(to_replace);
2101
  std::unordered_map<Value*, Value*> value_map;
2102
  std::vector<torch::jit::Value*> new_outputs = insertGraph(
2103
      *to_replace->owningGraph(),
2104
      *callee_graph,
2105
      to_replace->inputs(),
2106
      value_map);
2107

2108
  std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>
2109
      new_callstack_entries;
2110

2111
  c10::optional<ModuleInstanceInfo> module_instance_info = c10::nullopt;
2112
  if (to_replace->kind() == prim::CallMethod) {
2113
    auto class_type_ptr = to_replace->input(0)->type()->cast<c10::ClassType>();
2114
    if (to_replace->input(0)->node()->kind() == prim::GetAttr) {
2115
      module_instance_info = c10::make_optional(ModuleInstanceInfo(
2116
          class_type_ptr, to_replace->input(0)->node()->s(attr::name)));
2117
    } else if (
2118
        !to_replace->owningGraph()->inputs().empty() &&
2119
        to_replace->input(0) == to_replace->owningGraph()->inputs()[0]) {
2120
      // This CallMethod must correspond to method of the same object
2121
      // to which this graph belongs.
2122
      module_instance_info =
2123
          c10::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF"));
2124
    } else {
2125
      // Not sure if it is possible to come here ever.
2126
      // TODO: Remove this else. Or add assert
2127
      module_instance_info = c10::make_optional(
2128
          ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN"));
2129
    }
2130
  }
2131

2132
  // TODO: We might need to use nodes_map instead of value_map. Otherwise, we
2133
  // are missing nodes without outputs (e.g. prim::Print).
2134
  std::unordered_set<Node*> updated_nodes;
2135
  for (const auto& kv : value_map) {
2136
    /* Skip the old value if it is the graph input.
2137
     * The reason is that, value_map contains values not all for the nodes of
2138
     * the graph but primary inputs as well, and it will create duplicates when
2139
     * the first inlined graph is input to the next one. To avoid this issue,
2140
     * skip the old value when it is one of the
2141
     * callee->optimized_graph()->inputs() or callee->graph()->inputs(), depends
2142
     * on if it is inlined_optimized_graph
2143
     */
2144
    auto is_graph_input = std::find(
2145
        callee_graph->inputs().begin(), callee_graph->inputs().end(), kv.first);
2146
    if (is_graph_input != callee_graph->inputs().end()) {
2147
      continue;
2148
    }
2149

2150
    Node* new_node = kv.second->node();
2151
    if (!updated_nodes.insert(new_node).second) {
2152
      continue;
2153
    }
2154

2155
    inlineCallStackOfNode(
2156
        new_node,
2157
        new_callstack_entries,
2158
        callee,
2159
        to_replace,
2160
        module_instance_info);
2161
  }
2162
  const auto& old_outputs = to_replace->outputs();
2163

2164
  AT_ASSERT(new_outputs.size() == old_outputs.size());
2165
  for (const auto i : c10::irange(old_outputs.size())) {
2166
    if (old_outputs[i]->hasDebugName()) {
2167
      new_outputs[i]->setDebugName(old_outputs[i]->debugName());
2168
    }
2169
    old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
2170
  }
2171
  to_replace->destroy();
2172

2173
  return new_outputs;
2174
}
2175

2176
// inline_optimized_graph argument is used in substitute function call for
2177
// ONNX conversion
2178
std::vector<Value*> inlineCallTo(
2179
    Node* to_replace,
2180
    GraphFunction* callee,
2181
    bool inline_optimized_graph /*=true*/) {
2182
  auto graph =
2183
      inline_optimized_graph ? callee->optimized_graph() : callee->graph();
2184
  return inlineCallTo(to_replace, callee, graph.get());
2185
}
2186

2187
std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs) {
2188
  std::vector<Value*> new_outputs;
2189
  if (outputs.size() != 1 || outputs.at(0)->type()->kind() != TupleType::Kind) {
2190
    return outputs;
2191
  }
2192

2193
  auto tup = outputs[0];
2194
  for (Value* v : createTupleUnpack(tup)) {
2195
    new_outputs.emplace_back(v);
2196
  }
2197
  // if this was a peephole tuple unpack we can just get rid of
2198
  // the tuple construct here and prevent needing DCE
2199
  if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
2200
    tup->node()->destroy();
2201
  }
2202
  return new_outputs;
2203
}
2204

2205
std::vector<Node*> findAllNodes(
2206
    at::ArrayRef<Block*> array,
2207
    Symbol kind,
2208
    bool recurse) {
2209
  std::vector<Node*> ret;
2210
  for (auto block : array) {
2211
    findAllNodes(*block, kind, recurse, ret);
2212
  }
2213
  return ret;
2214
}
2215

2216
std::vector<Node*> findAllNodes(Block& block, Symbol kind, bool recurse) {
2217
  return findAllNodes({&block}, kind, recurse);
2218
}
2219

2220
std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse) {
2221
  return findAllNodes(*g.block(), kind, recurse);
2222
}
2223

2224
std::vector<Value*> insertGraph(
2225
    Graph& g,
2226
    Graph& callee,
2227
    ArrayRef<Value*> inputs,
2228
    std::unordered_map<Value*, Value*>& value_map) {
2229
  auto value_map_func = [&](Value* v) { return value_map.at(v); };
2230
  AT_ASSERT(callee.inputs().size() == inputs.size());
2231
  for (const auto i : c10::irange(inputs.size())) {
2232
    value_map[callee.inputs()[i]] = inputs[i];
2233
  }
2234
  for (auto* node : callee.nodes()) {
2235
    auto* new_node = g.insertNode(g.createClone(node, value_map_func));
2236
    for (size_t i = 0; i < node->outputs().size(); ++i) {
2237
      value_map[node->outputs()[i]] = new_node->outputs()[i];
2238
    }
2239
  }
2240

2241
  std::vector<Value*> outputs;
2242
  for (auto* output : callee.outputs()) {
2243
    outputs.push_back(value_map_func(output));
2244
  }
2245

2246
  return outputs;
2247
}
2248

2249
std::vector<Value*> insertGraph(
2250
    Graph& g,
2251
    Graph& callee,
2252
    ArrayRef<Value*> inputs) {
2253
  std::unordered_map<Value*, Value*> value_map;
2254
  return insertGraph(g, callee, inputs, value_map);
2255
}
2256

2257
void ProfileOp::cloneFrom(Node* other_) {
2258
  Node::cloneFrom(other_);
2259
  auto other = other_->cast<ProfileOp>();
2260
  this->callback_ = other->getCallback();
2261
}
2262

2263
Node* ProfileOp::allocNewInstance(Graph* g) {
2264
  return new ProfileOp(g, {nullptr});
2265
}
2266

2267
void ProfileIValueOp::cloneFrom(Node* other_) {
2268
  Node::cloneFrom(other_);
2269
  auto other = other_->cast<ProfileIValueOp>();
2270
  this->callback_ = other->getCallback();
2271
}
2272

2273
Node* ProfileIValueOp::allocNewInstance(Graph* g) {
2274
  return new ProfileIValueOp(g, {nullptr});
2275
}
2276

2277
TypePtr NamedValue::type() const {
2278
  if (value_) {
2279
    return value_->type();
2280
  } else {
2281
    return ivalue_.type();
2282
  }
2283
}
2284

2285
const Symbol ProfileOp::Kind = ::c10::prim::profile;
2286
const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue;
2287

2288
OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
2289
  insert(sig_literals);
2290
}
2291

2292
std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const {
2293
  std::vector<std::shared_ptr<Operator>> result;
2294
  for (const auto& kv : ops) {
2295
    auto ops_for_symbol = kv.second;
2296
    result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end());
2297
  }
2298
  return result;
2299
}
2300

2301
void OperatorSet::insert(std::initializer_list<const char*> sig_literals) {
2302
  for (const char* sig : sig_literals) {
2303
    auto op = getOperatorForLiteral(sig);
2304
    ops[Symbol::fromQualString(op->schema().name())].push_back(op);
2305
  }
2306
}
2307

2308
bool Node::isMemberOf(const OperatorSet& os) const {
2309
  auto it = os.ops.find(kind());
2310
  if (it == os.ops.end()) {
2311
    return false;
2312
  }
2313
  for (auto& op : it->second) {
2314
    if (matches(op->schema())) {
2315
      return true;
2316
    }
2317
  }
2318
  return false;
2319
}
2320

2321
} // namespace torch::jit
2322

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.