1
#include <torch/csrc/jit/frontend/parser.h>
3
#include <c10/util/Optional.h>
4
#include <torch/csrc/jit/frontend/lexer.h>
5
#include <torch/csrc/jit/frontend/parse_string_literal.h>
6
#include <torch/csrc/jit/frontend/tree.h>
7
#include <torch/csrc/jit/frontend/tree_views.h>
11
Decl mergeTypesFromTypeComment(
13
const Decl& type_annotation_decl,
15
auto expected_num_annotations = decl.params().size();
18
expected_num_annotations -= 1;
20
if (expected_num_annotations != type_annotation_decl.params().size()) {
21
throw ErrorReport(decl.range())
22
<< "Number of type annotations ("
23
<< type_annotation_decl.params().size()
24
<< ") did not match the number of "
25
<< (is_method ? "method" : "function") << " parameters ("
26
<< expected_num_annotations << ")";
28
auto old = decl.params();
29
auto _new = type_annotation_decl.params();
32
std::vector<Param> new_params;
33
size_t i = is_method ? 1 : 0;
36
new_params.push_back(old[0]);
38
for (; i < decl.params().size(); ++i, ++j) {
39
new_params.emplace_back(old[i].withType(_new[j].type()));
43
List<Param>::create(decl.range(), new_params),
44
type_annotation_decl.return_type());
48
explicit ParserImpl(const std::shared_ptr<Source>& source)
49
: L(source), shared(sharedParserData()) {}
52
auto t = L.expect(TK_IDENT);
56
return Ident::create(t.range, t.text());
58
TreeRef createApply(const Expr& expr) {
60
auto range = L.cur().range;
62
parseArguments(inputs, attributes);
66
List<Expr>(makeList(range, std::move(inputs))),
67
List<Attribute>(makeList(range, std::move(attributes))));
70
static bool followsTuple(int kind) {
93
Expr parseExpOrExpTuple() {
94
auto prefix = parseExp();
95
if (L.cur().kind == ',') {
96
std::vector<Expr> exprs = {prefix};
97
while (L.nextIf(',')) {
98
if (followsTuple(L.cur().kind))
100
exprs.push_back(parseExp());
102
auto list = List<Expr>::create(prefix.range(), exprs);
103
prefix = TupleLiteral::create(list.range(), list);
109
TreeRef parseBaseExp() {
111
switch (L.cur().kind) {
113
prefix = parseConst();
119
auto k = L.cur().kind;
120
auto r = L.cur().range;
121
prefix = create_compound(k, r, {});
128
std::vector<Expr> vecExpr;
129
List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr);
130
prefix = TupleLiteral::create(L.cur().range, listExpr);
133
prefix = parseExpOrExpTuple();
137
auto list = parseList('[', ',', ']', &ParserImpl::parseExp);
139
if (list.size() == 1 && (*list.begin()).kind() == TK_LIST_COMP) {
140
prefix = *list.begin();
142
for (auto se : list) {
143
if (se.kind() == TK_LIST_COMP) {
144
throw ErrorReport(list.range())
145
<< " expected a single list comprehension within '[' , ']'";
148
prefix = ListLiteral::create(list.range(), List<Expr>(list));
164
std::vector<Expr> keys;
165
std::vector<Expr> values;
166
auto range = L.cur().range;
167
if (L.cur().kind != '}') {
169
keys.push_back(parseExp());
171
values.push_back(parseExp());
172
} while (L.nextIf(','));
175
if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) {
176
ListComp lc(*values.begin());
177
prefix = DictComp::create(
178
range, *keys.begin(), lc.elt(), lc.target(), lc.iter());
180
prefix = DictLiteral::create(
182
List<Expr>::create(range, keys),
183
List<Expr>::create(range, values));
186
case TK_STRINGLITERAL: {
187
prefix = parseConcatenatedStringLiterals();
191
prefix = Dots::create(L.cur().range);
195
Ident name = parseIdent();
196
prefix = Var::create(name.range(), name);
201
const auto name = parseIdent();
202
prefix = Select::create(name.range(), Expr(prefix), Ident(name));
203
} else if (L.cur().kind == '(') {
204
prefix = createApply(Expr(prefix));
205
} else if (L.cur().kind == '[') {
206
prefix = parseSubscript(prefix);
213
c10::optional<TreeRef> maybeParseAssignmentOp() {
214
auto r = L.cur().range;
215
switch (L.cur().kind) {
224
int modifier = L.next().text()[0];
225
return create_compound(modifier, r, {});
229
return create_compound(TK_LSHIFT, r, {});
233
return create_compound(TK_RSHIFT, r, {});
237
return create_compound(TK_POW, r, {});
241
return create_compound('=', r, {});
247
TreeRef parseTrinary(
249
const SourceRange& range,
251
auto cond = parseExp();
253
auto false_branch = parseExp(binary_prec);
254
return create_compound(
255
TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
264
Expr parseExp(int precedence) {
268
if (shared.isUnary(L.cur().kind, &unary_prec)) {
269
auto kind = L.cur().kind;
270
auto pos = L.cur().range;
272
auto unary_kind = kind == '*' ? TK_STARRED
273
: kind == '-' ? TK_UNARY_MINUS
275
auto subexp = parseExp(unary_prec);
278
if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
279
prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
281
prefix = create_compound(unary_kind, pos, {subexp});
284
prefix = parseBaseExp();
288
while (shared.isBinary(L.cur().kind, &binary_prec)) {
289
if (binary_prec <= precedence)
293
int kind = L.cur().kind;
294
auto pos = L.cur().range;
296
if (shared.isRightAssociative(kind))
299
if (kind == TK_NOTIN) {
302
prefix = create_compound(TK_IN, pos, {prefix, parseExp(binary_prec)});
303
prefix = create_compound(TK_NOT, pos, {prefix});
309
prefix = parseTrinary(prefix, pos, binary_prec);
313
if (kind == TK_FOR) {
318
auto target = parseLHSExp();
320
auto iter = parseExp();
321
prefix = ListComp::create(pos, Expr(prefix), target, iter);
325
prefix = create_compound(kind, pos, {prefix, parseExp(binary_prec)});
334
const std::function<void()>& parse) {
335
if (begin != TK_NOTHING) {
338
while (end != L.cur().kind) {
340
if (!L.nextIf(sep)) {
341
if (end != TK_NOTHING) {
349
template <typename T>
350
List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
351
auto r = L.cur().range;
352
std::vector<T> elements;
354
begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
355
return List<T>::create(r, elements);
359
auto range = L.cur().range;
360
auto t = L.expect(TK_NUMBER);
361
return Const::create(t.range, t.text());
364
StringLiteral parseConcatenatedStringLiterals() {
365
auto range = L.cur().range;
367
while (L.cur().kind == TK_STRINGLITERAL) {
368
auto literal_range = L.cur().range;
369
ss.append(parseStringLiteral(literal_range, L.next().text()));
371
return StringLiteral::create(range, ss);
374
Expr parseAttributeValue() {
378
void parseArguments(TreeList& inputs, TreeList& attributes) {
379
parseSequence('(', ',', ')', [&] {
380
if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
381
auto ident = parseIdent();
383
auto v = parseAttributeValue();
384
attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
386
inputs.push_back(parseExp());
399
Expr parseSubscriptExp() {
400
TreeRef first, second, third;
401
auto range = L.cur().range;
402
if (L.cur().kind != ':') {
406
if (L.cur().kind != ',' && L.cur().kind != ']' && L.cur().kind != ':') {
410
if (L.cur().kind != ',' && L.cur().kind != ']') {
414
auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
415
: Maybe<Expr>::create(range);
416
auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
417
: Maybe<Expr>::create(range);
418
auto maybe_third = third ? Maybe<Expr>::create(range, Expr(third))
419
: Maybe<Expr>::create(range);
420
return SliceExpr::create(range, maybe_first, maybe_second, maybe_third);
426
TreeRef parseSubscript(const TreeRef& value) {
427
const auto range = L.cur().range;
429
auto subscript_exprs =
430
parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
432
const auto whole_range =
433
SourceRange(range.source(), range.start(), L.cur().range.start());
434
return Subscript::create(whole_range, Expr(value), subscript_exprs);
437
Maybe<Expr> maybeParseTypeAnnotation() {
442
auto expr = parseExp();
443
return Maybe<Expr>::create(expr.range(), expr);
445
return Maybe<Expr>::create(L.cur().range);
449
TreeRef parseFormalParam(bool kwarg_only) {
450
auto ident = parseIdent();
451
TreeRef type = maybeParseTypeAnnotation();
457
auto expr = parseExp();
458
def = Maybe<Expr>::create(expr.range(), expr);
460
def = Maybe<Expr>::create(L.cur().range);
462
return Param::create(
470
Param parseBareTypeAnnotation() {
471
auto type = parseExp();
472
return Param::create(
474
Ident::create(type.range(), ""),
475
Maybe<Expr>::create(type.range(), type),
476
Maybe<Expr>::create(type.range()),
480
Decl parseTypeComment() {
481
auto range = L.cur().range;
482
L.expect(TK_TYPE_COMMENT);
484
parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
486
if (L.nextIf(TK_ARROW)) {
487
auto return_type_range = L.cur().range;
488
return_type = Maybe<Expr>::create(return_type_range, parseExp());
490
return_type = Maybe<Expr>::create(L.cur().range);
492
return Decl::create(range, param_types, Maybe<Expr>(return_type));
498
TreeRef parseAssign(const Expr& lhs) {
499
auto type = maybeParseTypeAnnotation();
500
auto maybeOp = maybeParseAssignmentOp();
504
auto rhs = parseExpOrExpTuple();
505
if (maybeOp.value()->kind() == '=') {
506
std::vector<Expr> lhs_list = {lhs};
507
while (L.nextIf('=')) {
508
lhs_list.push_back(rhs);
509
rhs = parseExpOrExpTuple();
511
if (type.present() && lhs_list.size() > 1) {
512
throw ErrorReport(type.range())
513
<< "Annotated multiple assignment is not supported in python";
515
L.expect(TK_NEWLINE);
516
return Assign::create(
518
List<Expr>::create(lhs_list[0].range(), lhs_list),
519
Maybe<Expr>::create(rhs.range(), rhs),
522
L.expect(TK_NEWLINE);
524
if (lhs.kind() == TK_TUPLE_LITERAL) {
525
throw ErrorReport(lhs.range())
526
<< " augmented assignment can only have one LHS expression";
528
return AugAssign::create(
529
lhs.range(), lhs, AugAssignKind(*maybeOp), Expr(rhs));
533
TORCH_INTERNAL_ASSERT(type.present());
534
L.expect(TK_NEWLINE);
535
return Assign::create(
537
List<Expr>::create(lhs.range(), {lhs}),
538
Maybe<Expr>::create(lhs.range()),
543
TreeRef parseStmt(bool in_class = false) {
544
switch (L.cur().kind) {
552
auto range = L.next().range;
554
parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
555
L.expect(TK_NEWLINE);
556
return Global::create(range, idents);
559
auto range = L.next().range;
560
Expr value = L.cur().kind != TK_NEWLINE
561
? parseExpOrExpTuple()
562
: Expr(create_compound(TK_NONE, range, {}));
563
L.expect(TK_NEWLINE);
564
return Return::create(range, value);
567
auto range = L.next().range;
568
auto expr = parseExp();
569
L.expect(TK_NEWLINE);
570
return Raise::create(range, expr);
573
auto range = L.next().range;
574
auto cond = parseExp();
575
Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
577
auto msg = parseExp();
578
maybe_first = Maybe<Expr>::create(range, Expr(msg));
580
L.expect(TK_NEWLINE);
581
return Assert::create(range, cond, maybe_first);
584
auto range = L.next().range;
585
L.expect(TK_NEWLINE);
586
return Break::create(range);
589
auto range = L.next().range;
590
L.expect(TK_NEWLINE);
591
return Continue::create(range);
594
auto range = L.next().range;
595
L.expect(TK_NEWLINE);
596
return Pass::create(range);
599
return parseFunction(in_class);
602
auto range = L.next().range;
604
parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
605
L.expect(TK_NEWLINE);
606
return Delete::create(range, targets);
612
auto lhs = parseExpOrExpTuple();
613
if (L.cur().kind != TK_NEWLINE) {
614
return parseAssign(lhs);
616
L.expect(TK_NEWLINE);
617
return ExprStmt::create(lhs.range(), lhs);
623
WithItem parseWithItem() {
624
auto target = parseExp();
626
if (L.cur().kind == TK_AS) {
629
auto token = L.expect(TK_AS);
630
Ident ident = parseIdent();
631
auto var = Var::create(ident.range(), ident);
632
return WithItem::create(
633
token.range, target, Maybe<Var>::create(ident.range(), var));
636
return WithItem::create(
637
target.range(), target, Maybe<Var>::create(target.range()));
641
TreeRef parseIf(bool expect_if = true) {
642
auto r = L.cur().range;
645
auto cond = parseExp();
647
auto true_branch = parseStatements(true);
648
auto false_branch = makeList(L.cur().range, {});
649
if (L.nextIf(TK_ELSE)) {
651
false_branch = parseStatements(true);
652
} else if (L.nextIf(TK_ELIF)) {
656
auto range = L.cur().range;
657
false_branch = makeList(range, {parseIf(false)});
660
r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
662
TreeRef parseWhile() {
663
auto r = L.cur().range;
665
auto cond = parseExp();
667
auto body = parseStatements(true);
668
return While::create(r, Expr(cond), List<Stmt>(body));
672
auto r = L.cur().range;
674
auto targets = parseList(TK_NOTHING, ',', TK_IN, &ParserImpl::parseLHSExp);
675
auto itrs = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseExp);
676
auto body = parseStatements(true);
677
return For::create(r, targets, itrs, body);
680
TreeRef parseWith() {
681
auto r = L.cur().range;
684
auto targets = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseWithItem);
686
auto body = parseStatements(true);
687
return With::create(r, targets, body);
690
TreeRef parseStatements(bool expect_indent, bool in_class = false) {
691
auto r = L.cur().range;
697
stmts.push_back(parseStmt(in_class));
698
} while (!L.nextIf(TK_DEDENT));
699
return create_compound(TK_LIST, r, std::move(stmts));
702
Maybe<Expr> parseReturnAnnotation() {
703
if (L.nextIf(TK_ARROW)) {
705
auto return_type_range = L.cur().range;
706
return Maybe<Expr>::create(return_type_range, parseExp());
708
return Maybe<Expr>::create(L.cur().range);
712
List<Param> parseFormalParams() {
713
auto r = L.cur().range;
714
std::vector<Param> params;
715
bool kwarg_only = false;
716
parseSequence('(', ',', ')', [&] {
717
if (!kwarg_only && L.nextIf('*')) {
720
params.emplace_back(parseFormalParam(kwarg_only));
723
return List<Param>::create(r, params);
727
List<Param> paramlist = parseFormalParams();
729
Maybe<Expr> return_annotation = parseReturnAnnotation();
732
paramlist.range(), List<Param>(paramlist), return_annotation);
735
TreeRef parseClass() {
736
L.expect(TK_CLASS_DEF);
737
const auto name = parseIdent();
738
Maybe<Expr> superclass = Maybe<Expr>::create(name.range());
741
auto id = parseExp();
742
superclass = Maybe<Expr>::create(id.range(), id);
746
const auto statements =
747
parseStatements(true, true);
748
return ClassDef::create(
749
name.range(), name, superclass, List<Stmt>(statements));
752
TreeRef parseFunction(bool is_method) {
754
auto name = parseIdent();
755
auto decl = parseDecl();
758
if (L.nextIf(TK_INDENT)) {
761
if (L.cur().kind == TK_TYPE_COMMENT) {
762
auto type_annotation_decl = Decl(parseTypeComment());
763
L.expect(TK_NEWLINE);
764
decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
767
stmts_list = parseStatements(false);
771
if (L.cur().kind == TK_TYPE_COMMENT) {
772
auto type_annotation_decl = Decl(parseTypeComment());
773
decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
777
stmts.push_back(parseStmt(is_method));
778
stmts_list = create_compound(TK_LIST, L.cur().range, std::move(stmts));
782
name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
790
TreeRef create_compound(
792
const SourceRange& range,
794
return Compound::create(kind, range, std::move(trees));
796
TreeRef makeList(const SourceRange& range, TreeList&& trees) {
797
return create_compound(TK_LIST, range, std::move(trees));
800
SharedParserData& shared;
803
Parser::Parser(const std::shared_ptr<Source>& src)
804
: pImpl(new ParserImpl(src)) {}
806
Parser::~Parser() = default;
808
TreeRef Parser::parseFunction(bool is_method) {
809
return pImpl->parseFunction(is_method);
811
TreeRef Parser::parseClass() {
812
return pImpl->parseClass();
814
Lexer& Parser::lexer() {
815
return pImpl->lexer();
817
Decl Parser::parseTypeComment() {
818
return pImpl->parseTypeComment();
820
Expr Parser::parseExp() {
821
return pImpl->parseExp();