1
#include <torch/csrc/jit/ir/ir.h>
3
#include <ATen/core/builtin_function.h>
4
#include <ATen/core/function.h>
5
#include <c10/util/Exception.h>
6
#include <c10/util/StringUtil.h>
7
#include <c10/util/irange.h>
8
#include <torch/csrc/jit/api/function_impl.h>
9
#include <torch/csrc/jit/frontend/error_report.h>
10
#include <torch/csrc/jit/frontend/schema_matching.h>
11
#include <torch/csrc/jit/ir/constants.h>
12
#include <torch/csrc/jit/runtime/operator.h>
13
#include <torch/csrc/jit/serialization/python_print.h>
22
#include <unordered_map>
23
#include <unordered_set>
29
std::string getNodesModuleHierarchy(const Node& n) {
30
if (!n.callstack().has_value()) {
33
InlinedCallStackPtr callstack_ptr = n.callstack().value();
34
std::string module_hierarchy;
35
for (auto& entry : callstack_ptr->vec()) {
36
const auto& opt_module_info = std::get<kModuleInstanceInfo>(entry);
37
if (opt_module_info.has_value()) {
38
const auto& module_instance_info = opt_module_info.value();
39
if (!module_hierarchy.empty()) {
40
module_hierarchy.append(".");
42
module_hierarchy.append(utils::get_module_info(module_instance_info));
44
module_hierarchy += ".UNKNOWN_INSTANCE(UNKNOWN_TYPE)";
47
return module_hierarchy;
56
constexpr topo_position_t kLowerBound = INT64_MIN;
57
constexpr topo_position_t kUpperBound = INT64_MAX;
58
constexpr topo_position_t kMidPoint = 0;
64
constexpr topo_position_t kAppendInterval = 1099511627776ULL ;
66
void printValueRef(std::ostream& out, const Value* n) {
67
out << "%" << n->debugName();
70
bool isNumber(c10::string_view str) {
71
return str.find_first_not_of("0123456789") == std::string::npos;
74
std::string normalizeAttrName(c10::string_view field) {
75
if (isNumber(field)) {
76
return "_" + std::string{field};
78
return std::string{field};
85
std::vector<Node*>& ret) {
86
for (Node* n : block.nodes()) {
87
if (n->kind() == kind) {
91
for (auto b : n->blocks()) {
92
findAllNodes(*b, kind, recurse, ret);
103
std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) {
104
out << at::ArrayRef<T>{nodes};
109
static std::ostream& printValueRefs(
111
const at::ArrayRef<T> nodes) {
113
for (auto n : nodes) {
117
printValueRef(out, n);
125
static std::ostream& operator<<(
127
const at::ArrayRef<const Value*> nodes) {
128
return printValueRefs(out, nodes);
131
static std::ostream& operator<<(
133
const at::ArrayRef<Value*> nodes) {
134
return printValueRefs(out, nodes);
137
struct const_value_list_with_types {
138
const ArrayRef<const Value*> values;
140
const_value_list_with_types(
141
ArrayRef<const Value*> values,
142
std::string delim_ = ", ")
143
: values(values), delim(std::move(delim_)) {}
146
static std::ostream& operator<<(
148
const const_value_list_with_types& l) {
150
for (auto n : l.values) {
154
printValueRef(out, n);
155
if (c10::type_verbosity() >= c10::TypeVerbosity::Type) {
163
static void printAttribute(std::ostream& out, const at::Tensor& tensor) {
165
if (tensor.numel() == 1) {
166
auto scalar_tensor = tensor.view(std::vector<int64_t>{}).item();
168
if (scalar_tensor.isFloatingPoint()) {
169
out << scalar_tensor.toDouble();
170
} else if (scalar_tensor.isComplex()) {
171
out << scalar_tensor.toComplexDouble();
173
out << scalar_tensor.toLong();
176
} else if (tensor.numel() <= max_tensor_display_size) {
178
std::ostringstream tensor_ss;
180
std::string tensor_s{tensor_ss.str()};
182
std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
189
static void printAttribute(std::ostream& out, const IValue& ival) {
190
const auto customFormatter = [](std::ostream& ss, const IValue& input) {
191
if (input.isTensor()) {
192
printAttribute(ss, input.toTensor());
194
} else if (input.isTensorList()) {
197
} else if (input.isObject() && !input.type()->is_module()) {
198
ss << "object(" << &input.toObjectRef() << ")";
203
ival.repr(out, customFormatter);
206
static void printTypeList(
208
const std::vector<TypePtr>& items) {
211
for (auto& item : items) {
219
void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
220
switch (kindOf(name)) {
221
case AttributeKind::c:
222
printAttribute(out, c(name));
224
case AttributeKind::cs:
228
case AttributeKind::f:
229
printAttribute(out, f(name));
231
case AttributeKind::fs:
232
printAttribute(out, fs(name));
234
case AttributeKind::i:
235
printAttribute(out, i(name));
237
case AttributeKind::is:
238
printAttribute(out, is(name));
240
case AttributeKind::s:
241
printAttribute(out, s(name));
243
case AttributeKind::ss:
244
printAttribute(out, ss(name));
246
case AttributeKind::t:
247
printAttribute(out, t(name));
249
case AttributeKind::ts:
250
out << "[<Tensors>]";
252
case AttributeKind::ival:
253
printAttribute(out, ival(name));
255
case AttributeKind::g:
258
case AttributeKind::gs:
261
case AttributeKind::ty:
264
case AttributeKind::tys:
265
printTypeList(out, tys(name));
270
void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
273
auto names = attributeNames();
275
for (auto name : names) {
276
if (ignore_subgraph && name == attr::Subgraph) {
286
out << name.toUnqualString() << "=";
288
printAttrValue(out, name);
293
SourceRange Node::sourceRange() const {
295
return *source_range_;
297
return SourceRange();
300
static std::ostream& indent(std::ostream& out, size_t level) {
301
for (const auto i : c10::irange(level)) {
308
std::ostream& Node::print(
311
std::vector<const Node*>* groups,
312
bool print_source_locations,
313
bool print_attributes,
315
bool print_body) const {
316
auto outs = outputs();
317
indent(out, level) << const_value_list_with_types(outs);
319
if (kind() == prim::PythonOp) {
320
auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this);
321
out << "^" << pyOp->name();
322
printAttributes(out, false);
323
pyOp->writeScalars(out);
324
} else if (hasAttribute(attr::Subgraph) && groups) {
325
out << kind().toQualString() << "_" << groups->size();
326
if (print_attributes && numAttributes() > 1 &&
327
kind() != prim::DifferentiableGraph) {
328
printAttributes(out, true);
331
groups->push_back(this);
333
out << kind().toQualString();
334
if (print_attributes && hasAttributes()) {
335
printAttributes(out);
338
out << "(" << inputs() << ")";
341
std::string scName = scopeName();
342
if (!scName.empty()) {
344
out << "scope: " << scName;
349
if (print_source_locations) {
350
SourceRange r = sourceRange();
351
if (sourceRange().source()) {
352
if (auto orig = sourceRange().source()->findSourceRangeThatGenerated(r)) {
356
if (auto file_line_col = r.file_line_col()) {
357
auto [filename, line, col] = *file_line_col;
358
out << " # " << filename << ":" << line << ":" << col;
368
for (const auto i : c10::irange(blocks().size())) {
369
auto b = blocks()[i];
370
indent(out, level + 1) << "block" << i << "("
371
<< const_value_list_with_types(b->inputs())
373
for (auto nested : b->nodes()) {
374
nested->print(out, level + 2, groups);
376
indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
382
std::ostream& operator<<(std::ostream& out, const Node& n) {
383
return n.print(out, 0, nullptr);
386
std::ostream& Graph::print(std::ostream& out, bool print_source_locations)
388
out << "graph(" << const_value_list_with_types(inputs(), ",\n ")
390
std::vector<const Node*> groups;
391
for (auto n : nodes()) {
392
n->print(out, 1, &groups, print_source_locations);
394
out << " return (" << outputs() << ")\n";
396
for (auto fg : groups) {
397
out << "with " << fg->kind().toQualString() << "_" << i++ << " = "
398
<< *fg->g(attr::Subgraph);
415
std::ostream& operator<<(std::ostream& out, const Graph& g) {
416
return g.print(out, true);
419
static void checkSameDevice(const Node* node) {
420
bool has_device = false;
421
c10::optional<at::Device> device = c10::nullopt;
422
auto checkValue = [&](const Value* v) {
423
if (TensorTypePtr type = v->type()->cast<TensorType>()) {
424
if (type->device() && !has_device) {
426
device = *type->device();
428
AT_ASSERT(device == type->device());
432
for (auto input : node->inputs()) {
435
for (auto output : node->outputs()) {
440
using node_set = std::set<const Node*>;
441
#define ALL_OF(container) container.begin(), container.end()
450
void Node::lint() const {
463
for (auto input : inputs_) {
467
std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) !=
469
AT_ASSERT(graph_->all_nodes.count(this) == 1);
474
for (auto o : outputs()) {
475
for (auto use : o->uses()) {
479
AT_ASSERT(use.user->inputs_[use.offset] == o);
486
AT_ASSERT(inputs_.empty());
490
AT_ASSERT(outputs().empty());
494
AT_ASSERT(inputs_.empty());
496
case prim::PythonOp: {
498
auto* value = static_cast<const PythonOp*>(this);
499
value->lint_python();
507
case prim::FusionGroup:
508
case prim::CudaFusionGroup:
509
case prim::oneDNNFusionGroup:
510
checkSameDevice(this);
512
g(attr::Subgraph)->lint();
519
void Graph::lint() const {
534
LintScope() = default;
535
LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {}
536
bool contains(const Value* v) {
537
return values.count(v) > 0 || (parent && parent->contains(v));
539
bool contains(const Node* n) {
540
return nodes.count(n) > 0 || (parent && parent->contains(n));
542
void insert(const Value* v) {
543
AT_ASSERT(!contains(v));
546
void insert(const Node* n) {
547
AT_ASSERT(!contains(n));
551
std::unique_ptr<LintScope> parent;
554
std::unordered_set<const Value*> values;
555
std::unordered_set<const Node*> nodes;
560
LintImpl(const Graph& g)
562
scope(new LintScope()),
563
all_nodes_set(ALL_OF(g.all_nodes)) {}
565
std::unique_ptr<LintScope> scope;
566
std::unordered_set<size_t> seen_uniques;
567
std::unordered_map<const Node*, int64_t> anticipated_uses;
568
node_set all_nodes_set;
571
void check_value(const Value* v) {
573
auto b2 = seen_uniques.insert(v->unique());
574
AT_ASSERT(b2.second);
575
AT_ASSERT(v->unique() < g.next_unique_);
577
for (auto use : v->uses()) {
578
AT_ASSERT(!scope->contains(use.user));
579
AT_ASSERT(g.all_nodes.count(use.user) == 1);
580
anticipated_uses[use.user]++;
583
void check_node(const Node* n) {
584
for (auto input : n->inputs_) {
585
if (!scope->contains(input)) {
586
AT_ASSERTM(0, input->unique(), " not in scope");
589
AT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
590
anticipated_uses[n] = -1;
592
for (auto block : n->blocks()) {
593
std::unique_ptr<LintScope> new_scope(new LintScope(std::move(scope)));
594
scope = std::move(new_scope);
596
scope = std::move(scope->parent);
599
for (auto o : n->outputs()) {
600
AT_ASSERT(o->node() == n);
601
AT_ASSERT(i++ == o->offset_);
606
void check_block(const Block* b) {
608
AT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
609
auto curNode = *b->nodes().begin();
610
while (curNode != b->return_node()) {
611
AT_ASSERT(curNode->isBefore(curNode->next()));
612
curNode = curNode->next();
615
for (auto input : b->inputs()) {
617
AT_ASSERT(input->node()->kind_ == prim::Param);
620
for (auto n : b->nodes()) {
621
AT_ASSERT(n->kind_ != prim::Param);
622
AT_ASSERT(n->kind_ != prim::Return);
626
AT_ASSERT(b->output_->kind() == prim::Return);
627
check_node(b->output_);
635
node_set nodes_set(ALL_OF(b->nodes()));
636
node_set inputs_set{b->input_};
637
node_set output_set{b->output_};
640
AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set)));
641
AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set)));
642
AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set)));
644
sum_set.insert(ALL_OF(nodes_set));
645
sum_set.insert(ALL_OF(inputs_set));
646
sum_set.insert(ALL_OF(output_set));
649
node_set all_nodes_set(
650
ALL_OF(g.all_nodes));
652
check_block(g.block_);
653
for (auto kv : anticipated_uses) {
654
AT_ASSERT(kv.second == -1);
656
AT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set)));
659
LintImpl(*this).check_graph();
662
void Graph::dump() const {
663
std::cout << *this << "\n";
666
void Graph::push_scope(const std::string& scope_name) {
667
current_scope_ = current_scope_->push(Symbol::scope(scope_name));
668
Node* block_node = insertNode(create(prim::TracedModuleForward, 0));
669
block_node->s_(attr::scope, scope_name);
670
Block* b = block_node->addBlock();
673
void Graph::pop_scope() {
674
current_scope_ = current_scope_->parent();
675
if (insertPoint()->owningBlock()->owningNode()->kind() ==
676
prim::TracedModuleForward) {
677
setInsertPoint(insertPoint()->owningBlock()->owningNode()->next());
681
void LintGraph(const std::shared_ptr<Graph>& graph) {
685
Block::Block(Graph* graph_, Node* node_)
687
output_(graph_->create(prim::Return, 0)),
688
input_(graph_->create(prim::Param, 0)),
689
owning_node_(node_) {
690
input_->next() = output_;
691
input_->prev() = output_;
692
output_->next() = input_;
693
output_->prev() = input_;
695
graph_->all_blocks.emplace(this);
696
output_->owning_block_ = this;
697
output_->topo_position_ = kUpperBound;
698
input_->owning_block_ = this;
699
input_->topo_position_ = kLowerBound;
702
void Block::reIndexTopology() {
703
auto curPos = kLowerBound;
704
for (auto node : nodes()) {
705
AT_ASSERT(curPos <= (kUpperBound - kAppendInterval));
706
curPos += kAppendInterval;
707
node->topo_position_ = curPos;
711
void Block::cloneFrom(Block* src, std::function<Value*(Value*)> value_map) {
712
std::unordered_map<Value*, Value*> local_map;
713
auto env = [&](Value* v) {
714
auto it = local_map.find(v);
715
if (it != local_map.end()) {
721
auto graph = owningGraph();
722
for (auto input : src->inputs()) {
723
local_map[input] = this->addInput()->copyMetadata(input);
726
for (auto node : src->nodes()) {
727
auto new_node = this->appendNode(graph->createClone(node, env));
728
for (size_t i = 0; i < node->outputs().size(); ++i) {
729
auto oo = node->outputs()[i];
730
auto no = new_node->outputs()[i];
732
no->copyMetadata(oo);
735
for (auto output : src->outputs()) {
736
this->registerOutput(env(output));
740
void Block::destroy() {
743
output_->removeAllInputs();
744
for (auto it = this->nodes().reverse().begin(),
745
end = this->nodes().reverse().end();
752
graph_->freeBlock(this);
755
void Graph::cloneFrom(Graph& src) {
756
auto env = [](Value* v) -> Value* {
758
"Graph::copy() encountered a use of a value " + v->debugName() +
759
" not in scope. Run lint!");
761
block()->cloneFrom(src.block(), env);
764
std::shared_ptr<Graph> Graph::copy() {
765
auto new_g = std::make_shared<Graph>();
766
new_g->cloneFrom(*this);
770
std::unique_ptr<Graph> Graph::copyUnique() {
771
auto new_g = std::make_unique<Graph>();
772
new_g->cloneFrom(*this);
776
void Block::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
777
for (Value* input : inputs()) {
778
input->setType(type_map(input->type()));
780
for (Node* node : nodes()) {
781
for (Value* output : node->outputs()) {
782
output->setType(type_map(output->type()));
784
for (Block* sub_block : node->blocks()) {
785
sub_block->remapTypes(type_map);
787
for (Symbol name : node->attributeNames()) {
788
if (node->kindOf(name) == AttributeKind::g) {
789
node->g(name)->remapTypes(type_map);
790
} else if (node->kindOf(name) == AttributeKind::gs) {
791
for (const auto& g : node->gs(name)) {
792
g->remapTypes(type_map);
799
void Graph::remapTypes(const std::function<TypePtr(TypePtr)>& type_map) {
800
block()->remapTypes(type_map);
803
void Value::inferTypeFrom(const at::Tensor& output) {
804
setType(TensorType::create(output));
807
void Value::inferTypeFrom(
808
const c10::intrusive_ptr<c10::ivalue::Object>& output) {
809
setType(output->type());
812
bool Value::mustBeNone() const {
813
return type()->cast<NoneType>() || node_->mustBeNone();
815
bool Value::mustNotBeNone() const {
816
return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
817
!type()->cast<OptionalType>() &&
818
!(type()->cast<UnionType>() &&
819
type()->expect<UnionType>()->canHoldType(*NoneType::get()));
822
std::string Value::debugNameBase() const {
823
std::string name = debugName();
824
std::string name_base = name;
825
auto last_dot_pos = name.find_last_of('.');
826
if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
827
if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
829
name_base = name.substr(0, last_dot_pos);
835
bool Value::isValidName(const std::string& name) {
842
if (isNumber(name)) {
849
Value* Value::setDebugName(const std::string& name) {
850
if (!isValidName(name)) {
851
throw std::runtime_error("Invalid name: '" + name + "'");
854
auto& names = node()->owningGraph()->unique_names_;
857
if (hasDebugName()) {
858
names.erase(unique_name_);
868
auto old_owner_of_name = names.find(name);
869
if (old_owner_of_name != names.end()) {
871
std::string name_base = name;
872
auto last_dot_pos = name.find_last_of('.');
873
if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
874
if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
876
suffix = std::stoll(name.substr(last_dot_pos + 1));
877
name_base = name.substr(0, last_dot_pos);
881
auto& names_suffixes = node()->owningGraph()->name_base_suffix_;
882
auto it = names_suffixes.find(name_base);
883
if (it != names_suffixes.end()) {
884
suffix = std::max(suffix, it->second + 1);
889
std::string replacement_name;
891
std::stringstream ss;
896
static std::locale c_locale("C");
899
ss << name_base << "." << suffix++;
900
replacement_name = ss.str();
901
} while (names.count(replacement_name) > 0);
903
names_suffixes[name_base] = suffix;
905
old_owner_of_name->second->setDebugName(replacement_name);
913
Value* Value::copyMetadata(Value* from) {
914
setType(from->type());
915
if (from->hasDebugName()) {
916
setDebugName(from->debugName());
921
void Value::replaceFirstUseWith(Value* newValue) {
922
AT_ASSERT(owningGraph() == newValue->owningGraph());
924
u.user->inputs_[u.offset] = newValue;
925
newValue->uses_.push_back(u);
926
uses_.erase(uses_.begin());
929
void Value::replaceAllUsesWith(Value* newValue) {
930
while (!uses().empty()) {
931
replaceFirstUseWith(newValue);
935
void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) {
936
std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
937
if (u.user->isAfter(node)) {
938
u.user->inputs_[u.offset] = newValue;
939
newValue->uses_.push_back(u);
947
[&node](const Use& u) { return u.user->isAfter(node); }),
951
void Value::replaceAllUsesDominatedByNodeWith(
954
std::for_each(uses_.begin(), uses_.end(), [&node, newValue](Use& u) {
955
if (u.user->isDominatedBy(node)) {
956
u.user->inputs_[u.offset] = newValue;
957
newValue->uses_.push_back(u);
965
[&node](const Use& u) { return u.user->isDominatedBy(node); }),
969
static size_t findArgument(
970
const FunctionSchema& the_schema,
971
const std::string& unqualName) {
972
for (const auto i : c10::irange(the_schema.arguments().size())) {
973
const Argument* arg = &the_schema.arguments()[i];
974
if (arg->name() == unqualName) {
978
throw std::runtime_error(
979
std::string("Couldn't find an argument called ") + unqualName);
982
static size_t findArgument(const FunctionSchema& the_schema, Symbol name) {
983
const auto unqualName = name.toUnqualString();
984
return findArgument(the_schema, unqualName);
987
c10::optional<IValue> Node::get(Symbol name) const {
988
return toIValue(namedInput(name));
991
bool Node::hasNamedInput(const std::string& name) const {
992
for (const auto& argument : schema().arguments()) {
993
if (argument.name() == name) {
1000
Value* Node::namedInput(const std::string& unqualName) const {
1001
return input(findArgument(schema(), unqualName));
1003
Value* Node::namedInput(Symbol name) const {
1004
return input(findArgument(schema(), name));
1007
bool Node::matches(const FunctionSchema& schema) const {
1008
if (isBlockListedSchema(schema)) {
1012
if (kind().toQualString() != schema.name()) {
1015
at::ArrayRef<const Value*> actuals = inputs();
1016
const auto& formals = schema.arguments();
1019
if (actuals.size() < formals.size()) {
1024
for (const auto i : c10::irange(formals.size())) {
1025
auto formal = formals[i].type();
1026
const MatchTypeReturn matched_type =
1027
matchTypeVariables(formal, actuals[i]->type(), type_env);
1028
if (!matched_type.success()) {
1032
TypePtr resolved = tryEvalTypeVariables(formal, type_env);
1041
if (!actuals[i]->type()->isSubtypeOf(*formal)) {
1047
if (!schema.is_vararg() && actuals.size() != formals.size()) {
1055
const char* signature_literal,
1056
at::ArrayRef<Symbol> const_inputs) const {
1057
if (!matches(getOperatorForLiteral(signature_literal)->schema())) {
1060
for (Symbol s : const_inputs) {
1061
if (!is_constant(s)) {
1068
bool Node::mustBeNone() const {
1072
kind_ == prim::AutogradZero ||
1074
(outputs().size() == 1 && output()->type() == NoneType::get()) ||
1076
(kind_ == prim::Constant && !this->hasAttributes() &&
1077
output()->type()->cast<OptionalType>());
1080
void Node::dump() const {
1081
std::cout << *this << "\n";
1084
const FunctionSchema& Node::schema() const {
1086
return op_->schema();
1088
return getOperator().schema();
1091
const FunctionSchema* Node::maybeSchema() const {
1092
if (auto op = maybeOperator()) {
1093
return &op->schema();
1098
const Operator* Node::maybeOperator() const {
1100
const auto& candidates = getAllOperatorsFor(kind());
1101
for (const auto& candidate : candidates) {
1102
if (matches(candidate->schema())) {
1103
op_ = candidate.get();
1111
const Operator& Node::getOperator() const {
1112
const Operator* maybe = maybeOperator();
1116
auto er = ErrorReport(sourceRange());
1117
er << "Schema not found for node. File a bug report.\n";
1118
er << "Node: " << *this << "\n";
1119
er << "Input types:";
1120
for (const auto i : c10::irange(inputs().size())) {
1123
er << *inputs()[i]->type();
1125
const auto& candidates = getAllOperatorsFor(kind());
1126
if (!candidates.empty()) {
1127
er << "\ncandidates were:\n";
1128
for (auto& candidate : candidates) {
1129
er << " " << candidate->schema() << "\n";
1132
er << "\nno candidates found\n";
1134
er << "within the graph:\n";
1135
er << *owningGraph() << "\n";
1139
Operation Node::getOperation() const {
1143
return getOperator().getOperation(this);
1146
bool Node::isNondeterministic() const {
1147
const auto schema = maybeSchema();
1148
if (!kind().is_aten()) {
1155
TORCH_WARN("aten Schema not found.");
1158
torch::utils::SchemaInfo schema_info(*schema);
1159
if (hasNamedInput("train")) {
1160
auto value = constant_as<bool>(namedInput("train"));
1161
if (value.has_value()) {
1162
schema_info.addArgumentValue("train", *value);
1165
return schema_info.is_nondeterministic();
1168
bool Node::hasSideEffects() const {
1170
case prim::PythonOp:
1171
case prim::IgnoredPythonOp:
1173
case prim::RaiseException:
1176
case aten::manual_seed:
1177
case prim::AddStatValue:
1178
case prim::TimePoint:
1179
case prim::CallFunction:
1180
case prim::CallMethod:
1181
case prim::BailoutTemplate:
1183
case prim::rpc_async:
1184
case prim::rpc_sync:
1185
case prim::rpc_remote:
1187
#if !defined(USE_ROCM)
1188
case cuda::set_stream:
1189
case cuda::_set_device:
1190
case cuda::_current_device:
1191
case cuda::synchronize:
1198
auto op = maybeOperator();
1200
TORCH_INTERNAL_ASSERT(
1202
"Only prim ops are allowed to not have a registered operator but ",
1203
kind_.toDisplayString(),
1204
" doesn't have one either. We don't know if this op has side effects.");
1208
if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) {
1215
TORCH_INTERNAL_ASSERT(
1216
op->aliasAnalysisKind() == AliasAnalysisKind::INTERNAL_SPECIAL_CASE ||
1217
op->aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA ||
1218
op->aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE,
1219
"aten:: and prim:: ops should have AliasAnalysisKind::INTERNAL_SPECIAL_CASE"
1220
", AliasAnalysisKind::FROM_SCHEMA or AliasAnalysisKind::CONSERVATIVE but ",
1221
kind_.toDisplayString(),
1223
toString(op->aliasAnalysisKind()));
1226
switch (op->aliasAnalysisKind()) {
1227
case AliasAnalysisKind::PURE_FUNCTION:
1228
case AliasAnalysisKind::FROM_SCHEMA:
1229
case AliasAnalysisKind::INTERNAL_SPECIAL_CASE:
1231
case AliasAnalysisKind::CONSERVATIVE:
1234
TORCH_INTERNAL_ASSERT(false, "Unhandled AliasAnalysisKind case");
1250
void Node::assignTopoPosition() {
1251
bool is_first = prev() == owningBlock()->param_node();
1252
bool is_last = next() == owningBlock()->return_node();
1254
const auto prevPos = prev()->topo_position_;
1255
const auto nextPos = next()->topo_position_;
1261
topo_position_ = kMidPoint;
1265
if (prevPos >= (kUpperBound - kAppendInterval)) {
1267
owningBlock()->reIndexTopology();
1271
topo_position_ = prevPos + kAppendInterval;
1274
} else if (is_first) {
1276
if (nextPos <= (kLowerBound + kAppendInterval)) {
1278
owningBlock()->reIndexTopology();
1281
topo_position_ = nextPos - kAppendInterval;
1285
int64_t remaining = nextPos - prevPos;
1286
AT_ASSERT(remaining > 0);
1287
if (remaining == 1) {
1289
owningBlock()->reIndexTopology();
1292
int64_t predicted_future_insertions = 0;
1293
if (next() == graph_->insertPoint()) {
1294
predicted_future_insertions = graph_->predicted_insert_count_++;
1296
topo_position_ = prevPos +
1297
std::max(int64_t(1), remaining / (2 + predicted_future_insertions));
1298
AT_ASSERT(prevPos < topo_position_ && topo_position_ < nextPos);
1302
Node::Node(Graph* graph_, NodeKind kind_)
1305
owning_block_(nullptr),
1306
scope_(graph_->current_scope_),
1307
callstack_(c10::nullopt),
1310
graph_->all_nodes.emplace(this);
1313
void Node::eraseOutput(size_t i) {
1314
AT_ASSERT(i < outputs_.size());
1315
AT_ASSERT(outputs_[i]->uses().empty());
1317
Value* n = outputs_[i];
1318
outputs_.erase(outputs_.begin() + i);
1319
owningGraph()->freeValue(n);
1320
for (const auto j : c10::irange(i, outputs_.size())) {
1321
outputs_[j]->offset_--;
1325
Block* Node::addBlock() {
1327
blocks_.push_back(new Block(owningGraph(), this));
1328
return blocks_.back();
1331
void Node::eraseBlock(size_t i) {
1332
AT_ASSERT(i < blocks_.size());
1334
Block* n = blocks_[i];
1335
blocks_.erase(blocks_.begin() + i);
1339
void Node::destroy() {
1340
while (!outputs().empty()) {
1341
eraseOutput(outputs().size() - 1);
1343
while (!blocks().empty()) {
1344
eraseBlock(blocks().size() - 1);
1347
if (inBlockList()) {
1350
graph_->freeNode(this);
1353
void Node::cloneFrom(Node* s) {
1354
source_range_ = s->source_range_;
1355
if (s->scope_ && !s->scope_->isBlank()) {
1359
callstack_ = s->callstack_;
1362
void Node::replaceAllUsesWith(Node* n) {
1363
AT_ASSERT(outputs().size() == n->outputs().size());
1364
size_t nOutputs = outputs().size();
1365
for (const auto i : c10::irange(nOutputs)) {
1366
outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
1370
Node* Node::replaceWithNewSymbol(Symbol new_symbol) {
1371
WithInsertPoint insert_guard{this};
1372
bool had_operator = maybeOperator() != nullptr;
1373
auto graph = owningGraph();
1374
auto replace_node = graph->insertNode(graph->create(new_symbol, 0));
1375
for (Value* v : inputs()) {
1376
replace_node->addInput(v);
1378
for (Value* v : outputs()) {
1379
auto new_out = replace_node->addOutput()->copyMetadata(v);
1380
v->replaceAllUsesWith(new_out);
1382
replace_node->copyMetadata(this);
1383
replace_node->copyAttributes(*this);
1384
TORCH_INTERNAL_ASSERT(
1385
(replace_node->maybeOperator() != nullptr) == had_operator,
1386
"invalid symbol replacement:",
1389
return replace_node;
1392
bool Node::isDominatedBy(const Node* dominator) const {
1393
const Node* node = this;
1395
if (node->owningBlock() == dominator->owningBlock()) {
1396
return dominator->isBefore(node);
1398
node = node->owningBlock()->owningNode();
1403
Value* Node::insertInput(size_t i, Value* value) {
1404
AT_ASSERT(graph_ == value->owningGraph());
1410
for (const auto use_itr : c10::irange(i, inputs_.size())) {
1412
auto use = findUseForInput(use_itr);
1416
inputs_.insert(inputs_.begin() + i, value);
1418
value->uses_.emplace_back(this, i);
1422
Value* Node::addInput(Value* value) {
1423
AT_ASSERT(graph_ == value->owningGraph());
1425
value->uses_.emplace_back(this, inputs_.size());
1426
inputs_.push_back(value);
1430
Value* Node::replaceInput(size_t i, Value* newValue) {
1431
AT_ASSERT(newValue->owningGraph() == graph_);
1433
Value* old = dropInput(i);
1434
inputs_[i] = newValue;
1435
newValue->uses_.emplace_back(this, i);
1439
void Node::replaceInputWith(Value* from, Value* to) {
1440
AT_ASSERT(from->owningGraph() == graph_);
1441
AT_ASSERT(to->owningGraph() == graph_);
1444
for (auto input : inputs()) {
1445
if (input == from) {
1446
replaceInput(i, to);
1452
Value* Node::addOutput() {
1453
outputs_.push_back(new Value(this, outputs_.size()));
1455
return outputs_.back();
1458
Value* Node::insertOutput(size_t i) {
1460
outputs_.insert(outputs_.begin() + i, new Value(this, i));
1461
for (size_t itr = i + 1; itr < outputs_.size(); ++itr) {
1462
outputs_[itr]->setOffset(outputs_[itr]->offset() + 1);
1464
return outputs_.at(i);
1467
bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
1468
if (this->owningBlock() == n->owningBlock()) {
1469
if (moveSide == MoveSide::BEFORE) {
1470
return this->topo_position_ < n->topo_position_;
1473
if (moveSide == MoveSide::AFTER) {
1474
return this->topo_position_ > n->topo_position_;
1477
AT_ASSERT(this == n);
1485
AT_ASSERT(lhs->owningBlock());
1489
if (!rhs->owningBlock()) {
1493
if (lhs->owningBlock() == rhs->owningBlock()) {
1494
return lhs->isBeforeOrAfter(rhs, moveSide);
1496
rhs = rhs->owningBlock()->owningNode();
1499
lhs = lhs->owningBlock()->owningNode();
1505
bool Node::isBefore(const Node* n) const {
1506
return isBeforeOrAfter(n, MoveSide::BEFORE);
1509
bool Node::isAfter(const Node* n) const {
1510
return isBeforeOrAfter(n, MoveSide::AFTER);
1513
Node* Node::insertBefore(Node* n) {
1514
AT_ASSERT(n->inBlockList());
1515
insertAfter(n->prev());
1519
Node* Node::insertAfter(Node* n) {
1520
AT_ASSERT(!inBlockList() && n->inBlockList());
1521
AT_ASSERT(n->owningBlock());
1523
n->kind() != prim::Return,
1524
"Attempting to insert a Node after the Return node or before the Param node. Tried to insert",
1529
this->owning_block_ = n->owningBlock();
1530
Node* next = n->next();
1533
this->next() = next;
1534
next->prev() = this;
1535
assignTopoPosition();
1539
void Node::moveAfter(Node* n) {
1544
void Node::moveBefore(Node* n) {
1549
void Node::removeInput(size_t i) {
1554
for (size_t j = i + 1; j < inputs_.size(); j++) {
1555
auto it = findUseForInput(j);
1558
inputs_.erase(inputs_.begin() + i);
1561
void Node::removeAllInputs() {
1563
for (const auto i : c10::irange(inputs().size())) {
1569
void Node::removeAllOutputs() {
1571
size_t init_osize = outputs_.size();
1572
for (auto i : c10::irange(init_osize)) {
1573
eraseOutput(init_osize - i - 1);
1577
void Node::permuteInputs(const std::vector<size_t>& new_order) {
1579
AT_ASSERT(new_order.size() == inputs_.size());
1580
std::vector<Value*> new_inputs;
1581
new_inputs.reserve(new_order.size());
1582
for (const auto i : c10::irange(new_order.size())) {
1583
AT_ASSERTM(inputs_.at(new_order[i]) != nullptr, "Repeated index");
1584
new_inputs.push_back(inputs_.at(new_order[i]));
1585
auto it = findUseForInput(new_order[i]);
1587
inputs_.at(new_order[i]) = nullptr;
1589
inputs_ = std::move(new_inputs);
1592
void Node::permuteOutputs(const std::vector<size_t>& new_order) {
1594
AT_ASSERT(new_order.size() == outputs_.size());
1595
std::vector<Value*> new_outputs;
1596
new_outputs.reserve(new_order.size());
1597
for (const auto i : c10::irange(new_order.size())) {
1598
AT_ASSERTM(outputs_.at(new_order[i]) != nullptr, "Repeated index");
1599
new_outputs.push_back(outputs_.at(new_order[i]));
1600
outputs_.at(new_order[i])->setOffset(i);
1601
outputs_.at(new_order[i]) = nullptr;
1603
outputs_ = std::move(new_outputs);
1606
use_list::iterator Node::findUseForInput(size_t i) {
1607
auto& input_uses = inputs_[i]->uses_;
1610
auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
1611
AT_ASSERT(use_it != input_uses.end());
1615
Value* Node::dropInput(size_t i) {
1616
AT_ASSERT(i < inputs_.size());
1617
auto input_node = inputs_[i];
1618
auto use_it = findUseForInput(i);
1619
input_node->uses_.erase(use_it);
1620
inputs_[i] = nullptr;
1624
void Node::removeFromList() {
1625
AT_ASSERT(inBlockList());
1626
this->owning_block_ = nullptr;
1627
Node* next = this->next();
1628
Node* prev = this->prev();
1629
prev->next() = next;
1630
next->prev() = prev;
1631
this->next() = nullptr;
1632
this->prev() = nullptr;
1635
Block* Node::findCommonAncestorBlockWith(Node* n) {
1636
if (n->owningBlock() == owningBlock()) {
1637
return owningBlock();
1643
size_t d_1 = n1->blocksFromGraphBlock();
1644
size_t d_2 = n2->blocksFromGraphBlock();
1646
for (; d_1 > d_2; --d_1) {
1647
n1 = n1->owningBlock()->owningNode();
1651
for (; d_2 > d_1; --d_2) {
1652
n2 = n2->owningBlock()->owningNode();
1658
if (n1->owningBlock() == n2->owningBlock()) {
1659
return n1->owningBlock();
1662
n1 = n1->owningBlock()->owningNode();
1663
n2 = n2->owningBlock()->owningNode();
1665
AT_ASSERT(n1 != nullptr);
1666
AT_ASSERT(n2 != nullptr);
1670
size_t Node::blocksFromGraphBlock() {
1673
while (n->owningBlock()->owningNode()) {
1674
n = n->owningBlock()->owningNode();
1680
inline const SourceRange& fakeRange() {
1681
static SourceRange range(std::make_shared<Source>(std::string("")), 0, 1);
1685
Value* Graph::insert(
1687
at::ArrayRef<NamedValue> args,
1688
at::ArrayRef<NamedValue> kwargs,
1689
const c10::optional<SourceRange>& range) {
1690
return emitBuiltinCall(
1691
range.value_or(fakeRange()), *this, opname, args, kwargs);
1694
Node* Graph::create(NodeKind kind, size_t num_outputs) {
1696
auto n = new Node(this, kind);
1697
for (const auto i : c10::irange(num_outputs)) {
1706
ArrayRef<Value*> inputs,
1707
size_t num_outputs) {
1708
auto n = create(kind, num_outputs);
1709
for (auto i : inputs) {
1715
Node* Graph::createAutogradZero() {
1716
return create(prim::AutogradZero);
1719
Node* Graph::createNone() {
1720
Node* n = create(prim::Constant);
1721
n->output()->setType(NoneType::get());
1725
Node* Graph::createUninitialized(TypePtr typ) {
1726
Node* n = create(prim::Uninitialized);
1727
n->output()->setType(std::move(typ));
1731
Node* Graph::createWithSubgraph(Symbol kind) {
1732
auto n = create(kind, 0);
1733
n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope()));
1737
Node* Graph::createTuple(at::ArrayRef<Value*> values, TupleTypePtr tuple_type) {
1738
TORCH_INTERNAL_ASSERT(
1739
!tuple_type || tuple_type->schema(),
1740
"only pass tuple_type when creating a named tuple");
1742
auto types = fmap(values, [](Value* v) { return v->type(); });
1743
tuple_type = TupleType::create(std::move(types));
1745
auto n = create(prim::TupleConstruct, values);
1747
n->output()->setType(tuple_type);
1751
Node* Graph::createTupleUnpack(Value* v) {
1752
TupleTypePtr tt = v->type()->expect<TupleType>();
1753
auto n = create(prim::TupleUnpack, {v}, 0);
1754
for (auto& element : tt->elements()) {
1755
n->addOutput()->setType(element);
1760
Node* Graph::createTupleIndex(
1763
const TypePtr& output_type) {
1764
auto n = create(prim::TupleIndex, {tup, idx});
1765
n->output()->setType(output_type);
1769
Node* Graph::createTupleSlice(
1773
int64_t num_values) {
1774
std::vector<Value*> new_vals;
1775
TupleTypePtr tt = tup->type()->expect<TupleType>();
1776
new_vals.reserve(num_values);
1779
for (const auto j : c10::irange(num_values)) {
1781
auto idx = insertConstant(IValue(static_cast<int64_t>(i)));
1782
auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i]));
1784
new_vals.push_back(tupleIndex->output());
1788
auto n = createTuple(new_vals);
1792
Node* Graph::createEnumName(Value* e) {
1793
e->type()->expect<EnumType>();
1794
assert(e->type()->cast<EnumType>());
1795
auto n = create(prim::EnumName, {e});
1796
n->output()->setType(StringType::get());
1800
Node* Graph::createEnumValue(Value* e) {
1801
auto enum_type = e->type()->expect<EnumType>();
1802
auto n = create(prim::EnumValue, {e});
1803
n->output()->setType(enum_type->getValueType());
1807
Node* Graph::createList(
1808
const TypePtr& contained_type,
1809
at::ArrayRef<Value*> values) {
1810
auto n = create(prim::ListConstruct, values);
1811
for (const auto& v : values) {
1813
v->type()->isSubtypeOf(*contained_type),
1814
"Expected a list element that subtypes '",
1815
contained_type->repr_str(),
1816
"' but got an element of type '",
1817
v->type()->repr_str(),
1820
n->output()->setType(ListType::create(contained_type));
1824
Node* Graph::createListUnpack(Value* v, size_t size) {
1825
ListTypePtr list_type = v->type()->expect<ListType>();
1826
TypePtr elem_type = list_type->getElementType();
1827
auto n = create(prim::ListUnpack, {v}, 0);
1828
for (const auto i : c10::irange(size)) {
1830
n->addOutput()->setType(elem_type);
1835
Node* Graph::createDict(
1836
const TypePtr& key_type,
1837
const TypePtr& value_type,
1838
at::ArrayRef<Value*> keys,
1839
at::ArrayRef<Value*> values) {
1840
AT_ASSERT(keys.size() == values.size());
1841
auto n = create(prim::DictConstruct, 1);
1842
for (const auto i : c10::irange(keys.size())) {
1843
AT_ASSERT(keys[i]->type()->isSubtypeOf(*key_type));
1844
AT_ASSERT(values[i]->type()->isSubtypeOf(*value_type));
1846
n->addInput(keys[i]);
1847
n->addInput(values[i]);
1849
n->output()->setType(DictType::create(key_type, value_type));
1853
Node* Graph::createNumToTensor(Value* value) {
1854
Node* result = create(prim::NumToTensor, {value});
1855
result->output()->setType(TensorType::fromNumberType(*value->type()));
1859
Node* Graph::createObject(const ClassTypePtr& type) {
1860
auto result = create(prim::CreateObject);
1861
result->output()->setType(type);
1865
Node* Graph::createSetAttr(
1867
const std::string& field,
1869
auto n = create(prim::SetAttr, {obj, newValue}, 0);
1870
n->s_(attr::name, field);
1874
Node* Graph::createGetAttr(Value* obj, const std::string& field) {
1875
const auto classType = obj->type()->expect<ClassType>();
1877
auto n = create(prim::GetAttr, {obj}, 1);
1878
n->s_(attr::name, field);
1880
const auto outputType = classType->getAttribute(field);
1881
n->output()->setType(outputType);
1882
n->output()->setDebugName(normalizeAttrName(field));
1886
Node* Graph::createStore(const std::string& name, Value* v) {
1887
auto n = create(prim::Store, {v}, 0);
1888
n->s_(attr::name, name);
1892
Node* Graph::createLoad(const std::string& name, const TypePtr& type) {
1893
auto n = create(prim::Load, {}, 1);
1894
n->s_(attr::name, name);
1895
n->output()->setType(type);
1899
Node* Graph::createIsInstance(Value* v, at::ArrayRef<TypePtr> types) {
1900
auto n = create(prim::isinstance, {v}, 1);
1901
n->tys_(attr::types, types.vec());
1902
n->output()->setType(BoolType::get());
1905
Value* Graph::insertUncheckedCast(Value* v, TypePtr type) {
1906
Node* n = insertNode(create(prim::unchecked_cast, {v}));
1907
n->output()->setType(std::move(type));
1911
Value* Graph::insertToList(Value* v, TypePtr type) {
1916
while (auto list_type = ptr->cast<ListType>()) {
1917
ptr = list_type->getElementType();
1923
if (ptr == IntType::get()) {
1925
} else if (ptr == FloatType::get()) {
1927
} else if (ptr == BoolType::get()) {
1929
} else if (ptr == ComplexType::get()) {
1935
" is not one of the supported element types for tolist: int, float, complex, bool");
1940
Value* dim_val = insertConstant(IValue(dim));
1941
Value* elem_ty_val = insertConstant(IValue(elem_ty));
1942
Node* n = insertNode(create(prim::tolist, {v, dim_val, elem_ty_val}));
1943
n->output()->setType(std::move(type));
1947
Value* Graph::insertFunctionCall(
1949
const MatchedSchema& matched) {
1950
std::string func_name = callee->name();
1951
Value* fn_constant = insertNode(create(prim::Constant))
1952
->s_(attr::name, func_name)
1954
->setType(FunctionType::create(callee));
1955
std::vector<Value*> inputs = {fn_constant};
1956
inputs.insert(inputs.end(), matched.inputs.begin(), matched.inputs.end());
1957
Value* result = insertNode(create(prim::CallFunction, inputs))
1959
->setType(matched.return_types.at(0));
1963
Value* Graph::insertMethodCall(
1964
std::string method_name,
1965
const MatchedSchema& matched) {
1966
Value* result = insertNode(create(prim::CallMethod, matched.inputs))
1967
->s_(attr::name, std::move(method_name))
1969
->setType(matched.return_types.at(0));
1973
Node* Graph::createClone(
1975
const std::function<Value*(Value*)>& value_map,
1978
Node* r = n->allocNewInstance(this);
1979
for (auto o : n->outputs()) {
1980
r->addOutput()->copyMetadata(o);
1983
for (auto i : n->inputs()) {
1984
r->addInput(value_map(i));
1987
for (auto b : n->blocks()) {
1988
r->addBlock()->cloneFrom(b, value_map);
1994
Value* Graph::insertConstant(
1996
c10::optional<SourceRange> loc,
1997
c10::optional<ScopePtr> scope) {
1998
return jit::insertConstant(*this, val, std::move(loc), std::move(scope));
2001
std::string Graph::toString(bool print_source_locations) const {
2002
std::ostringstream oss;
2003
print(oss, print_source_locations);
2008
for (const Node* n : all_nodes) {
2011
for (const Value* v : all_values) {
2014
for (const Block* b : all_blocks) {
2019
void Graph::freeNode(Node* n) {
2020
auto it = all_nodes.find(n);
2021
AT_ASSERT(it != all_nodes.end());
2023
all_nodes.erase(it);
2025
void Graph::freeValue(Value* v) {
2026
v->setDebugName("");
2027
auto it = all_values.find(v);
2028
AT_ASSERT(it != all_values.end());
2030
all_values.erase(it);
2032
void Graph::freeBlock(Block* b) {
2033
auto it = all_blocks.find(b);
2034
AT_ASSERT(it != all_blocks.end());
2036
all_blocks.erase(it);
2039
at::ArrayRef<Value*> createTupleUnpack(Value* v) {
2042
if (v->node()->kind() == prim::TupleConstruct) {
2043
return v->node()->inputs();
2045
auto& g = *v->owningGraph();
2046
return g.insertNode(g.createTupleUnpack(v))->outputs();
2049
void inlineCallStackOfNode(
2051
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
2054
c10::optional<ModuleInstanceInfo> m_info);
2056
static void inlineCallStackOfBlock(
2058
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
2061
c10::optional<ModuleInstanceInfo> m_info) {
2062
for (auto n : b->nodes()) {
2063
inlineCallStackOfNode(n, new_cs_entries, callee, to_replace, m_info);
2067
void inlineCallStackOfNode(
2069
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>& new_cs_entries,
2072
c10::optional<ModuleInstanceInfo> m_info) {
2073
auto new_node_cs = new_node->callstack();
2075
InlinedCallStack* raw_callstack_ptr =
2076
new_node_cs ? new_node_cs->get() : nullptr;
2078
if (!new_cs_entries.count(raw_callstack_ptr)) {
2080
new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
2081
*new_node_cs, callee, to_replace->sourceRange(), m_info);
2083
new_cs_entries[raw_callstack_ptr] = c10::make_intrusive<InlinedCallStack>(
2084
callee, to_replace->sourceRange(), m_info);
2087
new_node->setCallStack(new_cs_entries.at(raw_callstack_ptr));
2091
for (auto block : new_node->blocks()) {
2092
inlineCallStackOfBlock(block, new_cs_entries, callee, to_replace, m_info);
2096
std::vector<Value*> inlineCallTo(
2098
GraphFunction* callee,
2099
Graph* callee_graph) {
2100
WithInsertPoint guard(to_replace);
2101
std::unordered_map<Value*, Value*> value_map;
2102
std::vector<torch::jit::Value*> new_outputs = insertGraph(
2103
*to_replace->owningGraph(),
2105
to_replace->inputs(),
2108
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>
2109
new_callstack_entries;
2111
c10::optional<ModuleInstanceInfo> module_instance_info = c10::nullopt;
2112
if (to_replace->kind() == prim::CallMethod) {
2113
auto class_type_ptr = to_replace->input(0)->type()->cast<c10::ClassType>();
2114
if (to_replace->input(0)->node()->kind() == prim::GetAttr) {
2115
module_instance_info = c10::make_optional(ModuleInstanceInfo(
2116
class_type_ptr, to_replace->input(0)->node()->s(attr::name)));
2118
!to_replace->owningGraph()->inputs().empty() &&
2119
to_replace->input(0) == to_replace->owningGraph()->inputs()[0]) {
2122
module_instance_info =
2123
c10::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF"));
2127
module_instance_info = c10::make_optional(
2128
ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN"));
2134
std::unordered_set<Node*> updated_nodes;
2135
for (const auto& kv : value_map) {
2144
auto is_graph_input = std::find(
2145
callee_graph->inputs().begin(), callee_graph->inputs().end(), kv.first);
2146
if (is_graph_input != callee_graph->inputs().end()) {
2150
Node* new_node = kv.second->node();
2151
if (!updated_nodes.insert(new_node).second) {
2155
inlineCallStackOfNode(
2157
new_callstack_entries,
2160
module_instance_info);
2162
const auto& old_outputs = to_replace->outputs();
2164
AT_ASSERT(new_outputs.size() == old_outputs.size());
2165
for (const auto i : c10::irange(old_outputs.size())) {
2166
if (old_outputs[i]->hasDebugName()) {
2167
new_outputs[i]->setDebugName(old_outputs[i]->debugName());
2169
old_outputs[i]->replaceAllUsesWith(new_outputs[i]);
2171
to_replace->destroy();
2178
std::vector<Value*> inlineCallTo(
2180
GraphFunction* callee,
2181
bool inline_optimized_graph ) {
2183
inline_optimized_graph ? callee->optimized_graph() : callee->graph();
2184
return inlineCallTo(to_replace, callee, graph.get());
2187
std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs) {
2188
std::vector<Value*> new_outputs;
2189
if (outputs.size() != 1 || outputs.at(0)->type()->kind() != TupleType::Kind) {
2193
auto tup = outputs[0];
2194
for (Value* v : createTupleUnpack(tup)) {
2195
new_outputs.emplace_back(v);
2199
if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
2200
tup->node()->destroy();
2205
std::vector<Node*> findAllNodes(
2206
at::ArrayRef<Block*> array,
2209
std::vector<Node*> ret;
2210
for (auto block : array) {
2211
findAllNodes(*block, kind, recurse, ret);
2216
std::vector<Node*> findAllNodes(Block& block, Symbol kind, bool recurse) {
2217
return findAllNodes({&block}, kind, recurse);
2220
std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse) {
2221
return findAllNodes(*g.block(), kind, recurse);
2224
std::vector<Value*> insertGraph(
2227
ArrayRef<Value*> inputs,
2228
std::unordered_map<Value*, Value*>& value_map) {
2229
auto value_map_func = [&](Value* v) { return value_map.at(v); };
2230
AT_ASSERT(callee.inputs().size() == inputs.size());
2231
for (const auto i : c10::irange(inputs.size())) {
2232
value_map[callee.inputs()[i]] = inputs[i];
2234
for (auto* node : callee.nodes()) {
2235
auto* new_node = g.insertNode(g.createClone(node, value_map_func));
2236
for (size_t i = 0; i < node->outputs().size(); ++i) {
2237
value_map[node->outputs()[i]] = new_node->outputs()[i];
2241
std::vector<Value*> outputs;
2242
for (auto* output : callee.outputs()) {
2243
outputs.push_back(value_map_func(output));
2249
std::vector<Value*> insertGraph(
2252
ArrayRef<Value*> inputs) {
2253
std::unordered_map<Value*, Value*> value_map;
2254
return insertGraph(g, callee, inputs, value_map);
2257
void ProfileOp::cloneFrom(Node* other_) {
2258
Node::cloneFrom(other_);
2259
auto other = other_->cast<ProfileOp>();
2260
this->callback_ = other->getCallback();
2263
Node* ProfileOp::allocNewInstance(Graph* g) {
2264
return new ProfileOp(g, {nullptr});
2267
void ProfileIValueOp::cloneFrom(Node* other_) {
2268
Node::cloneFrom(other_);
2269
auto other = other_->cast<ProfileIValueOp>();
2270
this->callback_ = other->getCallback();
2273
Node* ProfileIValueOp::allocNewInstance(Graph* g) {
2274
return new ProfileIValueOp(g, {nullptr});
2277
TypePtr NamedValue::type() const {
2279
return value_->type();
2281
return ivalue_.type();
2285
const Symbol ProfileOp::Kind = ::c10::prim::profile;
2286
const Symbol ProfileIValueOp::Kind = ::c10::prim::profile_ivalue;
2288
OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
2289
insert(sig_literals);
2292
std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const {
2293
std::vector<std::shared_ptr<Operator>> result;
2294
for (const auto& kv : ops) {
2295
auto ops_for_symbol = kv.second;
2296
result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end());
2301
void OperatorSet::insert(std::initializer_list<const char*> sig_literals) {
2302
for (const char* sig : sig_literals) {
2303
auto op = getOperatorForLiteral(sig);
2304
ops[Symbol::fromQualString(op->schema().name())].push_back(op);
2308
bool Node::isMemberOf(const OperatorSet& os) const {
2309
auto it = os.ops.find(kind());
2310
if (it == os.ops.end()) {
2313
for (auto& op : it->second) {
2314
if (matches(op->schema())) {