pytorch

Форк
0
824 строки · 25.4 Кб
1
#include <torch/csrc/jit/frontend/parser.h>
2

3
#include <c10/util/Optional.h>
4
#include <torch/csrc/jit/frontend/lexer.h>
5
#include <torch/csrc/jit/frontend/parse_string_literal.h>
6
#include <torch/csrc/jit/frontend/tree.h>
7
#include <torch/csrc/jit/frontend/tree_views.h>
8

9
namespace torch::jit {
10

11
Decl mergeTypesFromTypeComment(
12
    const Decl& decl,
13
    const Decl& type_annotation_decl,
14
    bool is_method) {
15
  auto expected_num_annotations = decl.params().size();
16
  if (is_method) {
17
    // `self` argument
18
    expected_num_annotations -= 1;
19
  }
20
  if (expected_num_annotations != type_annotation_decl.params().size()) {
21
    throw ErrorReport(decl.range())
22
        << "Number of type annotations ("
23
        << type_annotation_decl.params().size()
24
        << ") did not match the number of "
25
        << (is_method ? "method" : "function") << " parameters ("
26
        << expected_num_annotations << ")";
27
  }
28
  auto old = decl.params();
29
  auto _new = type_annotation_decl.params();
30
  // Merge signature idents and ranges with annotation types
31

32
  std::vector<Param> new_params;
33
  size_t i = is_method ? 1 : 0;
34
  size_t j = 0;
35
  if (is_method) {
36
    new_params.push_back(old[0]);
37
  }
38
  for (; i < decl.params().size(); ++i, ++j) {
39
    new_params.emplace_back(old[i].withType(_new[j].type()));
40
  }
41
  return Decl::create(
42
      decl.range(),
43
      List<Param>::create(decl.range(), new_params),
44
      type_annotation_decl.return_type());
45
}
46

47
struct ParserImpl {
48
  explicit ParserImpl(const std::shared_ptr<Source>& source)
49
      : L(source), shared(sharedParserData()) {}
50

51
  Ident parseIdent() {
52
    auto t = L.expect(TK_IDENT);
53
    // whenever we parse something that has a TreeView type we always
54
    // use its create method so that the accessors and the constructor
55
    // of the Compound tree are in the same place.
56
    return Ident::create(t.range, t.text());
57
  }
58
  TreeRef createApply(const Expr& expr) {
59
    TreeList attributes;
60
    auto range = L.cur().range;
61
    TreeList inputs;
62
    parseArguments(inputs, attributes);
63
    return Apply::create(
64
        range,
65
        expr,
66
        List<Expr>(makeList(range, std::move(inputs))),
67
        List<Attribute>(makeList(range, std::move(attributes))));
68
  }
69

70
  static bool followsTuple(int kind) {
71
    switch (kind) {
72
      case TK_PLUS_EQ:
73
      case TK_MINUS_EQ:
74
      case TK_TIMES_EQ:
75
      case TK_DIV_EQ:
76
      case TK_MOD_EQ:
77
      case TK_BIT_OR_EQ:
78
      case TK_BIT_AND_EQ:
79
      case TK_BIT_XOR_EQ:
80
      case TK_LSHIFT_EQ:
81
      case TK_RSHIFT_EQ:
82
      case TK_POW_EQ:
83
      case TK_NEWLINE:
84
      case '=':
85
      case ')':
86
        return true;
87
      default:
88
        return false;
89
    }
90
  }
91

92
  // exp | expr, | expr, expr, ...
93
  Expr parseExpOrExpTuple() {
94
    auto prefix = parseExp();
95
    if (L.cur().kind == ',') {
96
      std::vector<Expr> exprs = {prefix};
97
      while (L.nextIf(',')) {
98
        if (followsTuple(L.cur().kind))
99
          break;
100
        exprs.push_back(parseExp());
101
      }
102
      auto list = List<Expr>::create(prefix.range(), exprs);
103
      prefix = TupleLiteral::create(list.range(), list);
104
    }
105
    return prefix;
106
  }
107
  // things like a 1.0 or a(4) that are not unary/binary expressions
108
  // and have higher precedence than all of them
109
  TreeRef parseBaseExp() {
110
    TreeRef prefix;
111
    switch (L.cur().kind) {
112
      case TK_NUMBER: {
113
        prefix = parseConst();
114
      } break;
115
      case TK_TRUE:
116
      case TK_FALSE:
117
      case TK_NONE:
118
      case TK_NONE_TYPE: {
119
        auto k = L.cur().kind;
120
        auto r = L.cur().range;
121
        prefix = create_compound(k, r, {});
122
        L.next();
123
      } break;
124
      case '(': {
125
        L.next();
126
        if (L.nextIf(')')) {
127
          /// here we have the empty tuple case
128
          std::vector<Expr> vecExpr;
129
          List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr);
130
          prefix = TupleLiteral::create(L.cur().range, listExpr);
131
          break;
132
        }
133
        prefix = parseExpOrExpTuple();
134
        L.expect(')');
135
      } break;
136
      case '[': {
137
        auto list = parseList('[', ',', ']', &ParserImpl::parseExp);
138

139
        if (list.size() == 1 && (*list.begin()).kind() == TK_LIST_COMP) {
140
          prefix = *list.begin();
141
        } else {
142
          for (auto se : list) {
143
            if (se.kind() == TK_LIST_COMP) {
144
              throw ErrorReport(list.range())
145
                  << " expected a single list comprehension within '[' , ']'";
146
            }
147
          }
148
          prefix = ListLiteral::create(list.range(), List<Expr>(list));
149
        }
150

151
      } break;
152
      case '{': {
153
        L.next();
154
        // If we have a dict literal, `keys` and `values` will store the keys
155
        // and values used in the object's construction. EDGE CASE: We have a
156
        // dict comprehension, so we'll get the first element of the dict
157
        // comprehension in `keys` and a list comprehension in `values`.
158
        // For example, `{i : chr(i + 65) for i in range(4)}` would give us
159
        // `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`.
160
        // The optimal way of handling this case is to simply splice the new
161
        // dict comprehension together from the existing list comprehension.
162
        // Splicing prevents breaking changes to our API and does not require
163
        // the use of global variables.
164
        std::vector<Expr> keys;
165
        std::vector<Expr> values;
166
        auto range = L.cur().range;
167
        if (L.cur().kind != '}') {
168
          do {
169
            keys.push_back(parseExp());
170
            L.expect(':');
171
            values.push_back(parseExp());
172
          } while (L.nextIf(','));
173
        }
174
        L.expect('}');
175
        if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) {
176
          ListComp lc(*values.begin());
177
          prefix = DictComp::create(
178
              range, *keys.begin(), lc.elt(), lc.target(), lc.iter());
179
        } else {
180
          prefix = DictLiteral::create(
181
              range,
182
              List<Expr>::create(range, keys),
183
              List<Expr>::create(range, values));
184
        }
185
      } break;
186
      case TK_STRINGLITERAL: {
187
        prefix = parseConcatenatedStringLiterals();
188
      } break;
189
      case TK_ELLIPSIS:
190
      case TK_DOTS: {
191
        prefix = Dots::create(L.cur().range);
192
        L.next();
193
      } break;
194
      default: {
195
        Ident name = parseIdent();
196
        prefix = Var::create(name.range(), name);
197
      } break;
198
    }
199
    while (true) {
200
      if (L.nextIf('.')) {
201
        const auto name = parseIdent();
202
        prefix = Select::create(name.range(), Expr(prefix), Ident(name));
203
      } else if (L.cur().kind == '(') {
204
        prefix = createApply(Expr(prefix));
205
      } else if (L.cur().kind == '[') {
206
        prefix = parseSubscript(prefix);
207
      } else {
208
        break;
209
      }
210
    }
211
    return prefix;
212
  }
213
  c10::optional<TreeRef> maybeParseAssignmentOp() {
214
    auto r = L.cur().range;
215
    switch (L.cur().kind) {
216
      case TK_PLUS_EQ:
217
      case TK_MINUS_EQ:
218
      case TK_TIMES_EQ:
219
      case TK_DIV_EQ:
220
      case TK_BIT_OR_EQ:
221
      case TK_BIT_AND_EQ:
222
      case TK_BIT_XOR_EQ:
223
      case TK_MOD_EQ: {
224
        int modifier = L.next().text()[0];
225
        return create_compound(modifier, r, {});
226
      } break;
227
      case TK_LSHIFT_EQ: {
228
        L.next();
229
        return create_compound(TK_LSHIFT, r, {});
230
      } break;
231
      case TK_RSHIFT_EQ: {
232
        L.next();
233
        return create_compound(TK_RSHIFT, r, {});
234
      } break;
235
      case TK_POW_EQ: {
236
        L.next();
237
        return create_compound(TK_POW, r, {});
238
      } break;
239
      case '=': {
240
        L.next();
241
        return create_compound('=', r, {}); // no reduction
242
      } break;
243
      default:
244
        return c10::nullopt;
245
    }
246
  }
247
  TreeRef parseTrinary(
248
      TreeRef true_branch,
249
      const SourceRange& range,
250
      int binary_prec) {
251
    auto cond = parseExp();
252
    L.expect(TK_ELSE);
253
    auto false_branch = parseExp(binary_prec);
254
    return create_compound(
255
        TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
256
  }
257
  // parse the longest expression whose binary operators have
258
  // precedence strictly greater than 'precedence'
259
  // precedence == 0 will parse _all_ expressions
260
  // this is the core loop of 'top-down precedence parsing'
261
  Expr parseExp() {
262
    return parseExp(0);
263
  }
264
  Expr parseExp(int precedence) {
265
    TreeRef prefix;
266
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
267
    int unary_prec;
268
    if (shared.isUnary(L.cur().kind, &unary_prec)) {
269
      auto kind = L.cur().kind;
270
      auto pos = L.cur().range;
271
      L.next();
272
      auto unary_kind = kind == '*' ? TK_STARRED
273
          : kind == '-'             ? TK_UNARY_MINUS
274
                                    : kind;
275
      auto subexp = parseExp(unary_prec);
276
      // fold '-' into constant numbers, so that attributes can accept
277
      // things like -1
278
      if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
279
        prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
280
      } else {
281
        prefix = create_compound(unary_kind, pos, {subexp});
282
      }
283
    } else {
284
      prefix = parseBaseExp();
285
    }
286
    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
287
    int binary_prec;
288
    while (shared.isBinary(L.cur().kind, &binary_prec)) {
289
      if (binary_prec <= precedence) // not allowed to parse something which is
290
        // not greater than 'precedence'
291
        break;
292

293
      int kind = L.cur().kind;
294
      auto pos = L.cur().range;
295
      L.next();
296
      if (shared.isRightAssociative(kind))
297
        binary_prec--;
298

299
      if (kind == TK_NOTIN) {
300
        // NB: `not in` is just `not( in )`, so we don't introduce new tree view
301
        // but just make it a nested call in our tree view structure
302
        prefix = create_compound(TK_IN, pos, {prefix, parseExp(binary_prec)});
303
        prefix = create_compound(TK_NOT, pos, {prefix});
304
        continue;
305
      }
306

307
      // special case for trinary operator
308
      if (kind == TK_IF) {
309
        prefix = parseTrinary(prefix, pos, binary_prec);
310
        continue;
311
      }
312

313
      if (kind == TK_FOR) {
314
        // TK_FOR targets should only parse exprs prec greater than 4, which
315
        // only includes subset of Exprs that suppose to be on the LHS according
316
        // to the python grammar
317
        // https://docs.python.org/3/reference/grammar.html
318
        auto target = parseLHSExp();
319
        L.expect(TK_IN);
320
        auto iter = parseExp();
321
        prefix = ListComp::create(pos, Expr(prefix), target, iter);
322
        continue;
323
      }
324

325
      prefix = create_compound(kind, pos, {prefix, parseExp(binary_prec)});
326
    }
327
    return Expr(prefix);
328
  }
329

330
  void parseSequence(
331
      int begin,
332
      int sep,
333
      int end,
334
      const std::function<void()>& parse) {
335
    if (begin != TK_NOTHING) {
336
      L.expect(begin);
337
    }
338
    while (end != L.cur().kind) {
339
      parse();
340
      if (!L.nextIf(sep)) {
341
        if (end != TK_NOTHING) {
342
          L.expect(end);
343
        }
344
        return;
345
      }
346
    }
347
    L.expect(end);
348
  }
349
  template <typename T>
350
  List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
351
    auto r = L.cur().range;
352
    std::vector<T> elements;
353
    parseSequence(
354
        begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
355
    return List<T>::create(r, elements);
356
  }
357

358
  Const parseConst() {
359
    auto range = L.cur().range;
360
    auto t = L.expect(TK_NUMBER);
361
    return Const::create(t.range, t.text());
362
  }
363

364
  StringLiteral parseConcatenatedStringLiterals() {
365
    auto range = L.cur().range;
366
    std::string ss;
367
    while (L.cur().kind == TK_STRINGLITERAL) {
368
      auto literal_range = L.cur().range;
369
      ss.append(parseStringLiteral(literal_range, L.next().text()));
370
    }
371
    return StringLiteral::create(range, ss);
372
  }
373

374
  Expr parseAttributeValue() {
375
    return parseExp();
376
  }
377

378
  void parseArguments(TreeList& inputs, TreeList& attributes) {
379
    parseSequence('(', ',', ')', [&] {
380
      if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
381
        auto ident = parseIdent();
382
        L.expect('=');
383
        auto v = parseAttributeValue();
384
        attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
385
      } else {
386
        inputs.push_back(parseExp());
387
      }
388
    });
389
  }
390

391
  // parse LHS acceptable exprs, which only includes subset of Exprs that prec
392
  // is greater than 4 according to the python grammar
393
  Expr parseLHSExp() {
394
    return parseExp(4);
395
  }
396

397
  // Parse expr's of the form [a:], [:b], [a:b], [:] and all variations with
398
  // "::"
399
  Expr parseSubscriptExp() {
400
    TreeRef first, second, third;
401
    auto range = L.cur().range;
402
    if (L.cur().kind != ':') {
403
      first = parseExp();
404
    }
405
    if (L.nextIf(':')) {
406
      if (L.cur().kind != ',' && L.cur().kind != ']' && L.cur().kind != ':') {
407
        second = parseExp();
408
      }
409
      if (L.nextIf(':')) {
410
        if (L.cur().kind != ',' && L.cur().kind != ']') {
411
          third = parseExp();
412
        }
413
      }
414
      auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
415
                               : Maybe<Expr>::create(range);
416
      auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
417
                                 : Maybe<Expr>::create(range);
418
      auto maybe_third = third ? Maybe<Expr>::create(range, Expr(third))
419
                               : Maybe<Expr>::create(range);
420
      return SliceExpr::create(range, maybe_first, maybe_second, maybe_third);
421
    } else {
422
      return Expr(first);
423
    }
424
  }
425

426
  TreeRef parseSubscript(const TreeRef& value) {
427
    const auto range = L.cur().range;
428

429
    auto subscript_exprs =
430
        parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
431

432
    const auto whole_range =
433
        SourceRange(range.source(), range.start(), L.cur().range.start());
434
    return Subscript::create(whole_range, Expr(value), subscript_exprs);
435
  }
436

437
  Maybe<Expr> maybeParseTypeAnnotation() {
438
    if (L.nextIf(':')) {
439
      // NB: parseExp must not be called inline, since argument evaluation order
440
      // changes when L.cur().range is mutated with respect to the parseExp()
441
      // call.
442
      auto expr = parseExp();
443
      return Maybe<Expr>::create(expr.range(), expr);
444
    } else {
445
      return Maybe<Expr>::create(L.cur().range);
446
    }
447
  }
448

449
  TreeRef parseFormalParam(bool kwarg_only) {
450
    auto ident = parseIdent();
451
    TreeRef type = maybeParseTypeAnnotation();
452
    TreeRef def;
453
    if (L.nextIf('=')) {
454
      // NB: parseExp must not be called inline, since argument evaluation order
455
      // changes when L.cur().range is mutated with respect to the parseExp()
456
      // call.
457
      auto expr = parseExp();
458
      def = Maybe<Expr>::create(expr.range(), expr);
459
    } else {
460
      def = Maybe<Expr>::create(L.cur().range);
461
    }
462
    return Param::create(
463
        type->range(),
464
        Ident(ident),
465
        Maybe<Expr>(type),
466
        Maybe<Expr>(def),
467
        kwarg_only);
468
  }
469

470
  Param parseBareTypeAnnotation() {
471
    auto type = parseExp();
472
    return Param::create(
473
        type.range(),
474
        Ident::create(type.range(), ""),
475
        Maybe<Expr>::create(type.range(), type),
476
        Maybe<Expr>::create(type.range()),
477
        /*kwarg_only=*/false);
478
  }
479

480
  Decl parseTypeComment() {
481
    auto range = L.cur().range;
482
    L.expect(TK_TYPE_COMMENT);
483
    auto param_types =
484
        parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
485
    TreeRef return_type;
486
    if (L.nextIf(TK_ARROW)) {
487
      auto return_type_range = L.cur().range;
488
      return_type = Maybe<Expr>::create(return_type_range, parseExp());
489
    } else {
490
      return_type = Maybe<Expr>::create(L.cur().range);
491
    }
492
    return Decl::create(range, param_types, Maybe<Expr>(return_type));
493
  }
494

495
  // 'first' has already been parsed since expressions can exist
496
  // alone on a line:
497
  // first[,other,lhs] = rhs
498
  TreeRef parseAssign(const Expr& lhs) {
499
    auto type = maybeParseTypeAnnotation();
500
    auto maybeOp = maybeParseAssignmentOp();
501
    if (maybeOp) {
502
      // There is an assignment operator, parse the RHS and generate the
503
      // assignment.
504
      auto rhs = parseExpOrExpTuple();
505
      if (maybeOp.value()->kind() == '=') {
506
        std::vector<Expr> lhs_list = {lhs};
507
        while (L.nextIf('=')) {
508
          lhs_list.push_back(rhs);
509
          rhs = parseExpOrExpTuple();
510
        }
511
        if (type.present() && lhs_list.size() > 1) {
512
          throw ErrorReport(type.range())
513
              << "Annotated multiple assignment is not supported in python";
514
        }
515
        L.expect(TK_NEWLINE);
516
        return Assign::create(
517
            lhs.range(),
518
            List<Expr>::create(lhs_list[0].range(), lhs_list),
519
            Maybe<Expr>::create(rhs.range(), rhs),
520
            type);
521
      } else {
522
        L.expect(TK_NEWLINE);
523
        // this is an augmented assignment
524
        if (lhs.kind() == TK_TUPLE_LITERAL) {
525
          throw ErrorReport(lhs.range())
526
              << " augmented assignment can only have one LHS expression";
527
        }
528
        return AugAssign::create(
529
            lhs.range(), lhs, AugAssignKind(*maybeOp), Expr(rhs));
530
      }
531
    } else {
532
      // There is no assignment operator, so this is of the form `lhs : <type>`
533
      TORCH_INTERNAL_ASSERT(type.present());
534
      L.expect(TK_NEWLINE);
535
      return Assign::create(
536
          lhs.range(),
537
          List<Expr>::create(lhs.range(), {lhs}),
538
          Maybe<Expr>::create(lhs.range()),
539
          type);
540
    }
541
  }
542

543
  TreeRef parseStmt(bool in_class = false) {
544
    switch (L.cur().kind) {
545
      case TK_IF:
546
        return parseIf();
547
      case TK_WHILE:
548
        return parseWhile();
549
      case TK_FOR:
550
        return parseFor();
551
      case TK_GLOBAL: {
552
        auto range = L.next().range;
553
        auto idents =
554
            parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
555
        L.expect(TK_NEWLINE);
556
        return Global::create(range, idents);
557
      }
558
      case TK_RETURN: {
559
        auto range = L.next().range;
560
        Expr value = L.cur().kind != TK_NEWLINE
561
            ? parseExpOrExpTuple()
562
            : Expr(create_compound(TK_NONE, range, {}));
563
        L.expect(TK_NEWLINE);
564
        return Return::create(range, value);
565
      }
566
      case TK_RAISE: {
567
        auto range = L.next().range;
568
        auto expr = parseExp();
569
        L.expect(TK_NEWLINE);
570
        return Raise::create(range, expr);
571
      }
572
      case TK_ASSERT: {
573
        auto range = L.next().range;
574
        auto cond = parseExp();
575
        Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
576
        if (L.nextIf(',')) {
577
          auto msg = parseExp();
578
          maybe_first = Maybe<Expr>::create(range, Expr(msg));
579
        }
580
        L.expect(TK_NEWLINE);
581
        return Assert::create(range, cond, maybe_first);
582
      }
583
      case TK_BREAK: {
584
        auto range = L.next().range;
585
        L.expect(TK_NEWLINE);
586
        return Break::create(range);
587
      }
588
      case TK_CONTINUE: {
589
        auto range = L.next().range;
590
        L.expect(TK_NEWLINE);
591
        return Continue::create(range);
592
      }
593
      case TK_PASS: {
594
        auto range = L.next().range;
595
        L.expect(TK_NEWLINE);
596
        return Pass::create(range);
597
      }
598
      case TK_DEF: {
599
        return parseFunction(/*is_method=*/in_class);
600
      }
601
      case TK_DELETE: {
602
        auto range = L.next().range;
603
        auto targets =
604
            parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
605
        L.expect(TK_NEWLINE);
606
        return Delete::create(range, targets);
607
      }
608
      case TK_WITH: {
609
        return parseWith();
610
      }
611
      default: {
612
        auto lhs = parseExpOrExpTuple();
613
        if (L.cur().kind != TK_NEWLINE) {
614
          return parseAssign(lhs);
615
        } else {
616
          L.expect(TK_NEWLINE);
617
          return ExprStmt::create(lhs.range(), lhs);
618
        }
619
      }
620
    }
621
  }
622

623
  WithItem parseWithItem() {
624
    auto target = parseExp();
625

626
    if (L.cur().kind == TK_AS) {
627
      // If the current token is TK_AS, this with item is of the form
628
      // "expression as target".
629
      auto token = L.expect(TK_AS);
630
      Ident ident = parseIdent();
631
      auto var = Var::create(ident.range(), ident);
632
      return WithItem::create(
633
          token.range, target, Maybe<Var>::create(ident.range(), var));
634
    } else {
635
      // If not, this with item is of the form "expression".
636
      return WithItem::create(
637
          target.range(), target, Maybe<Var>::create(target.range()));
638
    }
639
  }
640

641
  TreeRef parseIf(bool expect_if = true) {
642
    auto r = L.cur().range;
643
    if (expect_if)
644
      L.expect(TK_IF);
645
    auto cond = parseExp();
646
    L.expect(':');
647
    auto true_branch = parseStatements(/*expect_indent=*/true);
648
    auto false_branch = makeList(L.cur().range, {});
649
    if (L.nextIf(TK_ELSE)) {
650
      L.expect(':');
651
      false_branch = parseStatements(/*expect_indent=*/true);
652
    } else if (L.nextIf(TK_ELIF)) {
653
      // NB: this needs to be a separate statement, since the call to parseIf
654
      // mutates the lexer state, and thus causes a heap-use-after-free in
655
      // compilers which evaluate argument expressions LTR
656
      auto range = L.cur().range;
657
      false_branch = makeList(range, {parseIf(false)});
658
    }
659
    return If::create(
660
        r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
661
  }
662
  TreeRef parseWhile() {
663
    auto r = L.cur().range;
664
    L.expect(TK_WHILE);
665
    auto cond = parseExp();
666
    L.expect(':');
667
    auto body = parseStatements(/*expect_indent=*/true);
668
    return While::create(r, Expr(cond), List<Stmt>(body));
669
  }
670

671
  TreeRef parseFor() {
672
    auto r = L.cur().range;
673
    L.expect(TK_FOR);
674
    auto targets = parseList(TK_NOTHING, ',', TK_IN, &ParserImpl::parseLHSExp);
675
    auto itrs = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseExp);
676
    auto body = parseStatements(/*expect_indent=*/true);
677
    return For::create(r, targets, itrs, body);
678
  }
679

680
  TreeRef parseWith() {
681
    auto r = L.cur().range;
682
    // Parse "with expression [as target][, expression [as target]]*:".
683
    L.expect(TK_WITH);
684
    auto targets = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseWithItem);
685
    // Parse the body.
686
    auto body = parseStatements(/*expect_indent=*/true);
687
    return With::create(r, targets, body);
688
  }
689

690
  TreeRef parseStatements(bool expect_indent, bool in_class = false) {
691
    auto r = L.cur().range;
692
    if (expect_indent) {
693
      L.expect(TK_INDENT);
694
    }
695
    TreeList stmts;
696
    do {
697
      stmts.push_back(parseStmt(in_class));
698
    } while (!L.nextIf(TK_DEDENT));
699
    return create_compound(TK_LIST, r, std::move(stmts));
700
  }
701

702
  Maybe<Expr> parseReturnAnnotation() {
703
    if (L.nextIf(TK_ARROW)) {
704
      // Exactly one expression for return type annotation
705
      auto return_type_range = L.cur().range;
706
      return Maybe<Expr>::create(return_type_range, parseExp());
707
    } else {
708
      return Maybe<Expr>::create(L.cur().range);
709
    }
710
  }
711

712
  List<Param> parseFormalParams() {
713
    auto r = L.cur().range;
714
    std::vector<Param> params;
715
    bool kwarg_only = false;
716
    parseSequence('(', ',', ')', [&] {
717
      if (!kwarg_only && L.nextIf('*')) {
718
        kwarg_only = true;
719
      } else {
720
        params.emplace_back(parseFormalParam(kwarg_only));
721
      }
722
    });
723
    return List<Param>::create(r, params);
724
  }
725
  Decl parseDecl() {
726
    // Parse return type annotation
727
    List<Param> paramlist = parseFormalParams();
728
    TreeRef return_type;
729
    Maybe<Expr> return_annotation = parseReturnAnnotation();
730
    L.expect(':');
731
    return Decl::create(
732
        paramlist.range(), List<Param>(paramlist), return_annotation);
733
  }
734

735
  TreeRef parseClass() {
736
    L.expect(TK_CLASS_DEF);
737
    const auto name = parseIdent();
738
    Maybe<Expr> superclass = Maybe<Expr>::create(name.range());
739
    if (L.nextIf('(')) {
740
      // Only support inheriting from NamedTuple right now.
741
      auto id = parseExp();
742
      superclass = Maybe<Expr>::create(id.range(), id);
743
      L.expect(')');
744
    }
745
    L.expect(':');
746
    const auto statements =
747
        parseStatements(/*expect_indent=*/true, /*in_class=*/true);
748
    return ClassDef::create(
749
        name.range(), name, superclass, List<Stmt>(statements));
750
  }
751

752
  TreeRef parseFunction(bool is_method) {
753
    L.expect(TK_DEF);
754
    auto name = parseIdent();
755
    auto decl = parseDecl();
756

757
    TreeRef stmts_list;
758
    if (L.nextIf(TK_INDENT)) {
759
      // Handle type annotations specified in a type comment as the first line
760
      // of the function.
761
      if (L.cur().kind == TK_TYPE_COMMENT) {
762
        auto type_annotation_decl = Decl(parseTypeComment());
763
        L.expect(TK_NEWLINE);
764
        decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
765
      }
766

767
      stmts_list = parseStatements(false);
768
    } else {
769
      // Special case: the Python grammar allows one-line functions with a
770
      // single statement.
771
      if (L.cur().kind == TK_TYPE_COMMENT) {
772
        auto type_annotation_decl = Decl(parseTypeComment());
773
        decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
774
      }
775

776
      TreeList stmts;
777
      stmts.push_back(parseStmt(is_method));
778
      stmts_list = create_compound(TK_LIST, L.cur().range, std::move(stmts));
779
    }
780

781
    return Def::create(
782
        name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
783
  }
784
  Lexer& lexer() {
785
    return L;
786
  }
787

788
 private:
789
  // short helpers to create nodes
790
  TreeRef create_compound(
791
      int kind,
792
      const SourceRange& range,
793
      TreeList&& trees) {
794
    return Compound::create(kind, range, std::move(trees));
795
  }
796
  TreeRef makeList(const SourceRange& range, TreeList&& trees) {
797
    return create_compound(TK_LIST, range, std::move(trees));
798
  }
799
  Lexer L;
800
  SharedParserData& shared;
801
};
802

803
Parser::Parser(const std::shared_ptr<Source>& src)
804
    : pImpl(new ParserImpl(src)) {}
805

806
Parser::~Parser() = default;
807

808
TreeRef Parser::parseFunction(bool is_method) {
809
  return pImpl->parseFunction(is_method);
810
}
811
TreeRef Parser::parseClass() {
812
  return pImpl->parseClass();
813
}
814
Lexer& Parser::lexer() {
815
  return pImpl->lexer();
816
}
817
Decl Parser::parseTypeComment() {
818
  return pImpl->parseTypeComment();
819
}
820
Expr Parser::parseExp() {
821
  return pImpl->parseExp();
822
}
823

824
} // namespace torch::jit
825

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.