pytorch

Форк
0
1114 строк · 36.4 Кб
1
#include <torch/csrc/jit/frontend/tracer.h>
2

3
#include <ATen/Backtrace.h>
4
#include <ATen/ScalarOps.h>
5
#include <ATen/TracerMode.h>
6
#include <ATen/core/Dict.h>
7
#include <ATen/core/functional.h>
8
#include <c10/util/Exception.h>
9
#include <c10/util/irange.h>
10
#include <torch/csrc/autograd/engine.h>
11
#include <torch/csrc/autograd/function.h>
12
#include <torch/csrc/autograd/variable.h>
13
#include <torch/csrc/jit/api/module.h>
14
#include <torch/csrc/jit/ir/constants.h>
15
#include <torch/csrc/jit/ir/ir.h>
16
#include <torch/csrc/jit/passes/dead_code_elimination.h>
17
#include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
18
#include <torch/csrc/jit/passes/inliner.h>
19
#include <torch/csrc/jit/passes/lower_tuples.h>
20
#include <torch/csrc/jit/passes/normalize_ops.h>
21
#include <torch/csrc/jit/passes/remove_expands.h>
22
#include <torch/csrc/utils/variadic.h>
23
#include <torch/custom_class.h>
24

25
#include <memory>
26
#include <sstream>
27
#include <string>
28

29
namespace torch::jit::tracer {
30

31
////////////////////////////////////////////////////////////////////////////////
32
// Recording the traces
33
////////////////////////////////////////////////////////////////////////////////
34
namespace detail {
35

36
template <typename T>
37
void genericAddInput(Node* n, T value) {
38
  Value* v = n->owningGraph()->insertConstant(value);
39
  recordSourceLocation(v->node());
40
  n->addInput(v);
41
}
42

43
template <typename T>
44
void genericAddOptionalInput(
45
    Node* n,
46
    const char* name,
47
    const c10::optional<T>& value) {
48
  if (value) {
49
    jit::tracer::addInputs(n, name, *value);
50
  } else {
51
    Graph* g = n->owningGraph();
52
    Value* none = g->insertNode(g->createNone())->output();
53
    n->addInput(none);
54
  }
55
}
56

57
template <typename T>
58
void badArgType(const T& v) {
59
  AT_ERROR(
60
      "Found an unsupported argument type in the JIT tracer: ",
61
      c10::demangle_type<T>(),
62
      ". File a bug report.");
63
}
64

65
thread_local std::shared_ptr<TracingState> tracing_state;
66
} // namespace detail
67

68
static std::atomic<bool> tracer_state_warn_mode{true};
69

70
std::atomic<bool>& getTracerStateWarnMode() {
71
  return tracer_state_warn_mode;
72
}
73

74
std::function<void()> pauseTracing() {
75
  // NOLINTNEXTLINE
76
  std::shared_ptr<tracer::TracingState> state = getTracingState();
77
  tracer::setTracingState(nullptr);
78

79
  return [state]() { tracer::setTracingState(state); };
80
}
81

82
void delValueTrace(const IValue& var) {
83
  getTracingState()->delValue(var);
84
}
85
void TracingState::delValue(const IValue& var) {
86
  for (const auto i : c10::irange(env_stack.size())) {
87
    auto& value_map = env_stack.at(env_stack.size() - 1 - i);
88
    auto it = value_map.find(var);
89
    if (it == value_map.end()) {
90
      continue;
91
    }
92
    value_map.erase(it);
93
  }
94
}
95

96
// Given a IValue 'var', return the 'node' which represents the instruction
97
// which computes the value of this variable in the IR.
98
// Here, we interpret untraced variables as constants that are just embedded
99
// in the graph.  This is useful to handle code which does things like this
100
// (from torch.autograd.variable, now moved to C++):
101
//
102
//    def mm(self, matrix):
103
//      output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
104
//      return Addmm.apply(output, self, matrix, 0, 1, True)
105
//
106
// Here, mm fakes up a dummy variable with uninitialized data to do an inplace
107
// update on, but subsequently ignores it because the alpha scaling factor is
108
// zero. This is one of the cases where a Variable can be created inside of a
109
// trace, and if we treat it as a constant, everything will work out.
110
Value* getValueTrace(const IValue& var) {
111
  return getTracingState()->getValue(var);
112
}
113
static Value* getOptTensorValueTrace(const c10::optional<at::Tensor>& var) {
114
  return getValueTrace(IValue(var));
115
}
116
Value* TracingState::getValue(const IValue& var) {
117
  // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...]
118
  // arguments
119
  if (var.isTensorList()) {
120
    return graph
121
        ->insertNode(graph->createList(
122
            TensorType::get(),
123
            fmap(
124
                var.toTensorVector(),
125
                [&](const IValue& val) { return getValue(val); })))
126
        ->output();
127
  } else if (var.isTuple()) {
128
    return graph
129
        ->insertNode(graph->createTuple(fmap(
130
            var.toTupleRef().elements(),
131
            [&](const IValue& val) { return getValue(val); })))
132
        ->output();
133
  } else if (var.isGenericDict()) {
134
    auto dict = var.toGenericDict();
135
    TypePtr key_type = dict.keyType();
136
    TypePtr value_type = dict.valueType();
137
    std::vector<Value*> keys;
138
    std::vector<Value*> values;
139
    for (const auto& entry : dict) {
140
      keys.emplace_back(getValue(entry.key()));
141
      values.emplace_back(getValue(entry.value()));
142
    }
143
    auto dict_node = graph->createDict(key_type, value_type, keys, values);
144
    return graph->insertNode(dict_node)->output();
145
  }
146
  if (var.isTensor()) {
147
    auto& ten = var.toTensor();
148
    if (!ten.defined()) {
149
      Node* n = graph->createNone();
150
      return graph->insertNode(n)->output();
151
    }
152
    for (const auto i : c10::irange(env_stack.size())) {
153
      auto& value_map = env_stack.at(env_stack.size() - 1 - i);
154
      auto it = value_map.find(var);
155
      if (it == value_map.end()) {
156
        continue;
157
      }
158
      if (!it->second->hasDebugName()) {
159
        auto unique_name = getTracingState()->lookup_var_name_fn(ten);
160
        if (!unique_name.empty()) {
161
          it->second->setDebugName(unique_name);
162
        }
163
      }
164
      return it->second;
165
    }
166

167
    // Didn't find it. Bake in a constant
168
    if (ten.requires_grad()) {
169
      pauseTracing();
170
      std::ostringstream oss;
171
      oss << "Cannot insert a Tensor that requires grad as a constant. "
172
          << "Consider making it a parameter or input, or detaching the gradient\n"
173
          << "Tensor:\n"
174
          << ten;
175
      throw std::runtime_error(oss.str());
176
    }
177

178
    Value* constant = graph->insertConstant(ten);
179
    recordSourceLocation(constant->node());
180
    constant->inferTypeFrom(ten);
181
    auto it = env_stack.back().emplace(var, constant);
182
    return it.first->second;
183
  } else if (var.isFuture() || var.isObject()) {
184
    for (const auto i : c10::irange(env_stack.size())) {
185
      auto& future_map = env_stack.at(env_stack.size() - 1 - i);
186
      auto it = future_map.find(var);
187
      if (it == future_map.end()) {
188
        continue;
189
      }
190
      return it->second;
191
    }
192

193
    // Find torchbind classes
194
    if (isCustomClass(var)) {
195
      auto obj = Object(var.toObject());
196
      auto qualname = obj.type()->name();
197
      auto custom_class_type = getCustomClass(qualname->qualifiedName());
198
      if (custom_class_type) {
199
        auto capsule = var.toObject()->getAttr("capsule");
200
        for (const auto i : c10::irange(env_stack.size())) {
201
          auto& value_map = env_stack.at(env_stack.size() - 1 - i);
202
          auto it = value_map.find(capsule);
203
          if (it == value_map.end()) {
204
            continue;
205
          }
206
          return it->second;
207
        }
208
      }
209
    }
210

211
    std::ostringstream oss;
212
    if (var.isFuture()) {
213
      oss << "Tried to trace Future or Object that the tracer was not aware of.";
214
    } else {
215
      oss << "Tried to trace " << var
216
          << " but it is not part of the active trace. Modules that are called during a trace"
217
          << " must be registered as submodules of the thing being traced.";
218
    }
219
    throw std::runtime_error(oss.str());
220
  } else {
221
    // If the values are non-tensors, we try to create constants
222
    // and bake those constants into the traced graph
223
    auto constant = tryInsertConstant(*graph, var);
224
    if (constant) {
225
      recordSourceLocation(constant.value()->node());
226
      return *constant;
227
    }
228
    std::ostringstream os;
229
    os << "Tracer cannot get value trace for type " << var.tagKind() << ". "
230
       << "The below value could not be materialized as a constant:\n"
231
       << var;
232
    throw std::runtime_error(os.str());
233
  }
234
}
235
bool TracingState::hasValue(const IValue& var) const {
236
  for (const auto& frame : env_stack) {
237
    if (frame.count(var)) {
238
      return true;
239
    }
240
  }
241
  return false;
242
}
243

244
Value* TracingState::getOutput(const IValue& iv, size_t i) {
245
  bool tracing_mode_strict = getTracingState()->strict;
246
  if (iv.isTensor()) {
247
    const at::Tensor& var = iv.toTensor();
248
    if (!var.defined()) {
249
      Node* n = graph->createNone();
250
      return graph->insertNode(n)->output();
251
    }
252

253
    auto& value_map = getTracingState()->env_stack.back();
254
    auto it = value_map.find(iv);
255
    if (it == value_map.end()) {
256
      std::ostringstream os;
257
      os << "output " << i << " (" << var
258
         << ") of traced region did not have observable "
259
         << "data dependence with trace inputs; this probably indicates your "
260
            "program "
261
         << "cannot be understood by the tracer.";
262
      throw std::runtime_error(os.str());
263
    }
264
    return it->second;
265
  } else if (iv.isTensorList()) {
266
    if (tracing_mode_strict) {
267
      tracer::warn(
268
          "Encountering a list at the output of the tracer", STRICT_TRACER_MSG);
269
    }
270
    return graph
271
        ->insertNode(graph->createList(
272
            TensorType::get(),
273
            fmap(
274
                iv.toTensorVector(),
275
                [&](const IValue& ival) { return getOutput(ival, i); })))
276
        ->output();
277
  } else if (iv.isTuple()) {
278
    const auto& tuple = iv.toTupleRef().elements();
279
    auto tuple_node = graph->createTuple(
280
        fmap(tuple, [&](const IValue& ival) { return getOutput(ival, i); }));
281
    graph->insertNode(tuple_node);
282
    return tuple_node->output();
283
  } else if (iv.isGenericDict()) {
284
    if (tracing_mode_strict) {
285
      throw std::runtime_error(
286
          "Encountering a dict at the output of the tracer" +
287
          std::string(STRICT_TRACER_MSG));
288
    }
289
    auto dict = iv.toGenericDict();
290
    TypePtr key_type = dict.keyType();
291
    TypePtr value_type = dict.valueType();
292

293
    bool key_type_valid = key_type->isSubtypeOf(*StringType::get()) ||
294
        key_type->isSubtypeOf(*TensorType::get());
295
    bool value_type_valid = value_type->isSubtypeOf(*TensorType::get());
296

297
    // Support tuple values that contain only tensors
298
    if (value_type->isSubtypeOf(*AnyTupleType::get())) {
299
      value_type_valid = true;
300
      for (const auto& type : value_type->containedTypes()) {
301
        if (!type->isSubtypeOf(*TensorType::get())) {
302
          value_type_valid = false;
303
          break;
304
        }
305
      }
306
    }
307

308
    if (!key_type_valid || !value_type_valid) {
309
      std::ostringstream os;
310
      os << "output " << i << " (" << dict << ") of traced region "
311
         << "cannot be understood by the tracer, only outputs matching"
312
         << "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
313
         << "can be a dictionary output of a traced function";
314
      throw std::runtime_error(os.str());
315
    }
316
    std::vector<Value*> keys;
317
    std::vector<Value*> values;
318
    for (const auto& entry : dict) {
319
      keys.emplace_back(getValue(entry.key()));
320
      values.emplace_back(getOutput(entry.value(), i));
321
    }
322
    auto dict_node = graph->createDict(key_type, value_type, keys, values);
323
    graph->insertNode(dict_node);
324
    return dict_node->output();
325
  } else {
326
    AT_ERROR(
327
        "Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions");
328
  }
329
}
330

331
Node* TracingState::createNode(c10::Symbol op_name, size_t num_outputs) {
332
  return graph->create(op_name, num_outputs);
333
}
334

335
void TracingState::insertNode(Node* node) {
336
  graph->insertNode(node);
337
}
338

339
// XXX: this function mutates input
340
static IValue addInput(
341
    const std::shared_ptr<TracingState>& state,
342
    const IValue& input,
343
    const TypePtr& type,
344
    Value* value) {
345
  value->setType(type);
346
  if (type->isSubtypeOf(*TensorType::get())) {
347
    auto input_tensor = input.toTensor();
348
    auto name = Variable(input_tensor).name();
349
    if (state->hasValue(input)) {
350
      input_tensor = input_tensor.view(input_tensor.sizes());
351
    }
352
    if (!value->hasDebugName()) {
353
      value->setDebugName(name);
354
    }
355
    state->setValue(input_tensor, value);
356
    return input_tensor;
357
  } else if (auto tuple_type = type->cast<TupleType>()) {
358
    auto unpack_node =
359
        state->graph->insertNode(state->graph->createTupleUnpack(value));
360
    auto elem_values = unpack_node->outputs();
361
    auto elem_types = tuple_type->elements();
362
    auto tuple = input.toTuple();
363
    const auto& elems = tuple->elements();
364
    size_t num_elems = elems.size();
365
    AT_ASSERT(
366
        elem_values.size() == num_elems && elem_types.size() == num_elems);
367
    for (const auto i : c10::irange(num_elems)) {
368
      tuple->unsafeSetElement(
369
          i, addInput(state, elems.at(i), elem_types[i], elem_values[i]));
370
    }
371
    return tuple;
372
  } else if (auto dict_type = type->cast<DictType>()) {
373
    auto dict = input.toGenericDict();
374

375
    // Unpack the list values statically
376
    for (const auto& entry : dict) {
377
      const IValue& key = entry.key();
378
      auto static_key = state->graph->insertConstant(key);
379
      auto static_value =
380
          state->graph->insert(aten::__getitem__, {value, static_key});
381
      recordSourceLocation(static_value->node());
382
      dict.insert_or_assign(
383
          entry.key(),
384
          addInput(
385
              state, entry.value(), dict_type->getValueType(), static_value));
386
    }
387

388
    return dict;
389
  } else if (auto list_type = type->cast<ListType>()) {
390
    size_t num_elems = input.isList() ? input.toListRef().size()
391
                                      : input.toTensorVector().size();
392
    auto list_unpack = state->graph->insertNode(
393
        state->graph->createListUnpack(value, num_elems));
394
    auto unpack_outputs = list_unpack->outputs();
395

396
    if (input.isTensorList()) {
397
      auto elems = input.toTensorList();
398
      for (const auto i : c10::irange(num_elems)) {
399
        elems[i] = addInput(
400
                       state,
401
                       elems.get(i),
402
                       list_type->getElementType(),
403
                       unpack_outputs[i])
404
                       .toTensor();
405
      }
406
      return elems;
407
    } else {
408
      auto elems = input.toList();
409
      for (const auto i : c10::irange(num_elems)) {
410
        elems[i] = addInput(
411
            state,
412
            elems.get(i),
413
            list_type->getElementType(),
414
            unpack_outputs[i]);
415
      }
416
      return elems;
417
    }
418
  } else {
419
    AT_ERROR(
420
        "Only tensors or (possibly nested) dict or tuples of tensors can be "
421
        "inputs to traced functions. Got ",
422
        type->repr_str());
423
  }
424
}
425

426
static void gatherParametersAndBuffers(
427
    const std::shared_ptr<TracingState>& state,
428
    Value* self_value,
429
    const Module& self,
430
    const std::string& prefix) {
431
  Graph& g = *self_value->owningGraph();
432

433
  state->setValue(self._ivalue(), self_value);
434

435
  auto self_ty = self.type();
436
  for (const NameValue& s : self.named_attributes(/*recurse=*/false)) {
437
    auto qualname = prefix + "." + s.name;
438
    Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr))
439
                                ->s_(attr::scope, qualname)
440
                                ->output()
441
                                ->setType(s.value.type());
442
    if (s.value.type()->isSubtypeOf(*TensorType::get())) {
443
      addInput(state, s.value, s.value.type(), trace_get_attr);
444
    }
445
    if (isCustomClass(s.value)) {
446
      tracer::setValueTrace(s.value, trace_get_attr);
447
    }
448

449
    auto attr_type = self_ty->getAttribute(s.name);
450
    // Skipping Parameters and Buffers that are behind an `InterfaceType`
451
    // because it is illegal for InterfaceType to expose any attribute.
452
    // And these attributes should never be used/exposed outside of
453
    // InterfaceType'd module anyway.
454
    if (attr_type->is_module() &&
455
        attr_type->kind() != TypeKind::InterfaceType) {
456
      gatherParametersAndBuffers(
457
          state, trace_get_attr, Module(s.value.toObject()), qualname);
458
    }
459
  }
460
}
461

462
std::pair<std::shared_ptr<TracingState>, Stack> trace(
463
    Stack inputs,
464
    const std::function<Stack(Stack)>& traced_fn,
465
    std::function<std::string(const Variable&)> var_name_lookup_fn,
466
    bool strict,
467
    bool force_outplace,
468
    Module* self,
469
    const std::vector<std::string>& argument_names) {
470
  try {
471
    // Start tracing, treating 'inputs' as inputs to the trace, which can be
472
    // varied on subsequent invocations of the trace.  Any other variables
473
    // will be treated as constants.
474
    if (isTracing()) {
475
      AT_ERROR("Tracing can't be nested");
476
    }
477
    auto state = std::make_shared<TracingState>();
478
    setTracingState(state);
479

480
    // if we are a module, then make sure the modules parameters are in the map
481
    // and mapped to accesses to the self object
482
    if (self) {
483
      Value* self_value = state->graph->insertInput(0, "self")->setType(
484
          self->_ivalue()->type());
485
      gatherParametersAndBuffers(state, self_value, *self, {"__module"});
486
    }
487

488
    // When enough argument name hints are provided, use them as debug names
489
    // for traced function/modules.
490
    // Here argument_names is allowed to have more names than needed because
491
    // some arguments may have valid default values, therefore they don't need
492
    // example inputs.
493
    if (argument_names.size() >= inputs.size()) {
494
      for (size_t i = 0, e = inputs.size(); i < e; ++i) {
495
        IValue& input = inputs[i];
496
        input = addInput(
497
            state,
498
            input,
499
            input.type(),
500
            state->graph->addInput(argument_names[i]));
501
      }
502
    } else {
503
      for (IValue& input : inputs) {
504
        input = addInput(state, input, input.type(), state->graph->addInput());
505
      }
506
    }
507

508
    auto graph = state->graph;
509

510
    getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);
511
    getTracingState()->strict = strict;
512
    getTracingState()->force_outplace = force_outplace;
513

514
    // Invoke the traced function
515
    auto out_stack = traced_fn(inputs);
516

517
    // Exit a trace, treating 'out_stack' as the outputs of the trace.  These
518
    // are the variables whose values will be computed upon subsequent
519
    // invocations of the trace.
520
    size_t i = 0;
521
    for (auto& output : out_stack) {
522
      // NB: The stack is in "reverse" order, so when we pass the diagnostic
523
      // number we need to flip it based on size.
524
      state->graph->registerOutput(
525
          state->getOutput(output, out_stack.size() - i));
526
      i++;
527
    }
528
    setTracingState(nullptr);
529

530
    if (getInlineEverythingMode()) {
531
      Inline(*graph);
532
    }
533
    FixupTraceScopeBlocks(graph, self);
534
    NormalizeOps(graph);
535
    return {state, out_stack};
536
  } catch (...) {
537
    tracer::abandon();
538
    throw;
539
  }
540
}
541

542
// Abort tracing. Used to reset the state in case of errors.
543
void abandon() {
544
  setTracingState(nullptr);
545
}
546

547
void setValueTrace(const IValue& v, Value* value) {
548
  return getTracingState()->setValue(v, value);
549
}
550
void TracingState::setValue(const IValue& v, Value* value) {
551
  if (v.isTensor()) {
552
    auto& var = v.toTensor();
553
    AT_ASSERT(var.defined());
554
    env_stack.back()[v] = value;
555

556
    // If the value comes from a CallFunction or CallMethod, it may not have
557
    // shape information attached. For debuggability, we enhance the type
558
    // information by assigning the concrete value's tupe to the jit::Value.
559
    if (auto tensor_type = value->type()->cast<TensorType>()) {
560
      if (!tensor_type->isComplete()) {
561
        value->inferTypeFrom(var);
562
      }
563
    }
564
  } else if (v.isTensorList()) {
565
    auto outputs = v.toTensorList();
566
    Node* unpack_node =
567
        graph->insertNode(graph->createListUnpack(value, outputs.size()));
568
    for (const auto i : c10::irange(outputs.size())) {
569
      setValue(outputs.get(i), unpack_node->outputs()[i]);
570
    }
571
  } else if (v.isTuple()) {
572
    const auto& outputs = v.toTupleRef().elements();
573
    Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
574
    for (const auto i : c10::irange(outputs.size())) {
575
      setValue(outputs[i], unpack_node->outputs()[i]);
576
    }
577
  } else if (v.isList()) {
578
    auto elements = v.toListRef();
579
    Node* unpack_node =
580
        graph->insertNode(graph->createListUnpack(value, elements.size()));
581
    for (const auto i : c10::irange(elements.size())) {
582
      setValue(elements[i], unpack_node->outputs()[i]);
583
    }
584
  } else if (isCustomClass(v)) {
585
    auto capsule = v.toObject()->getAttr("capsule");
586
    env_stack.back()[capsule] = value;
587
  } else if (v.isFuture() || v.isObject()) {
588
    env_stack.back()[v] = value;
589
  } else if (v.isGenericDict()) {
590
    auto dict = v.toGenericDict();
591
    TypePtr key_type = dict.keyType();
592
    TypePtr value_type = dict.valueType();
593
    for (const auto& entry : dict) {
594
      auto static_key = graph->insertConstant(entry.key());
595
      auto static_value = graph->insert(aten::__getitem__, {value, static_key});
596
      setValue(entry.value(), static_value);
597
    }
598
  } else {
599
    std::ostringstream os;
600
    os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
601
       << "Supported types are tensor, tensor list, and tuple of tensors.";
602
    throw std::runtime_error(os.str());
603
  }
604
}
605

606
void addInputs(Node* n, const char* name, int64_t value) {
607
  using ArgumentStash = jit::tracer::ArgumentStash;
608
  if (ArgumentStash::hasValue(name)) {
609
    Value* v = ArgumentStash::popValue(name);
610
    n->addInput(v);
611
  } else {
612
    detail::genericAddInput(n, value);
613
  }
614
}
615

616
void addInputs(Node* n, const char* name, c10::SymInt value) {
617
  addInputs(n, name, value.guard_int(__FILE__, __LINE__));
618
}
619

620
void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
621
  using ArgumentStash = jit::tracer::ArgumentStash;
622
  if (ArgumentStash::hasValue(name)) {
623
    Value* v = ArgumentStash::popValue(name);
624
    n->addInput(v);
625
  } else if (value) {
626
    detail::genericAddInput(n, *value);
627
  } else {
628
    Graph* g = n->owningGraph();
629
    Value* none = g->insertNode(g->createNone())->output();
630
    n->addInput(none);
631
  }
632
}
633
void addInputs(Node* n, const char* name, bool value) {
634
  detail::genericAddInput(n, value);
635
}
636
void addInputs(Node* n, const char* name, const c10::optional<bool>& value) {
637
  detail::genericAddOptionalInput(n, name, value);
638
}
639
void addInputs(Node* n, const char* name, double value) {
640
  detail::genericAddInput(n, value);
641
}
642
void addInputs(Node* n, const char* name, const c10::optional<double>& value) {
643
  detail::genericAddOptionalInput(n, name, value);
644
}
645
void addInputs(Node* n, const char* name, const at::Scalar& value) {
646
  using ArgumentStash = jit::tracer::ArgumentStash;
647
  if (ArgumentStash::hasValue(name)) {
648
    Value* v = ArgumentStash::popValue(name);
649
    n->addInput(v);
650
  } else {
651
    detail::genericAddInput(n, value);
652
  }
653
}
654
void addInputs(
655
    Node* n,
656
    const char* name,
657
    const c10::optional<at::Scalar>& value) {
658
  detail::genericAddOptionalInput(n, name, value);
659
}
660
void addInputs(Node* n, const char* name, const c10::string_view value) {
661
  detail::genericAddInput(n, std::string(value));
662
}
663
void addInputs(
664
    Node* n,
665
    const char* name,
666
    const c10::optional<c10::string_view>& value) {
667
  detail::genericAddOptionalInput(n, name, value);
668
}
669
void addInputs(Node* n, const char* name, const at::Tensor& value) {
670
  n->addInput(getValueTrace(value));
671
}
672
void addInputs(
673
    Node* n,
674
    const char* name,
675
    const c10::optional<at::Tensor>& value) {
676
  detail::genericAddOptionalInput(n, name, value);
677
}
678
void addInputs(
679
    Node* n,
680
    const char* name,
681
    const c10::optional<at::Generator>& value) {
682
  Graph* g = n->owningGraph();
683

684
  if (value.has_value() && value->defined()) {
685
    detail::genericAddInput(n, *value);
686
  } else {
687
    Value* undef_gen = g->insertNode(g->createNone())->output();
688
    n->addInput(undef_gen);
689
  }
690
}
691
void addInputs(Node* n, const char* name, at::Device value) {
692
  detail::genericAddInput(n, value);
693
}
694
void addInputs(Node* n, const char* name, c10::Stream stream) {
695
  detail::genericAddInput(n, c10::IValue(stream));
696
}
697
void addInputs(Node* n, const char* name, at::Layout value) {
698
  detail::genericAddInput(n, static_cast<int64_t>(value));
699
}
700
void addInputs(Node* n, const char* name, at::ScalarType value) {
701
  detail::genericAddInput(n, static_cast<int64_t>(value));
702
}
703
void addInputs(Node* n, const char* name, at::MemoryFormat value) {
704
  detail::genericAddInput(n, static_cast<int64_t>(value));
705
}
706
void addInputs(
707
    Node* n,
708
    const char* name,
709
    const c10::optional<at::MemoryFormat>& value) {
710
  detail::genericAddOptionalInput(n, name, value);
711
}
712
void addInputs(
713
    Node* n,
714
    const char* name,
715
    const c10::optional<at::Layout>& value) {
716
  detail::genericAddOptionalInput(n, name, value);
717
}
718
void addInputs(
719
    Node* n,
720
    const char* name,
721
    const c10::optional<at::Device>& value) {
722
  detail::genericAddOptionalInput(n, name, value);
723
}
724
void addInputs(
725
    Node* n,
726
    const char* name,
727
    c10::optional<at::DimnameList> value) {
728
  TORCH_CHECK(false, "NYI: Named tensors are not supported with the tracer");
729
}
730
void addInputs(
731
    Node* n,
732
    const char* name,
733
    const c10::optional<at::ScalarType>& value) {
734
  detail::genericAddOptionalInput(n, name, value);
735
}
736
void addInputs(
737
    Node* n,
738
    const char* name,
739
    at::ArrayRef<at::Tensor> value,
740
    bool allow_undefined) {
741
  addInputs(n, name, at::ITensorListRef(value), allow_undefined);
742
}
743
void addInputs(
744
    Node* n,
745
    const char* name,
746
    std::vector<at::Tensor> value,
747
    bool allow_undefined) {
748
  addInputs(n, name, at::ITensorListRef(value), allow_undefined);
749
}
750
void addInputs(
751
    Node* n,
752
    const char* name,
753
    at::ITensorListRef value,
754
    bool allow_undefined) {
755
  Graph* g = n->owningGraph();
756
  Node* list_node = nullptr;
757
  if (allow_undefined) {
758
    // if allow undefined, we create a list of optional tensors
759
    list_node = g->insertNode(
760
        g->createList(OptionalType::ofTensor(), fmap(value, getValueTrace)));
761
  } else {
762
    list_node = g->insertNode(
763
        g->createList(TensorType::get(), fmap(value, getValueTrace)));
764
  }
765
  n->addInput(list_node->output());
766
}
767
TORCH_API void addInputs(
768
    Node* n,
769
    const char* name,
770
    const List<c10::optional<at::Tensor>>& value) {
771
  Graph* g = n->owningGraph();
772
  Node* list_node = nullptr;
773
  list_node = g->insertNode(g->createList(
774
      OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace)));
775
  n->addInput(list_node->output());
776
}
777
void addInputs(
778
    Node* n,
779
    const char* name,
780
    ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value,
781
    const ClassTypePtr& class_type) {
782
  Graph* g = n->owningGraph();
783
  Node* list_node =
784
      g->insertNode(g->createList(class_type, fmap(value, getValueTrace)));
785
  n->addInput(list_node->output());
786
}
787

788
void addInputs(Node* n, const char* name, at::IntArrayRef value) {
789
  using ArgumentStash = jit::tracer::ArgumentStash;
790
  std::vector<Value*> info = ArgumentStash::hasIntArrayRef(name)
791
      ? ArgumentStash::popIntArrayRef(name)
792
      : ArgumentStash::IntArrayRefTrace(value.size());
793

794
  auto& g = getTracingState()->graph;
795
  for (const auto i : c10::irange(info.size())) {
796
    if (info[i] != nullptr)
797
      continue;
798
    info[i] = g->insertConstant(value[i]);
799
    recordSourceLocation(info[i]->node());
800
  }
801
  for (jit::Value* v : info) {
802
    if (*v->type() != *jit::IntType::get()) {
803
      throw std::runtime_error(
804
          "Type mismatch in setposattr for IntArrayRef. Check that your program "
805
          "is valid without tracing, and please file a bug report if it is.");
806
    }
807
  }
808
  n->addInput(
809
      g->insertNode(g->createList(jit::IntType::get(), info))->output());
810
}
811

812
void addInputs(Node* n, const char* name, c10::SymIntArrayRef value) {
813
  addInputs(n, name, C10_AS_INTARRAYREF_SLOW(value));
814
}
815

816
void addInputs(Node* n, const char* name, c10::optional<c10::SymInt> value) {
817
  addInputs(
818
      n,
819
      name,
820
      value.has_value()
821
          ? c10::make_optional(value->guard_int(__FILE__, __LINE__))
822
          : c10::nullopt);
823
}
824

825
void addInputs(
826
    Node* n,
827
    const char* name,
828
    const c10::optional<at::IntArrayRef>& opt_value) {
829
  detail::genericAddOptionalInput(n, name, opt_value);
830
}
831

832
void addInputs(
833
    Node* n,
834
    const char* name,
835
    const at::OptionalIntArrayRef& opt_value) {
836
  if (opt_value.has_value()) {
837
    jit::tracer::addInputs(n, name, *opt_value);
838
  } else {
839
    Graph* g = n->owningGraph();
840
    Value* none = g->insertNode(g->createNone())->output();
841
    n->addInput(none);
842
  }
843
}
844

845
void addInputs(
846
    Node* n,
847
    const char* name,
848
    const at::OptionalSymIntArrayRef& opt_value) {
849
  if (opt_value.has_value()) {
850
    jit::tracer::addInputs(n, name, *opt_value);
851
  } else {
852
    Graph* g = n->owningGraph();
853
    Value* none = g->insertNode(g->createNone())->output();
854
    n->addInput(none);
855
  }
856
}
857

858
void addInputs(Node* n, const char* name, ArrayRef<double> value) {
859
  std::vector<Value*> info;
860
  auto& g = getTracingState()->graph;
861
  for (double elt : value) {
862
    info.push_back(g->insertConstant(elt));
863
    recordSourceLocation(info.back()->node());
864
  }
865
  n->addInput(
866
      g->insertNode(g->createList(jit::FloatType::get(), info))->output());
867
}
868

869
void addInputs(
870
    Node* n,
871
    const char* name,
872
    const c10::optional<c10::ArrayRef<double>>& opt_value) {
873
  detail::genericAddOptionalInput(n, name, opt_value);
874
}
875

876
void addInputs(
877
    Node* n,
878
    const char* name,
879
    const c10::intrusive_ptr<c10::ivalue::Object>& obj) {
880
  Value* v = getValueTrace(obj);
881
  n->addInput(v);
882
}
883

884
void addOutput(Node* node, const at::Tensor& output) {
885
  setOutput(node->addOutput(), output);
886
}
887

888
void setOutput(Value* value, const at::Tensor& output) {
889
  if (output.defined()) {
890
    value->inferTypeFrom(output);
891
    setValueTrace(output, value);
892
  }
893
}
894

895
void addOutput(Node* node, const std::vector<at::Tensor>& outputs) {
896
  Value* value = node->addOutput()->setType(ListType::ofTensors());
897
  Graph* graph = node->owningGraph();
898
  Node* unpack_node = graph->insertNode(
899
      graph->create(prim::ListUnpack, {value}, outputs.size()));
900
  for (const auto i : c10::irange(outputs.size())) {
901
    Value* output_val = unpack_node->outputs()[i];
902
    output_val->inferTypeFrom(outputs[i]);
903
    setValueTrace(outputs[i], output_val);
904
  }
905
}
906

907
void addOutput(Node* node, const c10::List<at::Tensor>& outputs) {
908
  return addOutput(node, outputs.vec());
909
}
910

911
void addOutput(
912
    Node* node,
913
    const c10::intrusive_ptr<c10::ivalue::Object>& output) {
914
  Value* output_val = node->addOutput();
915
  output_val->inferTypeFrom(output);
916
  setValueTrace(output, output_val);
917
}
918

919
const std::shared_ptr<TracingState>& getTracingState() {
920
  return detail::tracing_state;
921
}
922

923
void setTracingState(std::shared_ptr<TracingState> state) {
924
  at::tracer::impl::set_dispatch_enabled(state != nullptr);
925
  detail::tracing_state = std::move(state);
926
}
927

928
TracingState::TracingState() : graph(new Graph()), env_stack{Frame()} {}
929

930
TracingState::~TracingState() = default;
931

932
autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
933
  auto& tracing_state = getTracingState();
934
  auto& graph = tracing_state->graph;
935

936
  Variable size_var;
937
  {
938
    // Make sure this scalar to tensor isn't traced!
939
    at::AutoDispatchBelowADInplaceOrView guard;
940
    size_var = scalar_to_tensor(at::Scalar(var.size(dim)));
941
  }
942
  auto* value = getValueTrace(var);
943
  auto dim_val = graph->insertConstant(dim);
944
  recordSourceLocation(dim_val->node());
945
  auto* node = graph->insertNode(graph->create(aten::size, {value, dim_val}));
946
  recordSourceLocation(node);
947
  node->output()->setType(jit::IntType::get());
948

949
  auto ten =
950
      graph->insertNode(graph->createNumToTensor(node->output()))->output();
951
  setValueTrace(size_var, ten);
952
  return size_var;
953
}
954

955
autograd::Variable getNumelOf(const autograd::Variable& var) {
956
  auto& tracing_state = getTracingState();
957
  auto& graph = tracing_state->graph;
958

959
  Variable numel_var;
960
  {
961
    // Make sure this scalar to tensor isn't traced!
962
    at::AutoDispatchBelowADInplaceOrView guard;
963
    numel_var = scalar_to_tensor(at::Scalar(var.numel()));
964
  }
965
  auto* value = getValueTrace(var);
966
  auto* node = graph->insertNode(graph->create(Symbol::aten("numel"), {value}));
967
  recordSourceLocation(node);
968
  node->output()->setType(jit::IntType::get());
969

970
  auto ten =
971
      graph->insertNode(graph->createNumToTensor(node->output()))->output();
972
  setValueTrace(numel_var, ten);
973
  return numel_var;
974
}
975

976
void ensureUniqueIfOutOfPlaced(const char* name, const at::Tensor& tensor) {
977
  auto& state = getTracingState();
978
  if (state && state->force_outplace == false) {
979
    // If we're not converting in-place ops to out-of-place, this check is
980
    // unnecessary
981
    return;
982
  }
983
  auto aliases = tensor.storage().use_count();
984
  if (isTracing() && aliases > 1) {
985
    std::stringstream ss;
986
    ss << "There are " << aliases
987
       << " live references to the data region being modified when tracing in-place operator "
988
       << name
989
       << ". This might cause the trace to be incorrect, because all other views "
990
       << "that also reference this data will not reflect this change in the trace! "
991
       << "On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. "
992
       << "are outputs of torch.split), this might still be safe.";
993
    warn(ss.str().c_str());
994
  }
995
}
996
void ensureUniqueIfOutOfPlaced(
997
    const char* name,
998
    const c10::optional<at::Tensor>& tensor) {
999
  ensureUniqueIfOutOfPlaced(name, tensor.has_value() ? *tensor : at::Tensor());
1000
}
1001

1002
////////////////////////////////////////////////////////////////////////////////
1003
// Argument stash
1004
////////////////////////////////////////////////////////////////////////////////
1005
thread_local ArgumentStash ArgumentStash::stash;
1006

1007
void ArgumentStash::stashIntArrayRefElem(
1008
    const std::string& arg_name,
1009
    size_t size,
1010
    size_t idx,
1011
    const Variable& var) {
1012
  // TODO: check type?
1013
  if (!isTracing())
1014
    return;
1015
  IntArrayRefTrace& list_trace =
1016
      stash.intlists.emplace(arg_name, size).first->second;
1017
  AT_ASSERT(size == list_trace.size());
1018
  AT_ASSERT(idx < list_trace.size());
1019
  AT_ASSERT(list_trace[idx] == nullptr);
1020

1021
  Value* ten = getValueTrace(var);
1022
  auto& g = *ten->owningGraph();
1023
  WithInsertPoint guard(ten->node()->next());
1024
  auto prim = g.insert(aten::Int, {ten});
1025
  list_trace[idx] = prim;
1026
}
1027

1028
void ArgumentStash::stashValue(
1029
    const std::string& arg_name,
1030
    size_t idx,
1031
    const Variable& var,
1032
    const TypePtr& type) {
1033
  if (!isTracing())
1034
    return;
1035

1036
  Value* ten = getValueTrace(var);
1037
  WithInsertPoint guard(ten->node()->next());
1038
  auto& g = *ten->owningGraph();
1039

1040
  if (type == IntType::get()) {
1041
    ten = g.insert(aten::Int, {ten});
1042
  } else if (type == FloatType::get()) {
1043
    ten = g.insert(aten::Float, {ten});
1044
  } else if (type == NumberType::get()) {
1045
    ten = g.insert(aten::ScalarImplicit, {ten});
1046
  }
1047

1048
  stash.values.emplace(arg_name, ten);
1049
}
1050

1051
////////////////////////////////////////////////////////////////////////////////
1052
// Stack trace recording
1053
////////////////////////////////////////////////////////////////////////////////
1054
// no python present so we just do not record source information
1055
static void defaultRecordSourceLocation(Node* n) {}
1056
std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(
1057
    defaultRecordSourceLocation);
1058
void recordSourceLocation(Node* n) {
1059
  return record_source_location.load()(n);
1060
}
1061
void setRecordSourceLocation(void (*v)(Node*)) {
1062
  record_source_location.store(v);
1063
}
1064

1065
static std::vector<StackEntry> defaultPythonCallstack() {
1066
  return std::vector<StackEntry>();
1067
}
1068
std::atomic<decltype(&defaultPythonCallstack)> python_callstack_fn(
1069
    defaultPythonCallstack);
1070
std::vector<StackEntry> pythonCallstack() {
1071
  return python_callstack_fn.load()();
1072
}
1073
void setPythonCallstack(std::vector<StackEntry> (*v)()) {
1074
  python_callstack_fn.store(v);
1075
}
1076

1077
static void defaultWarn(const std::string& str) {
1078
  TORCH_WARN(str);
1079
}
1080
std::atomic<warn_fn_type> warn_callback{defaultWarn};
1081

1082
const char* WARN_PYTHON_DATAFLOW =
1083
    " might cause the trace to be incorrect. We can't record the data flow of "
1084
    "Python values, so this value will be treated as a constant in the future. "
1085
    "This means that the trace might not generalize to other inputs!";
1086
const char* WARN_CONSTRUCTOR =
1087
    " results are registered as constants in the trace. You can safely ignore this "
1088
    "warning if you use this function to create tensors out of constant variables "
1089
    "that would be the same every time you call this function. In any other case, "
1090
    "this might cause the trace to be incorrect.";
1091
const char* WARN_RESIZE =
1092
    " can't be represented in the JIT at the moment, so we won't connect any uses of "
1093
    "this value with its current trace. If you happen to use it again, it will show "
1094
    "up as a constant in the graph. Consider using `view` or `reshape` to make "
1095
    "it traceable.";
1096
const char* STRICT_TRACER_MSG =
1097
    " might cause the trace to be incorrect, this is only valid if the container "
1098
    "structure does not change based on the module's inputs. Consider using a constant "
1099
    "container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a "
1100
    "`NamedTuple` instead). If you absolutely need this and know the side effects, pass "
1101
    "strict=False to trace() to allow this behavior.";
1102
// XXX: _kind can be a nullptr
1103
void _do_warn(const char* _reason, const char* _kind) {
1104
  std::string reason{_reason};
1105
  std::string kind{_kind ? _kind : ""};
1106
  std::ostringstream s;
1107
  s << reason << kind;
1108
  warn_callback.load()(s.str());
1109
}
1110

1111
void setWarn(warn_fn_type fn) {
1112
  warn_callback.store(fn);
1113
}
1114
} // namespace torch::jit::tracer
1115

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.