pytorch

Форк
0
/
sugared_value.cpp 
798 строк · 27.4 Кб
1
#include <torch/csrc/jit/frontend/sugared_value.h>
2

3
#include <c10/util/irange.h>
4
#include <torch/csrc/jit/frontend/schema_matching.h>
5
#include <torch/csrc/jit/frontend/tree_views.h>
6
#include <torch/csrc/jit/ir/ir.h>
7
#include <torch/csrc/jit/passes/constant_propagation.h>
8

9
namespace torch::jit {
10

11
struct NoneValue : SugaredValue {
12
  NoneValue() = default;
13
  std::string kind() const override {
14
    return "None";
15
  }
16
};
17

18
std::shared_ptr<SugaredValue> PrintValue::call(
19
    const SourceRange& loc,
20
    GraphFunction& m,
21
    at::ArrayRef<NamedValue> args,
22
    at::ArrayRef<NamedValue> kwargs,
23
    size_t n_binders) {
24
  auto& g = *m.graph();
25
  if (!kwargs.empty())
26
    throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
27

28
  std::vector<Value*> lowered_inputs = toValues(*m.graph(), args);
29
  g.insertNode(g.create(prim::Print, lowered_inputs, 0)->setSourceRange(loc));
30
  return std::make_shared<NoneValue>();
31
}
32

33
static const std::unordered_map<std::string, at::ScalarType>&
34
builtin_cast_method_to_scalar_type() {
35
  static std::unordered_map<std::string, at::ScalarType> mapping = {
36
      {"byte", at::kByte},
37
      {"char", at::kChar},
38
      {"double", at::kDouble},
39
      {"float", at::kFloat},
40
      {"cfloat", at::kComplexFloat},
41
      {"cdouble", at::kComplexDouble},
42
      {"int", at::kInt},
43
      {"long", at::kLong},
44
      {"short", at::kShort},
45
      {"half", at::kHalf}};
46
  return mapping;
47
}
48

49
std::shared_ptr<SugaredValue> BuiltinFunction::call(
50
    const SourceRange& loc,
51
    GraphFunction& m,
52
    at::ArrayRef<NamedValue> args,
53
    at::ArrayRef<NamedValue> kwargs,
54
    size_t n_binders) {
55
  return std::make_shared<SimpleValue>(
56
      emitBuiltinCall(loc, *m.graph(), symbol, args, kwargs, self));
57
}
58

59
// older versions of gcc/clang have a bug where enums can't be used as keys
60
// in a map by default
61
// https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key
62
struct EnumClassHash {
63
  template <typename T>
64
  std::size_t operator()(T t) const {
65
    return static_cast<std::size_t>(t);
66
  }
67
};
68

69
bool SimpleValue::hasAttr(
70
    const SourceRange& loc,
71
    GraphFunction& m,
72
    const std::string& field) {
73
  auto class_type = value_->type()->cast<ClassType>();
74
  if (!class_type) {
75
    throw ErrorReport(loc) << "hasattr's first argument must be an object, got "
76
                           << value_->type()->repr_str() << " instead";
77
  }
78

79
  return class_type->hasMethod(field) || class_type->hasAttribute(field) ||
80
      class_type->hasConstant(field);
81
}
82

83
// support syntax sugar for x.foo(y, z) by allowing x.foo to return a
84
// callable value that will resolve to foo(x, y, z) when called.
85
std::shared_ptr<SugaredValue> SimpleValue::attr(
86
    const SourceRange& loc,
87
    GraphFunction& m,
88
    const std::string& field) {
89
  // Allow method-style casts on Tensor types. e.g. x.int()
90
  if (value_->type()->isSubtypeOf(*TensorType::get())) {
91
    if (builtin_cast_method_to_scalar_type().count(field)) {
92
      return std::make_shared<TensorCastValue>(
93
          builtin_cast_method_to_scalar_type().at(field),
94
          NamedValue(loc, "self", value_));
95
    }
96
  }
97
  // accessing properties of Tensor and Device that are implemented as
98
  // prim:: or aten:: operators
99
  using PropertiesLookup = std::unordered_map<
100
      TypeKind,
101
      std::unordered_map<std::string, std::string>,
102
      EnumClassHash>;
103
  static const PropertiesLookup builtin_properties = {
104
      {TypeKind::OptionalType,
105
       {
106
           {"unchecked_unwrap_optional", "prim"},
107
       }},
108
      {TypeKind::TensorType,
109
       {
110
           {"dtype", "prim"},
111
           {"device", "prim"},
112
           {"grad", "prim"},
113
           {"data", "prim"},
114
           {"shape", "prim"},
115
           {"is_cuda", "prim"},
116
           {"is_cpu", "prim"},
117
           {"is_xla", "prim"},
118
           {"is_xpu", "prim"},
119
           {"is_sparse", "prim"},
120
           {"is_sparse_csr", "prim"},
121
           {"is_mkldnn", "prim"},
122
           {"is_mps", "prim"},
123
           {"is_mtia", "prim"},
124
           {"is_quantized", "prim"},
125
           {"is_vulkan", "prim"},
126
           {"is_ipu", "prim"},
127
           {"is_meta", "prim"},
128
           {"is_leaf", "aten"},
129
           {"is_nested", "prim"},
130
           {"requires_grad", "prim"},
131
           {"layout", "prim"},
132
           {"T", "prim"},
133
           {"H", "prim"},
134
           {"mT", "aten"},
135
           {"mH", "aten"},
136
           {"is_ort", "prim"},
137
           {"itemsize", "prim"},
138
           {"nbytes", "prim"},
139
           {"ndim", "prim"},
140
           {"name", "prim"},
141
           {"real", "aten"},
142
           {"imag", "aten"},
143
           {"retains_grad", "aten"},
144
       }},
145
      {TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
146
  auto kind = value_->type()->kind();
147
  auto types_for_builtin = builtin_properties.find(kind);
148
  if (types_for_builtin != builtin_properties.end()) {
149
    auto builtin_entry = types_for_builtin->second.find(field);
150
    if (builtin_entry != types_for_builtin->second.end()) {
151
      // A builtin was found, add it to the graph
152
      auto the_namespace = builtin_entry->second;
153
      auto r = m.graph()->insert(
154
          Symbol::fromQualString(the_namespace + "::" + field), {value_});
155
      return std::make_shared<SimpleValue>(r);
156
    }
157
  }
158

159
  // accessing fields of named tuples
160
  if (auto tuple_type = value_->type()->cast<TupleType>()) {
161
    if (tuple_type->schema()) {
162
      auto attrs = tuple_type->schema()->arguments();
163
      for (const auto i : c10::irange(attrs.size())) {
164
        if (attrs[i].name() == field) {
165
          auto idx = m.graph()->insertConstant(IValue(static_cast<int64_t>(i)));
166
          auto out_type = tuple_type->elements().at(i);
167
          auto r = m.graph()
168
                       ->insertNode(
169
                           m.graph()->createTupleIndex(value_, idx, out_type))
170
                       ->output();
171
          return std::make_shared<SimpleValue>(r);
172
        }
173
      }
174
    }
175
  } else if (auto awaitType = value_->type()->cast<AwaitType>()) {
176
    auto elType = awaitType->getElementType();
177
    auto& g = *m.graph();
178
    auto v = g.insert(prim::awaitable_wait, {value_}, {}, loc);
179
    auto sv = std::make_shared<SimpleValue>(v);
180
    return sv->attr(loc, m, field);
181
  } else if (auto classType = value_->type()->cast<ClassType>()) {
182
    // This is a class, emit the proper attribute lookup
183
    if (classType->findMethod(field)) {
184
      return std::make_shared<MethodValue>(getValue(), field);
185
    }
186
    if (classType->hasAttribute(field)) {
187
      auto& g = *m.graph();
188
      auto n = g.insertNode(g.createGetAttr(value_, field));
189
      return std::make_shared<SimpleValue>(n->output());
190
    }
191
    // Check and see if it's a getter attribute.
192
    auto prop = classType->getProperty(field);
193
    if (prop) {
194
      return MethodValue(value_, prop->getter->name())
195
          .call(loc, m, {}, {}, /*n_binders=*/1);
196
    }
197
  } else if (auto iface = value_->type()->cast<InterfaceType>()) {
198
    // accessing methods of interfaces
199
    if (iface->getMethod(field)) {
200
      return std::make_shared<MethodValue>(getValue(), field);
201
    }
202
  } else if (auto enum_type = value_->type()->cast<EnumType>()) {
203
    // Handle access to Enum's `name` and `value` attribute.
204
    auto& g = *m.graph();
205

206
    if (field == "name") {
207
      auto n = g.insertNode(g.createEnumName(value_));
208
      return std::make_shared<SimpleValue>(n->output());
209
    }
210

211
    if (field == "value") {
212
      auto n = g.insertNode(g.createEnumValue(value_));
213
      return std::make_shared<SimpleValue>(n->output());
214
    }
215
  }
216

217
  // none of the more-specific cases worked, so see if this is a builtin method
218
  // If field is a type, then call the aten::to op
219
  if (field == "type") {
220
    if (auto builtin = BuiltinFunction::tryCreate(
221
            Symbol::aten("to"), NamedValue(loc, "self", value_))) {
222
      return builtin;
223
    }
224
  }
225

226
  if (auto builtin = BuiltinFunction::tryCreate(
227
          Symbol::aten(field), NamedValue(loc, "self", value_))) {
228
    return builtin;
229
  }
230

231
  // Handle calling tolist() on a Tensor.
232
  if (value_->type()->isSubtypeOf(*TensorType::get()) && field == "tolist") {
233
    return SpecialFormValue::create(prim::tolist);
234
  }
235

236
  // Handle calling __getitem__() directly on a Tensor, it needs special
237
  // handling because desired method name (`__getitem__`) doesn't match `aten`
238
  // operator name of `aten::index`.
239
  if (value_->type()->isSubtypeOf(*TensorType::get()) &&
240
      field == "__getitem__") {
241
    return SpecialFormValue::create(aten::index);
242
  }
243

244
  if (auto generator_type = value_->type()->cast<GeneratorType>()) {
245
    // Handle access to Generator's `manual_seed`, `initial_seed` and `seed`
246
    // attributes.
247
    if (field == "manual_seed" || field == "initial_seed" || field == "seed") {
248
      if (auto builtin = BuiltinFunction::tryCreate(
249
              Symbol::aten(field), NamedValue(loc, "self", value_))) {
250
        return builtin;
251
      }
252
    }
253
  }
254

255
  ErrorReport report(loc);
256
  report << "'" << value_->type()->repr_str()
257
         << "' object has no attribute or method '" << field << "'.";
258
  if (auto classType = value_->type()->cast<ClassType>()) {
259
    if (classType->isUnresolvedClassAttribute(field)) {
260
      report
261
          << " '" << field
262
          << "' is defined as a class attribute which currently is not"
263
             " supported. Consider converting this to an instance attribute.";
264
    } else {
265
      report << " Did you forget to initialize an attribute in __init__()?";
266
    }
267
  }
268
  throw report;
269
}
270

271
std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
272
    const SourceRange& loc,
273
    GraphFunction& m,
274
    const c10::optional<size_t>& size_hint) {
275
  static const auto make_simple_value =
276
      [](Value* v) -> std::shared_ptr<SugaredValue> {
277
    return std::make_shared<SimpleValue>(v);
278
  };
279
  if (value_->type()->kind() == TypeKind::TupleType) {
280
    auto outputs = createTupleUnpack(value_);
281
    return fmap(outputs, make_simple_value);
282
  } else if (value_->type()->kind() == TypeKind::ListType) {
283
    if (!size_hint) {
284
      throw ErrorReport(loc)
285
          << "cannot statically infer the expected size of a "
286
          << "list in this context";
287
    }
288
    auto graph = value_->owningGraph();
289
    Node* unpack =
290
        graph->insertNode(graph->createListUnpack(value_, *size_hint));
291
    return fmap(unpack->outputs(), make_simple_value);
292
  } else if (value_->type()->kind() == TypeKind::AnyTupleType) {
293
    throw ErrorReport(loc)
294
        << "Provided tuple is not fully defined/refined including its element types, please provide a value of type like Tuple[int, int]";
295
  }
296
  throw ErrorReport(loc) << value_->type()->repr_str()
297
                         << " cannot be used as a tuple";
298
}
299

300
static bool isRecursive(const TypePtr& classType, const TypePtr& attrType) {
301
  if (attrType->isSubtypeOf(*classType)) {
302
    return true;
303
  }
304

305
  // Recursively check contained types. We need to do this because a user may do
306
  // A -> B -> A.
307
  for (const auto& type : attrType->containedTypes()) {
308
    if (isRecursive(classType, type)) {
309
      return true;
310
    }
311
  }
312
  return false;
313
}
314

315
void SimpleValue::setAttr(
316
    const SourceRange& loc,
317
    GraphFunction& m,
318
    const std::string& field,
319
    Value* newValue) {
320
  const auto classType = value_->type()->cast<ClassType>();
321
  if (!classType) {
322
    throw ErrorReport(loc) << "Tried to set an attribute: " << field
323
                           << " on a non-class: " << value_->type()->repr_str();
324
  }
325
  auto expectedType = classType->findAttribute(field);
326
  if (!expectedType) {
327
    // If we are still compiling the __init__ method for this class, then
328
    // setting an unknown attribute adds it to the class's definition.
329

330
    // We are initializing if:
331
    const auto isInitializing =
332
        // 1. The method we're currently inserting into is an init method
333
        // TODO this can be a qualified name check
334
        m.name() == "__init__" &&
335
        // 2. The `self` arg matches this value's type (i.e. we are in the init
336
        // method for this class, not some other class)
337
        !m.graph()->inputs().empty() &&
338
        m.graph()->inputs().at(0)->type() == classType;
339

340
    if (isInitializing) {
341
      if (isRecursive(classType, newValue->type())) {
342
        throw ErrorReport(loc)
343
            << "Assignment to attribute '" << field
344
            << "' cannot be of a type that contains class "
345
            << "'" << classType->repr_str() << "'.\n"
346
            << "Classes that recursively contain instances of themselves"
347
            << " are not yet supported";
348
      }
349

350
      classType->addAttribute(field, newValue->type());
351
      expectedType = newValue->type();
352

353
      const auto insertPoint = m.graph()->insertPoint();
354
      const auto topLevelBlock = m.graph()->block();
355
      if (insertPoint->owningBlock() != topLevelBlock) {
356
        throw ErrorReport(loc)
357
            << "First assignment cannot be in a control-flow block. "
358
            << "Initialize the field at the top level first";
359
      }
360
    } else {
361
      // Check and see if it's a setter attribute.
362
      auto prop = classType->getProperty(field);
363
      if (prop && prop->setter) {
364
        MethodValue(value_, prop->setter->name())
365
            .call(loc, m, {newValue}, {}, /*n_binders=*/1);
366
        return;
367
      }
368

369
      if (prop && !prop->setter) {
370
        throw ErrorReport(loc) << "Tried to set read-only attribute: " << field;
371
      }
372

373
      throw ErrorReport(loc)
374
          << "Tried to set nonexistent attribute: " << field
375
          << ". Did you forget to initialize it in __init__()?";
376
    }
377
  }
378

379
  AT_ASSERT(expectedType);
380

381
  // Check type correctness
382
  const auto newType = newValue->type();
383
  if (!newType->isSubtypeOf(*expectedType)) {
384
    throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
385
                           << expectedType->repr_str() << " but got "
386
                           << newType->repr_str();
387
  }
388

389
  auto& g = *m.graph();
390
  g.insertNode(g.createSetAttr(value_, field, newValue));
391
}
392

393
std::shared_ptr<SugaredValue> SimpleValue::call(
394
    const SourceRange& loc,
395
    GraphFunction& m,
396
    at::ArrayRef<NamedValue> args,
397
    at::ArrayRef<NamedValue> kwargs,
398
    size_t n_binders) {
399
  // allow our 'fake' closures to be called, used for fork serialization
400
  // at the moment, but can be expanded later
401
  Node* self = getValue()->node();
402
  if (self->kind() == prim::TupleConstruct && self->inputs().size() == 2 &&
403
      self->inputs().at(0)->node()->kind() == prim::Closure) {
404
    std::shared_ptr<Graph> graph =
405
        self->inputs().at(0)->node()->g(attr::Subgraph);
406
    Value* context = self->inputs().at(1);
407
    AT_ASSERT(context->node()->kind() == prim::TupleConstruct);
408

409
    // fork nodes are emitted in their own block but we do not simplify
410
    // tuple construction across blocks. To ensure we clean up the tuple
411
    // construct create another copy of the tuple construct in the fork block
412
    Value* close_context =
413
        m.graph()
414
            ->insertNode(m.graph()->createTuple(context->node()->inputs()))
415
            ->output();
416
    // TODO this needs to go in `m`s compilation unit
417
    auto cu = std::make_shared<CompilationUnit>();
418
    auto fn = cu->create_function(QualifiedName("anon"), graph);
419
    auto ret = StrongFunctionPtr(std::move(cu), fn);
420

421
    std::vector<NamedValue> ctx_inputs = {close_context};
422
    ctx_inputs.insert(ctx_inputs.end(), args.begin(), args.end());
423
    return FunctionValue(ret).call(loc, m, ctx_inputs, kwargs, n_binders);
424
  }
425

426
  if (auto class_type = getValue()->type()->cast<ClassType>()) {
427
    return attr(loc, m, "__call__")->call(loc, m, args, kwargs, n_binders);
428
  }
429

430
  return SugaredValue::call(loc, m, args, kwargs, n_binders);
431
}
432

433
Value* SimpleValue::len(const SourceRange& loc, GraphFunction& m) {
434
  // List, Tuple, Tensor, fill in missing information desugaring
435
  Value* val = getValue();
436
  TypePtr val_type = val->type();
437
  Graph& g = *m.graph();
438
  if (val_type->cast<ListType>() || val_type->cast<StringType>() ||
439
      val_type->isSubtypeOf(*TensorType::get())) {
440
    return g.insert(aten::len, {val}, {}, loc);
441
  } else {
442
    throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
443
                           << " object is not iterable";
444
  }
445
}
446

447
SugaredValuePtr SimpleValue::getitem(
448
    const SourceRange& loc,
449
    GraphFunction& m,
450
    Value* idx,
451
    TypePtr type_hint) {
452
  Value* val = getValue();
453
  TypePtr val_type = val->type();
454
  Graph& g = *m.graph();
455

456
  // if it's a List/String/Dict, emit a regular __getitem__ op
457
  // NOLINTNEXTLINE(bugprone-branch-clone)
458
  if (val_type->cast<ListType>() || val_type->cast<StringType>()) {
459
    return std::make_shared<SimpleValue>(
460
        g.insert(aten::__getitem__, {val, idx}, {}, loc));
461
  } else if (auto dict_type = val_type->cast<DictType>()) {
462
    return std::make_shared<SimpleValue>(
463
        g.insert(aten::__getitem__, {val, idx}, {}, loc));
464
  } else if (val_type->isSubtypeOf(*TensorType::get())) {
465
    return std::make_shared<SimpleValue>(
466
        g.insert(aten::select, {val, 0, idx}, {}, loc));
467
  } else if (auto class_type = val_type->cast<ClassType>()) {
468
    // Check if this is an indexing operation enabled by a type hint.
469
    // The ModuleDict has already been checked during IR generation to make
470
    // sure its contents implement the module interface referred to by
471
    // type_hint.
472
    if (class_type->is_module() && type_hint) {
473
      auto res = g.insert(prim::ModuleContainerIndex, {val, idx}, {}, loc);
474
      res->setType(type_hint);
475
      return std::make_shared<SimpleValue>(res);
476
    }
477

478
    // Defer to the __getitem__ attr on the class.
479
    return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1);
480
  } else {
481
    throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
482
                           << " object is not subscriptable";
483
  }
484
}
485

486
SugaredValuePtr SimpleValue::iter(const SourceRange& loc, GraphFunction& m) {
487
  auto value = getValue();
488
  auto type = value->type();
489
  // built-in iterable types
490
  if (type->cast<ListType>() || type->cast<StringType>() ||
491
      type->cast<TensorType>()) {
492
    return std::make_shared<SimpleValue>(value);
493
  }
494
  // dicts iterate over keys
495
  if (type->cast<DictType>()) {
496
    return std::make_shared<SimpleValue>(
497
        m.graph()->insert(aten::keys, {value}, {}, loc));
498
  }
499
  if (auto tup = type->cast<TupleType>()) {
500
    auto tup_values = createTupleUnpack(value);
501
    std::vector<SugaredValuePtr> tup_sugared;
502
    for (Value* v : tup_values) {
503
      tup_sugared.push_back(std::make_shared<SimpleValue>(v));
504
    }
505
    return std::make_shared<SugaredTupleValue>(tup_sugared);
506
  } else {
507
    throw ErrorReport(loc) << "'" << type->repr_str() << "'"
508
                           << " object is not iterable";
509
  }
510
}
511

512
RangeValue::RangeValue(
513
    const SourceRange& loc,
514
    GraphFunction& m,
515
    std::vector<Value*> inputs,
516
    c10::optional<int64_t> static_len) {
517
  for (const auto i : c10::irange(inputs.size())) {
518
    auto typ = inputs[i]->type();
519
    if (!typ->cast<IntType>()) {
520
      throw ErrorReport(loc)
521
          << "all inputs of range must be ints, found " << typ->repr_str()
522
          << " in argument " << std::to_string(i);
523
    }
524
  }
525

526
  Graph& g = *m.graph();
527
  if (inputs.empty()) {
528
    throw ErrorReport(loc) << "range expected at least 1 arguments, got 0";
529
  } else if (inputs.size() == 1) {
530
    end_ = inputs[0];
531
    start_ = g.insertConstant(0, loc);
532
    step_ = g.insertConstant(1, loc);
533
    // range() call only contains end, easier to calculate len() and getitem()
534
    has_only_end_ = true;
535
  } else if (inputs.size() <= 3) {
536
    start_ = inputs[0];
537
    end_ = inputs[1];
538
    if (inputs.size() == 3) {
539
      step_ = inputs[2];
540
    } else {
541
      step_ = g.insertConstant(1, loc);
542
    }
543
    has_only_end_ = false;
544
  } else {
545
    throw ErrorReport(loc) << "range expected at most 3 arguments, got "
546
                           << inputs.size();
547
  }
548

549
  static_len_ = static_len;
550
}
551

552
SugaredValuePtr RangeValue::iter(const SourceRange& loc, GraphFunction& m) {
553
  return shared_from_this();
554
};
555

556
Value* RangeValue::len(const SourceRange& loc, GraphFunction& m) {
557
  if (static_len_) {
558
    return insertConstant(*m.graph(), *static_len_, loc);
559
  }
560
  if (has_only_end_) {
561
    return end_;
562
  } else {
563
    Graph& g = *m.graph();
564
    return g.insert(aten::__range_length, {start_, end_, step_}, {}, loc);
565
  }
566
}
567

568
SugaredValuePtr RangeValue::getitem(
569
    const SourceRange& loc,
570
    GraphFunction& m,
571
    Value* idx,
572
    TypePtr type_hint) {
573
  if (has_only_end_) {
574
    return std::make_shared<SimpleValue>(idx);
575
  } else {
576
    auto& g = *m.graph();
577
    return std::make_shared<SimpleValue>(
578
        g.insert(aten::__derive_index, {idx, start_, step_}, {}, loc));
579
  }
580
}
581

582
std::vector<SugaredValuePtr> IterableTree::get_base_iterables() {
583
  std::vector<SugaredValuePtr> base_iters{};
584

585
  for (SugaredValuePtr& sv : children_) {
586
    if (auto iv = std::dynamic_pointer_cast<IterableTree>(sv)) {
587
      std::vector<SugaredValuePtr> child_iters = iv->get_base_iterables();
588
      // merge child iters with the base_iters
589
      base_iters.insert(
590
          base_iters.end(),
591
          std::make_move_iterator(child_iters.begin()),
592
          std::make_move_iterator(child_iters.end()));
593

594
    } else {
595
      // IterableTree leaves, either SimpleValue or RangeValue
596
      base_iters.emplace_back(sv);
597
    }
598
  }
599
  return base_iters;
600
}
601

602
Value* IterableTree::len(const SourceRange& loc, GraphFunction& m) {
603
  // if it's a iterable tree, we get the base iterables that consists of
604
  // SimpleValue or RangeValue, and then calculate the minimum length of all the
605
  // base iterables to be max_trip_count_val
606
  TORCH_INTERNAL_ASSERT(!unroll_length_);
607
  Graph& g = *m.graph();
608
  std::vector<SugaredValuePtr> base_iters = get_base_iterables();
609
  std::vector<Value*> lengths;
610
  lengths.reserve(base_iters.size());
611

612
  for (const SugaredValuePtr& base_iter : base_iters) {
613
    lengths.emplace_back(base_iter->len(loc, m));
614
  }
615
  Node* list_node = g.insertNode(g.createList(IntType::get(), lengths));
616
  return g.insert(prim::min, {list_node->output()}, {}, loc);
617
}
618

619
SugaredValuePtr IterableTree::getitem(
620
    const SourceRange& loc,
621
    GraphFunction& m,
622
    Value* idx,
623
    TypePtr type_hint) {
624
  std::vector<SugaredValuePtr> child_items;
625
  child_items.reserve(children_.size());
626
  for (const SugaredValuePtr& child : children_) {
627
    child_items.emplace_back(child->getitem(loc, m, idx));
628
  }
629
  return std::make_shared<SugaredTupleValue>(child_items);
630
}
631

632
void IterableTree::addChild(
633
    const SourceRange& range,
634
    GraphFunction& m,
635
    const SugaredValuePtr& iter_value) {
636
  c10::optional<int64_t> child_len = iter_value->staticLen();
637
  if (children_.empty()) {
638
    unroll_length_ = child_len;
639
  } else {
640
    if ((unroll_length_ && !child_len) || (child_len && !unroll_length_)) {
641
      throw ErrorReport(range)
642
          << "Can not iterate over a module list or tuple with a value "
643
             "that does not have a statically determinable length\n";
644
    }
645
    if (unroll_length_ && child_len) {
646
      // iterables run for the minimum length of all its leaves
647
      unroll_length_ = std::min(*child_len, *unroll_length_);
648
    } else {
649
      unroll_length_ = c10::nullopt;
650
    }
651
  }
652
  children_.push_back(iter_value);
653
}
654

655
std::shared_ptr<SugaredValue> MagicMethod::call(
656
    const SourceRange& loc,
657
    GraphFunction& m,
658
    at::ArrayRef<NamedValue> args,
659
    at::ArrayRef<NamedValue> kwargs,
660
    size_t n_binders) {
661
  if (!args.empty()) {
662
    Value* self = args[0].value(*m.graph());
663
    if (auto class_ptr = self->type()->cast<ClassType>()) {
664
      return SimpleValue(self)
665
          .attr(loc, m, desugared_name_)
666
          ->call(loc, m, args.slice(1), kwargs, n_binders);
667
    }
668
  }
669
  TORCH_INTERNAL_ASSERT(base_value_);
670
  return base_value_->call(loc, m, args, kwargs, n_binders);
671
}
672

673
std::shared_ptr<SugaredValue> ClassValue::call(
674
    const SourceRange& loc,
675
    GraphFunction& m,
676
    // note: names for args will be 'argument 0', 'argument 1', etc..
677
    at::ArrayRef<NamedValue> args,
678
    at::ArrayRef<NamedValue> kwargs,
679
    size_t n_binders) {
680
  AT_ASSERT(n_binders <= 1);
681

682
  // Generate a new object of the right type, then call `__init__` on it
683
  auto& g = *m.graph();
684
  auto self = g.insertNode(g.createObject(type_))->output();
685
  self->node()->setSourceRange(loc);
686
  if (!type_->findMethod("__init__")) {
687
    throw ErrorReport(loc) << "Class " << type_->name()->name()
688
                           << " does not have an __init__ function defined";
689
  }
690

691
  // Call the init function
692
  MethodValue(self, "__init__").call(loc, m, args, kwargs, n_binders);
693

694
  return std::make_shared<SimpleValue>(self);
695
}
696

697
std::shared_ptr<SugaredValue> ClassValue::attr(
698
    const SourceRange& loc,
699
    GraphFunction& m,
700
    const std::string& field) {
701
  // Allow import_source.cpp to resolve calls to a submodule's
702
  // hooks. Edge case because normally you wouldn't allow a module to
703
  // call functions of a submodule
704
  if (Function* hook = type_->findHook(field)) {
705
    return std::make_shared<FunctionValue>(hook);
706
  }
707

708
  if (field != "__new__") {
709
    throw ErrorReport(loc) << "Tried to lookup unknown attribute on class "
710
                           << type_->annotation_str();
711
  }
712
  return SpecialFormValue::create(prim::CreateObject);
713
}
714

715
std::shared_ptr<SugaredValue> NamedTupleConstructor::call(
716
    const SourceRange& loc,
717
    GraphFunction& m,
718
    at::ArrayRef<NamedValue> args,
719
    at::ArrayRef<NamedValue> kwargs,
720
    size_t n_binders) {
721
  auto& g = *m.graph();
722

723
  auto schema = type_->schema();
724
  TORCH_INTERNAL_ASSERT(schema);
725
  auto qualname = type_->name();
726
  auto matched_schema = matchSchema(*schema, loc, g, args, kwargs);
727

728
  auto self =
729
      g.insertNode(
730
           g.createTuple(matched_schema.inputs, type_)->setSourceRange(loc))
731
          ->output();
732
  self->setType(type_);
733

734
  return std::make_shared<SimpleValue>(self);
735
}
736

737
std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
738
    Symbol symbol,
739
    c10::optional<NamedValue> self) {
740
  for (const std::shared_ptr<Operator>& op : getAllOperatorsFor(symbol)) {
741
    if (!self) {
742
      return std::make_shared<BuiltinFunction>(symbol, nullptr);
743
    }
744
    if (auto index = op->schema().argumentIndexWithName("self")) {
745
      std::unordered_map<std::string, TypePtr> type_env;
746
      TypePtr formal_type = op->schema().arguments().at(*index).type();
747
      const MatchTypeReturn matched =
748
          matchTypeVariables(formal_type, self->type(), type_env);
749
      if (!matched.success()) {
750
        continue;
751
      }
752
      const auto concrete_type = tryEvalTypeVariables(formal_type, type_env);
753
      if (!concrete_type || !self->type()->isSubtypeOf(*concrete_type)) {
754
        continue;
755
      }
756
      return std::make_shared<BuiltinFunction>(symbol, self);
757
    }
758
  }
759
  return nullptr;
760
}
761

762
std::shared_ptr<SugaredValue> SugaredEnumClass::attr(
763
    const SourceRange& loc,
764
    GraphFunction& m,
765
    const std::string& field) {
766
  const auto& names_values = enum_type_->enumNamesValues();
767
  auto it = std::find_if(
768
      names_values.begin(),
769
      names_values.end(),
770
      [&field](const at::EnumNameValue& nv) { return nv.first == field; });
771
  if (it == names_values.end()) {
772
    throw ErrorReport(loc) << enum_type_->repr_str() << "'"
773
                           << " has no attribute '" << field << "'";
774
  }
775
  auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
776
      enum_type_, it->first, it->second);
777
  return std::make_shared<SimpleValue>(
778
      m.graph()->insertConstant(IValue(enum_holder), loc));
779
}
780

781
SugaredValuePtr SugaredEnumClass::iter(
782
    const SourceRange& loc,
783
    GraphFunction& m) {
784
  const auto& names_values = enum_type_->enumNamesValues();
785
  auto enum_value_ivalues = c10::impl::GenericList(enum_type_);
786
  enum_value_ivalues.reserve(names_values.size());
787
  for (const auto& name_value : names_values) {
788
    auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
789
        enum_type_, name_value.first, name_value.second);
790
    enum_value_ivalues.emplace_back(enum_holder);
791
  }
792

793
  auto enum_values_list_constant = std::make_shared<SimpleValue>(
794
      m.graph()->insertConstant(enum_value_ivalues, loc));
795
  return enum_values_list_constant;
796
}
797

798
} // namespace torch::jit
799

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.