1
#include <torch/csrc/jit/frontend/tracer.h>
3
#include <ATen/Backtrace.h>
4
#include <ATen/ScalarOps.h>
5
#include <ATen/TracerMode.h>
6
#include <ATen/core/Dict.h>
7
#include <ATen/core/functional.h>
8
#include <c10/util/Exception.h>
9
#include <c10/util/irange.h>
10
#include <torch/csrc/autograd/engine.h>
11
#include <torch/csrc/autograd/function.h>
12
#include <torch/csrc/autograd/variable.h>
13
#include <torch/csrc/jit/api/module.h>
14
#include <torch/csrc/jit/ir/constants.h>
15
#include <torch/csrc/jit/ir/ir.h>
16
#include <torch/csrc/jit/passes/dead_code_elimination.h>
17
#include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
18
#include <torch/csrc/jit/passes/inliner.h>
19
#include <torch/csrc/jit/passes/lower_tuples.h>
20
#include <torch/csrc/jit/passes/normalize_ops.h>
21
#include <torch/csrc/jit/passes/remove_expands.h>
22
#include <torch/csrc/utils/variadic.h>
23
#include <torch/custom_class.h>
29
namespace torch::jit::tracer {
31
////////////////////////////////////////////////////////////////////////////////
32
// Recording the traces
33
////////////////////////////////////////////////////////////////////////////////
37
void genericAddInput(Node* n, T value) {
38
Value* v = n->owningGraph()->insertConstant(value);
39
recordSourceLocation(v->node());
44
void genericAddOptionalInput(
47
const c10::optional<T>& value) {
49
jit::tracer::addInputs(n, name, *value);
51
Graph* g = n->owningGraph();
52
Value* none = g->insertNode(g->createNone())->output();
58
void badArgType(const T& v) {
60
"Found an unsupported argument type in the JIT tracer: ",
61
c10::demangle_type<T>(),
62
". File a bug report.");
65
thread_local std::shared_ptr<TracingState> tracing_state;
68
static std::atomic<bool> tracer_state_warn_mode{true};
70
std::atomic<bool>& getTracerStateWarnMode() {
71
return tracer_state_warn_mode;
74
std::function<void()> pauseTracing() {
76
std::shared_ptr<tracer::TracingState> state = getTracingState();
77
tracer::setTracingState(nullptr);
79
return [state]() { tracer::setTracingState(state); };
82
void delValueTrace(const IValue& var) {
83
getTracingState()->delValue(var);
85
void TracingState::delValue(const IValue& var) {
86
for (const auto i : c10::irange(env_stack.size())) {
87
auto& value_map = env_stack.at(env_stack.size() - 1 - i);
88
auto it = value_map.find(var);
89
if (it == value_map.end()) {
96
// Given a IValue 'var', return the 'node' which represents the instruction
97
// which computes the value of this variable in the IR.
98
// Here, we interpret untraced variables as constants that are just embedded
99
// in the graph. This is useful to handle code which does things like this
100
// (from torch.autograd.variable, now moved to C++):
102
// def mm(self, matrix):
103
// output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
104
// return Addmm.apply(output, self, matrix, 0, 1, True)
106
// Here, mm fakes up a dummy variable with uninitialized data to do an inplace
107
// update on, but subsequently ignores it because the alpha scaling factor is
108
// zero. This is one of the cases where a Variable can be created inside of a
109
// trace, and if we treat it as a constant, everything will work out.
110
Value* getValueTrace(const IValue& var) {
111
return getTracingState()->getValue(var);
113
static Value* getOptTensorValueTrace(const c10::optional<at::Tensor>& var) {
114
return getValueTrace(IValue(var));
116
Value* TracingState::getValue(const IValue& var) {
117
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...]
119
if (var.isTensorList()) {
121
->insertNode(graph->createList(
124
var.toTensorVector(),
125
[&](const IValue& val) { return getValue(val); })))
127
} else if (var.isTuple()) {
129
->insertNode(graph->createTuple(fmap(
130
var.toTupleRef().elements(),
131
[&](const IValue& val) { return getValue(val); })))
133
} else if (var.isGenericDict()) {
134
auto dict = var.toGenericDict();
135
TypePtr key_type = dict.keyType();
136
TypePtr value_type = dict.valueType();
137
std::vector<Value*> keys;
138
std::vector<Value*> values;
139
for (const auto& entry : dict) {
140
keys.emplace_back(getValue(entry.key()));
141
values.emplace_back(getValue(entry.value()));
143
auto dict_node = graph->createDict(key_type, value_type, keys, values);
144
return graph->insertNode(dict_node)->output();
146
if (var.isTensor()) {
147
auto& ten = var.toTensor();
148
if (!ten.defined()) {
149
Node* n = graph->createNone();
150
return graph->insertNode(n)->output();
152
for (const auto i : c10::irange(env_stack.size())) {
153
auto& value_map = env_stack.at(env_stack.size() - 1 - i);
154
auto it = value_map.find(var);
155
if (it == value_map.end()) {
158
if (!it->second->hasDebugName()) {
159
auto unique_name = getTracingState()->lookup_var_name_fn(ten);
160
if (!unique_name.empty()) {
161
it->second->setDebugName(unique_name);
167
// Didn't find it. Bake in a constant
168
if (ten.requires_grad()) {
170
std::ostringstream oss;
171
oss << "Cannot insert a Tensor that requires grad as a constant. "
172
<< "Consider making it a parameter or input, or detaching the gradient\n"
175
throw std::runtime_error(oss.str());
178
Value* constant = graph->insertConstant(ten);
179
recordSourceLocation(constant->node());
180
constant->inferTypeFrom(ten);
181
auto it = env_stack.back().emplace(var, constant);
182
return it.first->second;
183
} else if (var.isFuture() || var.isObject()) {
184
for (const auto i : c10::irange(env_stack.size())) {
185
auto& future_map = env_stack.at(env_stack.size() - 1 - i);
186
auto it = future_map.find(var);
187
if (it == future_map.end()) {
193
// Find torchbind classes
194
if (isCustomClass(var)) {
195
auto obj = Object(var.toObject());
196
auto qualname = obj.type()->name();
197
auto custom_class_type = getCustomClass(qualname->qualifiedName());
198
if (custom_class_type) {
199
auto capsule = var.toObject()->getAttr("capsule");
200
for (const auto i : c10::irange(env_stack.size())) {
201
auto& value_map = env_stack.at(env_stack.size() - 1 - i);
202
auto it = value_map.find(capsule);
203
if (it == value_map.end()) {
211
std::ostringstream oss;
212
if (var.isFuture()) {
213
oss << "Tried to trace Future or Object that the tracer was not aware of.";
215
oss << "Tried to trace " << var
216
<< " but it is not part of the active trace. Modules that are called during a trace"
217
<< " must be registered as submodules of the thing being traced.";
219
throw std::runtime_error(oss.str());
221
// If the values are non-tensors, we try to create constants
222
// and bake those constants into the traced graph
223
auto constant = tryInsertConstant(*graph, var);
225
recordSourceLocation(constant.value()->node());
228
std::ostringstream os;
229
os << "Tracer cannot get value trace for type " << var.tagKind() << ". "
230
<< "The below value could not be materialized as a constant:\n"
232
throw std::runtime_error(os.str());
235
bool TracingState::hasValue(const IValue& var) const {
236
for (const auto& frame : env_stack) {
237
if (frame.count(var)) {
244
Value* TracingState::getOutput(const IValue& iv, size_t i) {
245
bool tracing_mode_strict = getTracingState()->strict;
247
const at::Tensor& var = iv.toTensor();
248
if (!var.defined()) {
249
Node* n = graph->createNone();
250
return graph->insertNode(n)->output();
253
auto& value_map = getTracingState()->env_stack.back();
254
auto it = value_map.find(iv);
255
if (it == value_map.end()) {
256
std::ostringstream os;
257
os << "output " << i << " (" << var
258
<< ") of traced region did not have observable "
259
<< "data dependence with trace inputs; this probably indicates your "
261
<< "cannot be understood by the tracer.";
262
throw std::runtime_error(os.str());
265
} else if (iv.isTensorList()) {
266
if (tracing_mode_strict) {
268
"Encountering a list at the output of the tracer", STRICT_TRACER_MSG);
271
->insertNode(graph->createList(
275
[&](const IValue& ival) { return getOutput(ival, i); })))
277
} else if (iv.isTuple()) {
278
const auto& tuple = iv.toTupleRef().elements();
279
auto tuple_node = graph->createTuple(
280
fmap(tuple, [&](const IValue& ival) { return getOutput(ival, i); }));
281
graph->insertNode(tuple_node);
282
return tuple_node->output();
283
} else if (iv.isGenericDict()) {
284
if (tracing_mode_strict) {
285
throw std::runtime_error(
286
"Encountering a dict at the output of the tracer" +
287
std::string(STRICT_TRACER_MSG));
289
auto dict = iv.toGenericDict();
290
TypePtr key_type = dict.keyType();
291
TypePtr value_type = dict.valueType();
293
bool key_type_valid = key_type->isSubtypeOf(*StringType::get()) ||
294
key_type->isSubtypeOf(*TensorType::get());
295
bool value_type_valid = value_type->isSubtypeOf(*TensorType::get());
297
// Support tuple values that contain only tensors
298
if (value_type->isSubtypeOf(*AnyTupleType::get())) {
299
value_type_valid = true;
300
for (const auto& type : value_type->containedTypes()) {
301
if (!type->isSubtypeOf(*TensorType::get())) {
302
value_type_valid = false;
308
if (!key_type_valid || !value_type_valid) {
309
std::ostringstream os;
310
os << "output " << i << " (" << dict << ") of traced region "
311
<< "cannot be understood by the tracer, only outputs matching"
312
<< "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
313
<< "can be a dictionary output of a traced function";
314
throw std::runtime_error(os.str());
316
std::vector<Value*> keys;
317
std::vector<Value*> values;
318
for (const auto& entry : dict) {
319
keys.emplace_back(getValue(entry.key()));
320
values.emplace_back(getOutput(entry.value(), i));
322
auto dict_node = graph->createDict(key_type, value_type, keys, values);
323
graph->insertNode(dict_node);
324
return dict_node->output();
327
"Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions");
331
Node* TracingState::createNode(c10::Symbol op_name, size_t num_outputs) {
332
return graph->create(op_name, num_outputs);
335
void TracingState::insertNode(Node* node) {
336
graph->insertNode(node);
339
// XXX: this function mutates input
340
static IValue addInput(
341
const std::shared_ptr<TracingState>& state,
345
value->setType(type);
346
if (type->isSubtypeOf(*TensorType::get())) {
347
auto input_tensor = input.toTensor();
348
auto name = Variable(input_tensor).name();
349
if (state->hasValue(input)) {
350
input_tensor = input_tensor.view(input_tensor.sizes());
352
if (!value->hasDebugName()) {
353
value->setDebugName(name);
355
state->setValue(input_tensor, value);
357
} else if (auto tuple_type = type->cast<TupleType>()) {
359
state->graph->insertNode(state->graph->createTupleUnpack(value));
360
auto elem_values = unpack_node->outputs();
361
auto elem_types = tuple_type->elements();
362
auto tuple = input.toTuple();
363
const auto& elems = tuple->elements();
364
size_t num_elems = elems.size();
366
elem_values.size() == num_elems && elem_types.size() == num_elems);
367
for (const auto i : c10::irange(num_elems)) {
368
tuple->unsafeSetElement(
369
i, addInput(state, elems.at(i), elem_types[i], elem_values[i]));
372
} else if (auto dict_type = type->cast<DictType>()) {
373
auto dict = input.toGenericDict();
375
// Unpack the list values statically
376
for (const auto& entry : dict) {
377
const IValue& key = entry.key();
378
auto static_key = state->graph->insertConstant(key);
380
state->graph->insert(aten::__getitem__, {value, static_key});
381
recordSourceLocation(static_value->node());
382
dict.insert_or_assign(
385
state, entry.value(), dict_type->getValueType(), static_value));
389
} else if (auto list_type = type->cast<ListType>()) {
390
size_t num_elems = input.isList() ? input.toListRef().size()
391
: input.toTensorVector().size();
392
auto list_unpack = state->graph->insertNode(
393
state->graph->createListUnpack(value, num_elems));
394
auto unpack_outputs = list_unpack->outputs();
396
if (input.isTensorList()) {
397
auto elems = input.toTensorList();
398
for (const auto i : c10::irange(num_elems)) {
402
list_type->getElementType(),
408
auto elems = input.toList();
409
for (const auto i : c10::irange(num_elems)) {
413
list_type->getElementType(),
420
"Only tensors or (possibly nested) dict or tuples of tensors can be "
421
"inputs to traced functions. Got ",
426
static void gatherParametersAndBuffers(
427
const std::shared_ptr<TracingState>& state,
430
const std::string& prefix) {
431
Graph& g = *self_value->owningGraph();
433
state->setValue(self._ivalue(), self_value);
435
auto self_ty = self.type();
436
for (const NameValue& s : self.named_attributes(/*recurse=*/false)) {
437
auto qualname = prefix + "." + s.name;
438
Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr))
439
->s_(attr::scope, qualname)
441
->setType(s.value.type());
442
if (s.value.type()->isSubtypeOf(*TensorType::get())) {
443
addInput(state, s.value, s.value.type(), trace_get_attr);
445
if (isCustomClass(s.value)) {
446
tracer::setValueTrace(s.value, trace_get_attr);
449
auto attr_type = self_ty->getAttribute(s.name);
450
// Skipping Parameters and Buffers that are behind an `InterfaceType`
451
// because it is illegal for InterfaceType to expose any attribute.
452
// And these attributes should never be used/exposed outside of
453
// InterfaceType'd module anyway.
454
if (attr_type->is_module() &&
455
attr_type->kind() != TypeKind::InterfaceType) {
456
gatherParametersAndBuffers(
457
state, trace_get_attr, Module(s.value.toObject()), qualname);
462
std::pair<std::shared_ptr<TracingState>, Stack> trace(
464
const std::function<Stack(Stack)>& traced_fn,
465
std::function<std::string(const Variable&)> var_name_lookup_fn,
469
const std::vector<std::string>& argument_names) {
471
// Start tracing, treating 'inputs' as inputs to the trace, which can be
472
// varied on subsequent invocations of the trace. Any other variables
473
// will be treated as constants.
475
AT_ERROR("Tracing can't be nested");
477
auto state = std::make_shared<TracingState>();
478
setTracingState(state);
480
// if we are a module, then make sure the modules parameters are in the map
481
// and mapped to accesses to the self object
483
Value* self_value = state->graph->insertInput(0, "self")->setType(
484
self->_ivalue()->type());
485
gatherParametersAndBuffers(state, self_value, *self, {"__module"});
488
// When enough argument name hints are provided, use them as debug names
489
// for traced function/modules.
490
// Here argument_names is allowed to have more names than needed because
491
// some arguments may have valid default values, therefore they don't need
493
if (argument_names.size() >= inputs.size()) {
494
for (size_t i = 0, e = inputs.size(); i < e; ++i) {
495
IValue& input = inputs[i];
500
state->graph->addInput(argument_names[i]));
503
for (IValue& input : inputs) {
504
input = addInput(state, input, input.type(), state->graph->addInput());
508
auto graph = state->graph;
510
getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);
511
getTracingState()->strict = strict;
512
getTracingState()->force_outplace = force_outplace;
514
// Invoke the traced function
515
auto out_stack = traced_fn(inputs);
517
// Exit a trace, treating 'out_stack' as the outputs of the trace. These
518
// are the variables whose values will be computed upon subsequent
519
// invocations of the trace.
521
for (auto& output : out_stack) {
522
// NB: The stack is in "reverse" order, so when we pass the diagnostic
523
// number we need to flip it based on size.
524
state->graph->registerOutput(
525
state->getOutput(output, out_stack.size() - i));
528
setTracingState(nullptr);
530
if (getInlineEverythingMode()) {
533
FixupTraceScopeBlocks(graph, self);
535
return {state, out_stack};
542
// Abort tracing. Used to reset the state in case of errors.
544
setTracingState(nullptr);
547
void setValueTrace(const IValue& v, Value* value) {
548
return getTracingState()->setValue(v, value);
550
void TracingState::setValue(const IValue& v, Value* value) {
552
auto& var = v.toTensor();
553
AT_ASSERT(var.defined());
554
env_stack.back()[v] = value;
556
// If the value comes from a CallFunction or CallMethod, it may not have
557
// shape information attached. For debuggability, we enhance the type
558
// information by assigning the concrete value's tupe to the jit::Value.
559
if (auto tensor_type = value->type()->cast<TensorType>()) {
560
if (!tensor_type->isComplete()) {
561
value->inferTypeFrom(var);
564
} else if (v.isTensorList()) {
565
auto outputs = v.toTensorList();
567
graph->insertNode(graph->createListUnpack(value, outputs.size()));
568
for (const auto i : c10::irange(outputs.size())) {
569
setValue(outputs.get(i), unpack_node->outputs()[i]);
571
} else if (v.isTuple()) {
572
const auto& outputs = v.toTupleRef().elements();
573
Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
574
for (const auto i : c10::irange(outputs.size())) {
575
setValue(outputs[i], unpack_node->outputs()[i]);
577
} else if (v.isList()) {
578
auto elements = v.toListRef();
580
graph->insertNode(graph->createListUnpack(value, elements.size()));
581
for (const auto i : c10::irange(elements.size())) {
582
setValue(elements[i], unpack_node->outputs()[i]);
584
} else if (isCustomClass(v)) {
585
auto capsule = v.toObject()->getAttr("capsule");
586
env_stack.back()[capsule] = value;
587
} else if (v.isFuture() || v.isObject()) {
588
env_stack.back()[v] = value;
589
} else if (v.isGenericDict()) {
590
auto dict = v.toGenericDict();
591
TypePtr key_type = dict.keyType();
592
TypePtr value_type = dict.valueType();
593
for (const auto& entry : dict) {
594
auto static_key = graph->insertConstant(entry.key());
595
auto static_value = graph->insert(aten::__getitem__, {value, static_key});
596
setValue(entry.value(), static_value);
599
std::ostringstream os;
600
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
601
<< "Supported types are tensor, tensor list, and tuple of tensors.";
602
throw std::runtime_error(os.str());
606
void addInputs(Node* n, const char* name, int64_t value) {
607
using ArgumentStash = jit::tracer::ArgumentStash;
608
if (ArgumentStash::hasValue(name)) {
609
Value* v = ArgumentStash::popValue(name);
612
detail::genericAddInput(n, value);
616
void addInputs(Node* n, const char* name, c10::SymInt value) {
617
addInputs(n, name, value.guard_int(__FILE__, __LINE__));
620
void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
621
using ArgumentStash = jit::tracer::ArgumentStash;
622
if (ArgumentStash::hasValue(name)) {
623
Value* v = ArgumentStash::popValue(name);
626
detail::genericAddInput(n, *value);
628
Graph* g = n->owningGraph();
629
Value* none = g->insertNode(g->createNone())->output();
633
void addInputs(Node* n, const char* name, bool value) {
634
detail::genericAddInput(n, value);
636
void addInputs(Node* n, const char* name, const c10::optional<bool>& value) {
637
detail::genericAddOptionalInput(n, name, value);
639
void addInputs(Node* n, const char* name, double value) {
640
detail::genericAddInput(n, value);
642
void addInputs(Node* n, const char* name, const c10::optional<double>& value) {
643
detail::genericAddOptionalInput(n, name, value);
645
void addInputs(Node* n, const char* name, const at::Scalar& value) {
646
using ArgumentStash = jit::tracer::ArgumentStash;
647
if (ArgumentStash::hasValue(name)) {
648
Value* v = ArgumentStash::popValue(name);
651
detail::genericAddInput(n, value);
657
const c10::optional<at::Scalar>& value) {
658
detail::genericAddOptionalInput(n, name, value);
660
void addInputs(Node* n, const char* name, const c10::string_view value) {
661
detail::genericAddInput(n, std::string(value));
666
const c10::optional<c10::string_view>& value) {
667
detail::genericAddOptionalInput(n, name, value);
669
void addInputs(Node* n, const char* name, const at::Tensor& value) {
670
n->addInput(getValueTrace(value));
675
const c10::optional<at::Tensor>& value) {
676
detail::genericAddOptionalInput(n, name, value);
681
const c10::optional<at::Generator>& value) {
682
Graph* g = n->owningGraph();
684
if (value.has_value() && value->defined()) {
685
detail::genericAddInput(n, *value);
687
Value* undef_gen = g->insertNode(g->createNone())->output();
688
n->addInput(undef_gen);
691
void addInputs(Node* n, const char* name, at::Device value) {
692
detail::genericAddInput(n, value);
694
void addInputs(Node* n, const char* name, c10::Stream stream) {
695
detail::genericAddInput(n, c10::IValue(stream));
697
void addInputs(Node* n, const char* name, at::Layout value) {
698
detail::genericAddInput(n, static_cast<int64_t>(value));
700
void addInputs(Node* n, const char* name, at::ScalarType value) {
701
detail::genericAddInput(n, static_cast<int64_t>(value));
703
void addInputs(Node* n, const char* name, at::MemoryFormat value) {
704
detail::genericAddInput(n, static_cast<int64_t>(value));
709
const c10::optional<at::MemoryFormat>& value) {
710
detail::genericAddOptionalInput(n, name, value);
715
const c10::optional<at::Layout>& value) {
716
detail::genericAddOptionalInput(n, name, value);
721
const c10::optional<at::Device>& value) {
722
detail::genericAddOptionalInput(n, name, value);
727
c10::optional<at::DimnameList> value) {
728
TORCH_CHECK(false, "NYI: Named tensors are not supported with the tracer");
733
const c10::optional<at::ScalarType>& value) {
734
detail::genericAddOptionalInput(n, name, value);
739
at::ArrayRef<at::Tensor> value,
740
bool allow_undefined) {
741
addInputs(n, name, at::ITensorListRef(value), allow_undefined);
746
std::vector<at::Tensor> value,
747
bool allow_undefined) {
748
addInputs(n, name, at::ITensorListRef(value), allow_undefined);
753
at::ITensorListRef value,
754
bool allow_undefined) {
755
Graph* g = n->owningGraph();
756
Node* list_node = nullptr;
757
if (allow_undefined) {
758
// if allow undefined, we create a list of optional tensors
759
list_node = g->insertNode(
760
g->createList(OptionalType::ofTensor(), fmap(value, getValueTrace)));
762
list_node = g->insertNode(
763
g->createList(TensorType::get(), fmap(value, getValueTrace)));
765
n->addInput(list_node->output());
767
TORCH_API void addInputs(
770
const List<c10::optional<at::Tensor>>& value) {
771
Graph* g = n->owningGraph();
772
Node* list_node = nullptr;
773
list_node = g->insertNode(g->createList(
774
OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace)));
775
n->addInput(list_node->output());
780
ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value,
781
const ClassTypePtr& class_type) {
782
Graph* g = n->owningGraph();
784
g->insertNode(g->createList(class_type, fmap(value, getValueTrace)));
785
n->addInput(list_node->output());
788
void addInputs(Node* n, const char* name, at::IntArrayRef value) {
789
using ArgumentStash = jit::tracer::ArgumentStash;
790
std::vector<Value*> info = ArgumentStash::hasIntArrayRef(name)
791
? ArgumentStash::popIntArrayRef(name)
792
: ArgumentStash::IntArrayRefTrace(value.size());
794
auto& g = getTracingState()->graph;
795
for (const auto i : c10::irange(info.size())) {
796
if (info[i] != nullptr)
798
info[i] = g->insertConstant(value[i]);
799
recordSourceLocation(info[i]->node());
801
for (jit::Value* v : info) {
802
if (*v->type() != *jit::IntType::get()) {
803
throw std::runtime_error(
804
"Type mismatch in setposattr for IntArrayRef. Check that your program "
805
"is valid without tracing, and please file a bug report if it is.");
809
g->insertNode(g->createList(jit::IntType::get(), info))->output());
812
void addInputs(Node* n, const char* name, c10::SymIntArrayRef value) {
813
addInputs(n, name, C10_AS_INTARRAYREF_SLOW(value));
816
void addInputs(Node* n, const char* name, c10::optional<c10::SymInt> value) {
821
? c10::make_optional(value->guard_int(__FILE__, __LINE__))
828
const c10::optional<at::IntArrayRef>& opt_value) {
829
detail::genericAddOptionalInput(n, name, opt_value);
835
const at::OptionalIntArrayRef& opt_value) {
836
if (opt_value.has_value()) {
837
jit::tracer::addInputs(n, name, *opt_value);
839
Graph* g = n->owningGraph();
840
Value* none = g->insertNode(g->createNone())->output();
848
const at::OptionalSymIntArrayRef& opt_value) {
849
if (opt_value.has_value()) {
850
jit::tracer::addInputs(n, name, *opt_value);
852
Graph* g = n->owningGraph();
853
Value* none = g->insertNode(g->createNone())->output();
858
void addInputs(Node* n, const char* name, ArrayRef<double> value) {
859
std::vector<Value*> info;
860
auto& g = getTracingState()->graph;
861
for (double elt : value) {
862
info.push_back(g->insertConstant(elt));
863
recordSourceLocation(info.back()->node());
866
g->insertNode(g->createList(jit::FloatType::get(), info))->output());
872
const c10::optional<c10::ArrayRef<double>>& opt_value) {
873
detail::genericAddOptionalInput(n, name, opt_value);
879
const c10::intrusive_ptr<c10::ivalue::Object>& obj) {
880
Value* v = getValueTrace(obj);
884
void addOutput(Node* node, const at::Tensor& output) {
885
setOutput(node->addOutput(), output);
888
void setOutput(Value* value, const at::Tensor& output) {
889
if (output.defined()) {
890
value->inferTypeFrom(output);
891
setValueTrace(output, value);
895
void addOutput(Node* node, const std::vector<at::Tensor>& outputs) {
896
Value* value = node->addOutput()->setType(ListType::ofTensors());
897
Graph* graph = node->owningGraph();
898
Node* unpack_node = graph->insertNode(
899
graph->create(prim::ListUnpack, {value}, outputs.size()));
900
for (const auto i : c10::irange(outputs.size())) {
901
Value* output_val = unpack_node->outputs()[i];
902
output_val->inferTypeFrom(outputs[i]);
903
setValueTrace(outputs[i], output_val);
907
void addOutput(Node* node, const c10::List<at::Tensor>& outputs) {
908
return addOutput(node, outputs.vec());
913
const c10::intrusive_ptr<c10::ivalue::Object>& output) {
914
Value* output_val = node->addOutput();
915
output_val->inferTypeFrom(output);
916
setValueTrace(output, output_val);
919
const std::shared_ptr<TracingState>& getTracingState() {
920
return detail::tracing_state;
923
void setTracingState(std::shared_ptr<TracingState> state) {
924
at::tracer::impl::set_dispatch_enabled(state != nullptr);
925
detail::tracing_state = std::move(state);
928
TracingState::TracingState() : graph(new Graph()), env_stack{Frame()} {}
930
TracingState::~TracingState() = default;
932
autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
933
auto& tracing_state = getTracingState();
934
auto& graph = tracing_state->graph;
938
// Make sure this scalar to tensor isn't traced!
939
at::AutoDispatchBelowADInplaceOrView guard;
940
size_var = scalar_to_tensor(at::Scalar(var.size(dim)));
942
auto* value = getValueTrace(var);
943
auto dim_val = graph->insertConstant(dim);
944
recordSourceLocation(dim_val->node());
945
auto* node = graph->insertNode(graph->create(aten::size, {value, dim_val}));
946
recordSourceLocation(node);
947
node->output()->setType(jit::IntType::get());
950
graph->insertNode(graph->createNumToTensor(node->output()))->output();
951
setValueTrace(size_var, ten);
955
autograd::Variable getNumelOf(const autograd::Variable& var) {
956
auto& tracing_state = getTracingState();
957
auto& graph = tracing_state->graph;
961
// Make sure this scalar to tensor isn't traced!
962
at::AutoDispatchBelowADInplaceOrView guard;
963
numel_var = scalar_to_tensor(at::Scalar(var.numel()));
965
auto* value = getValueTrace(var);
966
auto* node = graph->insertNode(graph->create(Symbol::aten("numel"), {value}));
967
recordSourceLocation(node);
968
node->output()->setType(jit::IntType::get());
971
graph->insertNode(graph->createNumToTensor(node->output()))->output();
972
setValueTrace(numel_var, ten);
976
void ensureUniqueIfOutOfPlaced(const char* name, const at::Tensor& tensor) {
977
auto& state = getTracingState();
978
if (state && state->force_outplace == false) {
979
// If we're not converting in-place ops to out-of-place, this check is
983
auto aliases = tensor.storage().use_count();
984
if (isTracing() && aliases > 1) {
985
std::stringstream ss;
986
ss << "There are " << aliases
987
<< " live references to the data region being modified when tracing in-place operator "
989
<< ". This might cause the trace to be incorrect, because all other views "
990
<< "that also reference this data will not reflect this change in the trace! "
991
<< "On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. "
992
<< "are outputs of torch.split), this might still be safe.";
993
warn(ss.str().c_str());
996
void ensureUniqueIfOutOfPlaced(
998
const c10::optional<at::Tensor>& tensor) {
999
ensureUniqueIfOutOfPlaced(name, tensor.has_value() ? *tensor : at::Tensor());
1002
////////////////////////////////////////////////////////////////////////////////
1004
////////////////////////////////////////////////////////////////////////////////
1005
thread_local ArgumentStash ArgumentStash::stash;
1007
void ArgumentStash::stashIntArrayRefElem(
1008
const std::string& arg_name,
1011
const Variable& var) {
1012
// TODO: check type?
1015
IntArrayRefTrace& list_trace =
1016
stash.intlists.emplace(arg_name, size).first->second;
1017
AT_ASSERT(size == list_trace.size());
1018
AT_ASSERT(idx < list_trace.size());
1019
AT_ASSERT(list_trace[idx] == nullptr);
1021
Value* ten = getValueTrace(var);
1022
auto& g = *ten->owningGraph();
1023
WithInsertPoint guard(ten->node()->next());
1024
auto prim = g.insert(aten::Int, {ten});
1025
list_trace[idx] = prim;
1028
void ArgumentStash::stashValue(
1029
const std::string& arg_name,
1031
const Variable& var,
1032
const TypePtr& type) {
1036
Value* ten = getValueTrace(var);
1037
WithInsertPoint guard(ten->node()->next());
1038
auto& g = *ten->owningGraph();
1040
if (type == IntType::get()) {
1041
ten = g.insert(aten::Int, {ten});
1042
} else if (type == FloatType::get()) {
1043
ten = g.insert(aten::Float, {ten});
1044
} else if (type == NumberType::get()) {
1045
ten = g.insert(aten::ScalarImplicit, {ten});
1048
stash.values.emplace(arg_name, ten);
1051
////////////////////////////////////////////////////////////////////////////////
1052
// Stack trace recording
1053
////////////////////////////////////////////////////////////////////////////////
1054
// no python present so we just do not record source information
1055
static void defaultRecordSourceLocation(Node* n) {}
1056
std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(
1057
defaultRecordSourceLocation);
1058
void recordSourceLocation(Node* n) {
1059
return record_source_location.load()(n);
1061
void setRecordSourceLocation(void (*v)(Node*)) {
1062
record_source_location.store(v);
1065
static std::vector<StackEntry> defaultPythonCallstack() {
1066
return std::vector<StackEntry>();
1068
std::atomic<decltype(&defaultPythonCallstack)> python_callstack_fn(
1069
defaultPythonCallstack);
1070
std::vector<StackEntry> pythonCallstack() {
1071
return python_callstack_fn.load()();
1073
void setPythonCallstack(std::vector<StackEntry> (*v)()) {
1074
python_callstack_fn.store(v);
1077
static void defaultWarn(const std::string& str) {
1080
std::atomic<warn_fn_type> warn_callback{defaultWarn};
1082
const char* WARN_PYTHON_DATAFLOW =
1083
" might cause the trace to be incorrect. We can't record the data flow of "
1084
"Python values, so this value will be treated as a constant in the future. "
1085
"This means that the trace might not generalize to other inputs!";
1086
const char* WARN_CONSTRUCTOR =
1087
" results are registered as constants in the trace. You can safely ignore this "
1088
"warning if you use this function to create tensors out of constant variables "
1089
"that would be the same every time you call this function. In any other case, "
1090
"this might cause the trace to be incorrect.";
1091
const char* WARN_RESIZE =
1092
" can't be represented in the JIT at the moment, so we won't connect any uses of "
1093
"this value with its current trace. If you happen to use it again, it will show "
1094
"up as a constant in the graph. Consider using `view` or `reshape` to make "
1096
const char* STRICT_TRACER_MSG =
1097
" might cause the trace to be incorrect, this is only valid if the container "
1098
"structure does not change based on the module's inputs. Consider using a constant "
1099
"container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a "
1100
"`NamedTuple` instead). If you absolutely need this and know the side effects, pass "
1101
"strict=False to trace() to allow this behavior.";
1102
// XXX: _kind can be a nullptr
1103
void _do_warn(const char* _reason, const char* _kind) {
1104
std::string reason{_reason};
1105
std::string kind{_kind ? _kind : ""};
1106
std::ostringstream s;
1107
s << reason << kind;
1108
warn_callback.load()(s.str());
1111
void setWarn(warn_fn_type fn) {
1112
warn_callback.store(fn);
1114
} // namespace torch::jit::tracer