pytorch

Форк
0
/
irparser.cpp 
675 строк · 18.9 Кб
1
#include <torch/csrc/jit/ir/irparser.h>
2

3
#include <ATen/EmptyTensor.h>
4
#include <torch/csrc/jit/frontend/lexer.h>
5
#include <torch/csrc/jit/frontend/parse_string_literal.h>
6
#include <torch/csrc/jit/frontend/schema_type_parser.h>
7
#include <torch/csrc/jit/ir/ir.h>
8

9
#ifndef AT_PER_OPERATOR_HEADERS
10
#include <ATen/Functions.h>
11
#else
12
#include <ATen/ops/empty.h>
13
#include <ATen/ops/empty_strided.h>
14
#endif
15

16
#include <string>
17
#include <vector>
18

19
namespace torch::jit {
20

21
struct VarWithType;
22
struct ParsedLiteral;
23

24
class IRParser {
25
  friend void parseIR(
26
      const std::string& str,
27
      torch::jit::Graph* graph,
28
      std::unordered_map<std::string, Value*>& vmap,
29
      bool parse_tensor_constants);
30
  IRParser(
31
      const std::string& str,
32
      torch::jit::Graph* graph,
33
      std::unordered_map<std::string, Value*>& vmap,
34
      bool parse_tensor_constants)
35
      : L(std::make_shared<Source>(str)),
36
        g(graph),
37
        vmap(vmap),
38
        type_parser(L, /*parse_complete_tensor_types*/ true),
39
        parse_tensor_constants_(parse_tensor_constants) {}
40

41
  std::string parseVar();
42
  VarWithType parseVarWithType(bool allow_optional = false);
43
  ParsedLiteral parseScalarLiteral(Node* n);
44

45
  void parse();
46
  void parseGraphInputs();
47
  void parseReturnOperator();
48

49
  void parseBlocks(Node* parentNode);
50
  void parseBlock(Node* parentNode);
51
  void parseBlockInputs(Block* b);
52
  void parseBlockOutputs(Block* b);
53

54
  void parseOperatorsList(Block* b);
55
  void parseOperator(Block* b);
56
  void parseOperatorOutputs(std::vector<VarWithType>* outs);
57
  std::string parseOperatorName();
58
  void parseOperatorInputs(Node* n);
59
  void parseAttrs(Node* n);
60
  void parseAttr(Node* n);
61

62
  void parseList(
63
      int begin,
64
      int sep,
65
      int end,
66
      const std::function<void()>& callback);
67

68
  void bypassTypeAnnotationList();
69

70
  Value* findValueInVMap(const std::string& name);
71

72
  torch::jit::Lexer L;
73
  torch::jit::Graph* g = nullptr;
74
  std::unordered_map<std::string, Value*>& vmap;
75
  SchemaTypeParser type_parser;
76
  bool parse_tensor_constants_;
77
  std::vector<Node*> deferred_tensor_value_initializations_;
78
  std::vector<Node*> deferred_empty_container_initializations_;
79
};
80

81
struct ParsedLiteral {
82
  ParsedLiteral() = default;
83

84
  AttributeKind k = AttributeKind::t;
85

86
  int64_t i = 0;
87
  std::string s = "";
88
  double f = 0.0;
89
  c10::complex<double> c = c10::complex<double>(0, 0);
90
  TypePtr ty;
91
  std::vector<int64_t> is;
92
  std::vector<std::string> ss;
93
  std::vector<double> fs;
94
  std::vector<c10::complex<double>> cs;
95
  std::vector<TypePtr> tys;
96
};
97

98
struct VarWithType {
99
  VarWithType() = default;
100
  std::string name;
101
  TypePtr type;
102
};
103

104
void parseIR(
105
    const std::string& str,
106
    torch::jit::Graph* graph,
107
    std::unordered_map<std::string, Value*>& vmap,
108
    bool parse_tensor_constants) {
109
  torch::jit::IRParser p(str, graph, vmap, parse_tensor_constants);
110
  p.parse();
111
}
112

113
void parseIR(
114
    const std::string& str,
115
    torch::jit::Graph* graph,
116
    bool parse_tensor_constants) {
117
  std::unordered_map<std::string, Value*> vmap;
118
  parseIR(str, graph, vmap, parse_tensor_constants);
119
}
120

121
VarWithType IRParser::parseVarWithType(bool allow_optional) {
122
  VarWithType r;
123
  r.name = parseVar();
124
  if (allow_optional) {
125
    r.type = nullptr;
126
  } else {
127
    r.type = TensorType::get();
128
  }
129
  if (L.nextIf(':')) {
130
    auto type_alias = type_parser.parseType();
131
    AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
132
    r.type = type_alias.first;
133
  }
134
  return r;
135
}
136

137
std::string IRParser::parseVar() {
138
  L.expect('%');
139
  std::string name;
140
  bool continue_parsing;
141
  do {
142
    if (L.cur().kind == TK_IDENT) {
143
      name += L.expect(TK_IDENT).text();
144
    } else {
145
      name += L.expect(TK_NUMBER).text();
146
    }
147
    continue_parsing = false;
148
    if (L.nextIf('.')) {
149
      continue_parsing = true;
150
      name += '.';
151
    } else if (L.cur().kind == TK_NUMBER && L.cur().text()[0] == '.') {
152
      continue_parsing = true;
153
    }
154
  } while (continue_parsing);
155
  return name;
156
}
157

158
void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {
159
  if (L.cur().kind != '%') {
160
    return;
161
  }
162
  parseList(TK_NOTHING, ',', TK_NOTHING, [&] {
163
    outs->push_back(parseVarWithType(true));
164
  });
165
  L.expect('=');
166
}
167

168
// Parse string or numeric literal and return it along with its type.
169
ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
170
  auto token = L.cur();
171
  std::string str;
172
  std::pair<TypePtr, c10::optional<c10::AliasInfo>> type_alias;
173
  ParsedLiteral r;
174
  switch (token.kind) {
175
    case TK_STRINGLITERAL:
176
      r.k = AttributeKind::s;
177
      r.s = parseStringLiteral(token.range, token.text());
178
      L.next();
179
      return r;
180
    case '-':
181
      str = "-";
182
      L.next();
183
      if (L.cur().kind != TK_NUMBER) {
184
        throw ErrorReport(token.range)
185
            << "Expected a number after '-' but got:" << token.text();
186
      }
187
      [[fallthrough]];
188
    case TK_NUMBER:
189
      str += L.cur().text();
190
      if (str.find('j') != std::string::npos) {
191
        r.k = AttributeKind::c;
192
        double imag = 0.0f;
193
        try {
194
          imag = std::stod(str.substr(0, str.size() - 1));
195
        } catch (const std::invalid_argument& e) {
196
          throw ErrorReport(token.range)
197
              << "Number cannot be converted to double";
198
        } catch (const std::out_of_range& e) {
199
          throw ErrorReport(token.range)
200
              << "Number is too long to be represented in type double";
201
        }
202
        r.c = c10::complex<double>(0, imag);
203
      } else if (
204
          str.find('.') != std::string::npos ||
205
          str.find('e') != std::string::npos) {
206
        r.k = AttributeKind::f;
207
        try {
208
          r.f = std::stod(str);
209
        } catch (const std::invalid_argument& e) {
210
          throw ErrorReport(token.range)
211
              << "Number cannot be converted to double";
212
        } catch (const std::out_of_range& e) {
213
          throw ErrorReport(token.range)
214
              << "Number is too long to be represented in type double";
215
        }
216
      } else {
217
        r.k = AttributeKind::i;
218
        try {
219
          r.i = std::stoll(str);
220
        } catch (const std::invalid_argument& e) {
221
          throw ErrorReport(token.range)
222
              << "Number cannot be converted to integer";
223
        } catch (const std::out_of_range& e) {
224
          throw ErrorReport(token.range) << "Number is too big";
225
        }
226
      }
227
      L.next();
228
      return r;
229
    case TK_IDENT:
230
      // Type literal
231
      r.k = AttributeKind::ty;
232
      type_alias = type_parser.parseType();
233
      AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
234
      r.ty = type_alias.first;
235
      return r;
236
    case '<': {
237
      L.next();
238
      auto text = L.expect(TK_IDENT);
239
      if (text.text() != "Tensor") {
240
        throw ErrorReport(token.range)
241
            << "Could not parse literal" << token.text();
242
      }
243
      if (!parse_tensor_constants_) {
244
        throw ErrorReport(token.range)
245
            << "Tensor constant encountered but `parse_tensor_constants` set to false"
246
            << token.text();
247
      }
248
      L.expect('>');
249
      // these values will be set with randomly initialized data in
250
      // a post processing pass;
251
      deferred_tensor_value_initializations_.push_back(n);
252
      r.k = AttributeKind::t;
253
      return r;
254
    }
255
    case '{': {
256
      L.next();
257
      if (L.cur().kind == '-') {
258
        L.next();
259
      }
260
      auto text = L.expect(TK_NUMBER);
261
      if (!parse_tensor_constants_) {
262
        throw ErrorReport(token.range)
263
            << "Single-element tensor constant encountered but "
264
            << "`parse_tensor_constants` is set to false " << token.text();
265
      }
266
      L.expect('}');
267
      deferred_tensor_value_initializations_.push_back(n);
268
      r.k = AttributeKind::t;
269
      return r;
270
    }
271
    default:
272
      throw ErrorReport(token.range)
273
          << "Could not parse literal" << token.text();
274
  }
275
}
276

277
void IRParser::bypassTypeAnnotationList() {
278
  int depth = 0;
279
  bool bypassed_list = false;
280
  while (depth != 0 || !bypassed_list) {
281
    if (L.cur().kind == '[') {
282
      bypassed_list = true;
283
      depth++;
284
    } else if (L.cur().kind == ']') {
285
      depth--;
286
    }
287
    L.next();
288
  }
289
}
290

291
/** \brief Parse attribute and add it to the node N.
292
 *
293
 * The function determines the attribute type (string, int, float, complex, list
294
 * of strings, list of ints, list of floats, list of complex, and a list of
295
 * tensors (currently only for empty lists)). An attribute looks like the
296
 * following: AttrName=AttrValue Where AttrValue can be a list or a scalar
297
 * literal, e.g.: size = 27 name = "Bob" coefs = [1.2, 3.4, 0.6]
298
 */
299
void IRParser::parseAttr(Node* n) {
300
  std::string attrname = L.expect(TK_IDENT).text();
301
  L.expect('=');
302
  if (L.cur().kind == '[') {
303
    // list
304
    AttributeKind k = AttributeKind::ts;
305
    c10::List<int64_t> is;
306
    c10::List<std::string> ss;
307
    c10::List<double> fs;
308
    c10::List<c10::complex<double>> cs;
309
    std::vector<TypePtr> tys;
310
    int elem_num = 0;
311
    parseList('[', ',', ']', [&] {
312
      ParsedLiteral r = parseScalarLiteral(n);
313
      switch (r.k) {
314
        case AttributeKind::s:
315
          ss.push_back(r.s);
316
          AT_ASSERT(!elem_num++ || k == AttributeKind::ss);
317
          k = AttributeKind::ss;
318
          break;
319
        case AttributeKind::i:
320
          is.push_back(r.i);
321
          AT_ASSERT(!elem_num++ || k == AttributeKind::is);
322
          k = AttributeKind::is;
323
          break;
324
        case AttributeKind::f:
325
          fs.push_back(r.f);
326
          AT_ASSERT(!elem_num++ || k == AttributeKind::fs);
327
          k = AttributeKind::fs;
328
          break;
329
        case AttributeKind::c:
330
          cs.push_back(r.c);
331
          AT_ASSERT(!elem_num++ || k == AttributeKind::cs);
332
          k = AttributeKind::cs;
333
          break;
334
        case AttributeKind::ty:
335
          tys.push_back(r.ty);
336
          AT_ASSERT(!elem_num++ || k == AttributeKind::tys);
337
          k = AttributeKind::tys;
338
          break;
339
        default:
340
          throw ErrorReport(L.cur().range) << "Unexpected attr type";
341
      }
342
    });
343
    switch (k) {
344
      case AttributeKind::ts:
345
        n->ival_(Symbol::attr(attrname), IValue());
346
        break;
347
      case AttributeKind::ss:
348
        n->ival_(Symbol::attr(attrname), IValue(ss));
349
        break;
350
      case AttributeKind::fs:
351
        n->ival_(Symbol::attr(attrname), IValue(fs));
352
        break;
353
      case AttributeKind::cs:
354
        n->ival_(Symbol::attr(attrname), IValue(cs));
355
        break;
356
      case AttributeKind::is:
357
        n->ival_(Symbol::attr(attrname), IValue(is));
358
        break;
359
      case AttributeKind::tys:
360
        n->tys_(Symbol::attr(attrname), tys);
361
        break;
362
      default:
363
        throw ErrorReport(L.cur().range) << "Unexpected attr type";
364
    }
365
  } else if (L.cur().text() == "annotate") {
366
    L.next();
367
    L.expect('(');
368
    auto type = L.cur().text();
369
    if (type != "List" && type != "Dict") {
370
      throw ErrorReport(L.cur().range)
371
          << "Unexpected annotation (only List and Dict can be parsed)";
372
    }
373
    L.next();
374
    // ignore the annotations on the IValue constants, and instead recover
375
    // type from the Node output
376
    // Note: we could also use script_type_parser
377
    bypassTypeAnnotationList();
378
    L.expect(',');
379
    // expect an empty definition (note - this isn't always true)
380
    if (type == "Dict") {
381
      L.expect('{');
382
      L.expect('}');
383
    } else if (type == "List") {
384
      L.expect('[');
385
      L.expect(']');
386
    }
387
    L.expect(')');
388
    deferred_empty_container_initializations_.push_back(n);
389
  } else {
390
    // scalar
391
    ParsedLiteral r = parseScalarLiteral(n);
392
    switch (r.k) {
393
      case AttributeKind::s:
394
        n->s_(Symbol::attr(attrname), r.s);
395
        break;
396
      case AttributeKind::i:
397
        n->i_(Symbol::attr(attrname), r.i);
398
        break;
399
      case AttributeKind::f:
400
        n->f_(Symbol::attr(attrname), r.f);
401
        break;
402
      case AttributeKind::c:
403
        n->c_(Symbol::attr(attrname), r.c);
404
        break;
405
      case AttributeKind::ty:
406
        n->ty_(Symbol::attr(attrname), r.ty);
407
        break;
408
      case AttributeKind::t:
409
        // initialized with random data later
410
        break;
411
      default:
412
        throw ErrorReport(L.cur().range) << "Unexpected attr type";
413
    }
414
    return;
415
  }
416
}
417

418
void IRParser::parseAttrs(Node* n) {
419
  parseList('[', ',', ']', [&] { parseAttr(n); });
420
}
421

422
void IRParser::parseOperatorInputs(Node* n) {
423
  if (L.cur().kind == '[') {
424
    parseAttrs(n);
425
  }
426
  parseList('(', ',', ')', [&] {
427
    std::string var_name = parseVar();
428
    n->addInput(findValueInVMap(var_name));
429
  });
430
}
431

432
void IRParser::parseBlocks(Node* parentNode) {
433
  L.expect(TK_INDENT);
434
  while (L.cur().kind != TK_DEDENT) {
435
    parseBlock(parentNode);
436
  }
437
  L.expect(TK_DEDENT);
438
}
439

440
void IRParser::parseBlockInputs(Block* b) {
441
  parseList('(', ',', ')', [&] {
442
    VarWithType v = parseVarWithType();
443
    // If the name isn't valid, don't use it
444
    std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
445
    vmap[v.name] = b->addInput(uniq_name);
446
    vmap[v.name]->setType(v.type);
447
  });
448
}
449

450
void IRParser::parseBlockOutputs(Block* b) {
451
  L.expect(TK_ARROW);
452
  parseList('(', ',', ')', [&] {
453
    std::string var_name = parseVar();
454
    b->registerOutput(findValueInVMap(var_name));
455
  });
456
  L.expect(TK_NEWLINE);
457
  L.expect(TK_DEDENT);
458
}
459

460
/** \brief Parse a block.
461
 *
462
 * It should look like the following:
463
 * blockName(input1, input2, input3, ...):
464
 *   op1
465
 *   op2
466
 *   ...
467
 *   opN
468
 *   -> (output1, output2, output3, ...)
469
 */
470
void IRParser::parseBlock(Node* parentNode) {
471
  Block* b = parentNode->addBlock();
472
  L.expect(TK_IDENT).text(); // Block name is not used anywhere.
473
  parseBlockInputs(b);
474
  L.expect(':');
475
  parseOperatorsList(b);
476
  parseBlockOutputs(b);
477
}
478

479
/** \brief Parse a list of statements.
480
 *
481
 * It is expected to be delimited by TK_NEWLINE and end with TK_RETURN or
482
 * TK_ARROW.
483
 */
484
void IRParser::parseOperatorsList(Block* b) {
485
  L.expect(TK_INDENT);
486
  while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) {
487
    parseOperator(b);
488
  }
489
}
490

491
std::string IRParser::parseOperatorName() {
492
  std::string name = L.expect(TK_IDENT).text();
493
  L.expect(':');
494
  L.expect(':');
495
  name += "::" + L.expect(TK_IDENT).text();
496
  return name;
497
}
498

499
/** \brief Parse a statement.
500
 *
501
 * It should look like the following:
502
 *   <outputs> = NodeName[<attributes>](<inputs>)
503
 *     <blocks>
504
 * Outputs, blocks and attributes are optional.
505
 */
506
void IRParser::parseOperator(Block* b) {
507
  // Parse lefthand side.
508
  std::vector<VarWithType> outs;
509
  parseOperatorOutputs(&outs);
510

511
  // Parse the name and create the corresponding node in the graph.
512
  auto source_range = L.cur().range;
513
  std::string name = parseOperatorName();
514
  Node* n = g->create(Symbol::fromQualString(name), {}, outs.size())
515
                ->setSourceRange(source_range);
516

517
  // Parse attributes and inputs.
518
  parseOperatorInputs(n);
519

520
  const FunctionSchema* schema = n->maybeSchema();
521

522
  // Register outputs.
523
  unsigned idx = 0;
524
  for (const VarWithType& v : outs) {
525
    vmap[v.name] = n->outputs()[idx];
526
    if (schema && !schema->is_varret()) {
527
      TORCH_CHECK(
528
          schema->returns().size() > idx,
529
          "Operator parsing error: out of bounds access at ",
530
          idx,
531
          " to schema->returns() which size is ",
532
          schema->returns().size(),
533
          " in size");
534
      auto schema_return_type = schema->returns().at(idx).type();
535
      if (!v.type) {
536
        vmap[v.name]->setType(schema_return_type);
537
      } else {
538
        // Don't currently support checking against type variables
539
        // TODO: support?
540
        if (!schema_return_type->hasFreeVariables() &&
541
            !v.type->isSubtypeOf(*schema_return_type)) {
542
          throw ErrorReport(source_range)
543
              << "Annotated type " << v.type->repr_str()
544
              << " does not match schema type "
545
              << schema_return_type->repr_str() << " for operator " << *schema;
546
        }
547
        vmap[v.name]->setType(v.type);
548
      }
549
    } else {
550
      vmap[v.name]->setType(v.type ? v.type : TensorType::get());
551
    }
552
    idx++;
553
  }
554

555
  // Insert the new node into block B.
556
  b->appendNode(n);
557

558
  // If the statement has nested blocks, parse them:
559
  if (L.cur().kind == TK_INDENT) {
560
    parseBlocks(n);
561
  }
562
  L.nextIf(TK_NEWLINE);
563
}
564

565
void IRParser::parseGraphInputs() {
566
  parseList('(', ',', ')', [&] {
567
    VarWithType v = parseVarWithType();
568
    // If the name isn't valid, don't use it
569
    std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
570
    vmap[v.name] = g->addInput(uniq_name);
571
    vmap[v.name]->setType(v.type);
572
  });
573
}
574

575
/** \brief Parse return statement.
576
 *
577
 * It should look like the following:
578
 *   return (x : TypeX, y : TypeY, z, ...)
579
 */
580
void IRParser::parseReturnOperator() {
581
  L.expect(TK_RETURN);
582

583
  // Parse output names and types
584
  parseList('(', ',', ')', [&] {
585
    std::string var_name = parseVar();
586
    g->registerOutput(findValueInVMap(var_name));
587
  });
588

589
  // Consume ending tokens
590
  if (L.cur().kind != TK_EOF) {
591
    L.expect(TK_NEWLINE);
592
    L.expect(TK_DEDENT);
593
  }
594
}
595

596
/** \brief Parse entire graph.
597
 *
598
 * It should look like the following:
599
 *   graphName (input1, input2, ... inputN):
600
 *     op1
601
 *     op2
602
 *     ...
603
 *     opN
604
 *     return (output1, output2, ... outputN)
605
 */
606
void IRParser::parse() {
607
  // Parse graph definition, it should look like the following:
608
  // graphName (input1, input2, ... inputN):
609
  std::string graphName = L.expect(TK_IDENT).text();
610
  parseGraphInputs();
611
  L.expect(':');
612

613
  // After the definition we should have a list of statements, parse it:
614
  parseOperatorsList(g->block());
615

616
  // The last statement should be return, which specifies graph outputs
617
  parseReturnOperator();
618

619
  for (Node* n : deferred_tensor_value_initializations_) {
620
    auto type = n->output()->type()->expect<TensorType>();
621
    auto tt = n->output()->type()->cast<TensorType>();
622
    TORCH_INTERNAL_ASSERT(tt, "expected tensor output ", *n);
623
    auto sizes = tt->sizes().concrete_sizes();
624
    TORCH_INTERNAL_ASSERT(sizes);
625
    auto strides = tt->strides().concrete_sizes();
626
    TORCH_INTERNAL_ASSERT(strides);
627
    auto device = tt->device();
628
    TORCH_INTERNAL_ASSERT(device);
629
    auto dtype = tt->scalarType();
630
    TORCH_INTERNAL_ASSERT(dtype);
631
    auto options = at::TensorOptions(*device).dtype(*dtype);
632
    auto t = n->t_(attr::value, at::empty_strided(*sizes, *strides, options));
633
    (void)t;
634
  }
635

636
  for (Node* n : deferred_empty_container_initializations_) {
637
    auto type = n->output()->type();
638
    IValue val;
639
    if (type->kind() == TypeKind::ListType) {
640
      val = c10::impl::GenericList(type->containedType(0));
641
    } else if (type->kind() == TypeKind::DictType) {
642
      val = c10::impl::GenericDict(
643
          type->containedType(0), type->containedType(1));
644
    }
645
    n->ival_(attr::value, val);
646
  }
647
}
648

649
void IRParser::parseList(
650
    int begin,
651
    int sep,
652
    int end,
653
    const std::function<void()>& callback) {
654
  if (begin != TK_NOTHING) {
655
    L.expect(begin);
656
  }
657
  if (L.cur().kind != end) {
658
    do {
659
      callback();
660
    } while (L.nextIf(sep));
661
  }
662
  if (end != TK_NOTHING) {
663
    L.expect(end);
664
  }
665
}
666

667
Value* IRParser::findValueInVMap(const std::string& name) {
668
  if (!vmap.count(name)) {
669
    throw ErrorReport(L.cur().range)
670
        << "Cannot find a variable with name '" << name << "'";
671
  }
672
  return vmap.at(name);
673
}
674

675
} // namespace torch::jit
676

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.