1
#include <torch/csrc/jit/ir/irparser.h>
3
#include <ATen/EmptyTensor.h>
4
#include <torch/csrc/jit/frontend/lexer.h>
5
#include <torch/csrc/jit/frontend/parse_string_literal.h>
6
#include <torch/csrc/jit/frontend/schema_type_parser.h>
7
#include <torch/csrc/jit/ir/ir.h>
9
#ifndef AT_PER_OPERATOR_HEADERS
10
#include <ATen/Functions.h>
12
#include <ATen/ops/empty.h>
13
#include <ATen/ops/empty_strided.h>
26
const std::string& str,
27
torch::jit::Graph* graph,
28
std::unordered_map<std::string, Value*>& vmap,
29
bool parse_tensor_constants);
31
const std::string& str,
32
torch::jit::Graph* graph,
33
std::unordered_map<std::string, Value*>& vmap,
34
bool parse_tensor_constants)
35
: L(std::make_shared<Source>(str)),
39
parse_tensor_constants_(parse_tensor_constants) {}
41
std::string parseVar();
42
VarWithType parseVarWithType(bool allow_optional = false);
43
ParsedLiteral parseScalarLiteral(Node* n);
46
void parseGraphInputs();
47
void parseReturnOperator();
49
void parseBlocks(Node* parentNode);
50
void parseBlock(Node* parentNode);
51
void parseBlockInputs(Block* b);
52
void parseBlockOutputs(Block* b);
54
void parseOperatorsList(Block* b);
55
void parseOperator(Block* b);
56
void parseOperatorOutputs(std::vector<VarWithType>* outs);
57
std::string parseOperatorName();
58
void parseOperatorInputs(Node* n);
59
void parseAttrs(Node* n);
60
void parseAttr(Node* n);
66
const std::function<void()>& callback);
68
void bypassTypeAnnotationList();
70
Value* findValueInVMap(const std::string& name);
73
torch::jit::Graph* g = nullptr;
74
std::unordered_map<std::string, Value*>& vmap;
75
SchemaTypeParser type_parser;
76
bool parse_tensor_constants_;
77
std::vector<Node*> deferred_tensor_value_initializations_;
78
std::vector<Node*> deferred_empty_container_initializations_;
82
ParsedLiteral() = default;
84
AttributeKind k = AttributeKind::t;
89
c10::complex<double> c = c10::complex<double>(0, 0);
91
std::vector<int64_t> is;
92
std::vector<std::string> ss;
93
std::vector<double> fs;
94
std::vector<c10::complex<double>> cs;
95
std::vector<TypePtr> tys;
99
VarWithType() = default;
105
const std::string& str,
106
torch::jit::Graph* graph,
107
std::unordered_map<std::string, Value*>& vmap,
108
bool parse_tensor_constants) {
109
torch::jit::IRParser p(str, graph, vmap, parse_tensor_constants);
114
const std::string& str,
115
torch::jit::Graph* graph,
116
bool parse_tensor_constants) {
117
std::unordered_map<std::string, Value*> vmap;
118
parseIR(str, graph, vmap, parse_tensor_constants);
121
VarWithType IRParser::parseVarWithType(bool allow_optional) {
124
if (allow_optional) {
127
r.type = TensorType::get();
130
auto type_alias = type_parser.parseType();
131
AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
132
r.type = type_alias.first;
137
std::string IRParser::parseVar() {
140
bool continue_parsing;
142
if (L.cur().kind == TK_IDENT) {
143
name += L.expect(TK_IDENT).text();
145
name += L.expect(TK_NUMBER).text();
147
continue_parsing = false;
149
continue_parsing = true;
151
} else if (L.cur().kind == TK_NUMBER && L.cur().text()[0] == '.') {
152
continue_parsing = true;
154
} while (continue_parsing);
158
void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {
159
if (L.cur().kind != '%') {
162
parseList(TK_NOTHING, ',', TK_NOTHING, [&] {
163
outs->push_back(parseVarWithType(true));
169
ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
170
auto token = L.cur();
172
std::pair<TypePtr, c10::optional<c10::AliasInfo>> type_alias;
174
switch (token.kind) {
175
case TK_STRINGLITERAL:
176
r.k = AttributeKind::s;
177
r.s = parseStringLiteral(token.range, token.text());
183
if (L.cur().kind != TK_NUMBER) {
184
throw ErrorReport(token.range)
185
<< "Expected a number after '-' but got:" << token.text();
189
str += L.cur().text();
190
if (str.find('j') != std::string::npos) {
191
r.k = AttributeKind::c;
194
imag = std::stod(str.substr(0, str.size() - 1));
195
} catch (const std::invalid_argument& e) {
196
throw ErrorReport(token.range)
197
<< "Number cannot be converted to double";
198
} catch (const std::out_of_range& e) {
199
throw ErrorReport(token.range)
200
<< "Number is too long to be represented in type double";
202
r.c = c10::complex<double>(0, imag);
204
str.find('.') != std::string::npos ||
205
str.find('e') != std::string::npos) {
206
r.k = AttributeKind::f;
208
r.f = std::stod(str);
209
} catch (const std::invalid_argument& e) {
210
throw ErrorReport(token.range)
211
<< "Number cannot be converted to double";
212
} catch (const std::out_of_range& e) {
213
throw ErrorReport(token.range)
214
<< "Number is too long to be represented in type double";
217
r.k = AttributeKind::i;
219
r.i = std::stoll(str);
220
} catch (const std::invalid_argument& e) {
221
throw ErrorReport(token.range)
222
<< "Number cannot be converted to integer";
223
} catch (const std::out_of_range& e) {
224
throw ErrorReport(token.range) << "Number is too big";
231
r.k = AttributeKind::ty;
232
type_alias = type_parser.parseType();
233
AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
234
r.ty = type_alias.first;
238
auto text = L.expect(TK_IDENT);
239
if (text.text() != "Tensor") {
240
throw ErrorReport(token.range)
241
<< "Could not parse literal" << token.text();
243
if (!parse_tensor_constants_) {
244
throw ErrorReport(token.range)
245
<< "Tensor constant encountered but `parse_tensor_constants` set to false"
251
deferred_tensor_value_initializations_.push_back(n);
252
r.k = AttributeKind::t;
257
if (L.cur().kind == '-') {
260
auto text = L.expect(TK_NUMBER);
261
if (!parse_tensor_constants_) {
262
throw ErrorReport(token.range)
263
<< "Single-element tensor constant encountered but "
264
<< "`parse_tensor_constants` is set to false " << token.text();
267
deferred_tensor_value_initializations_.push_back(n);
268
r.k = AttributeKind::t;
272
throw ErrorReport(token.range)
273
<< "Could not parse literal" << token.text();
277
void IRParser::bypassTypeAnnotationList() {
279
bool bypassed_list = false;
280
while (depth != 0 || !bypassed_list) {
281
if (L.cur().kind == '[') {
282
bypassed_list = true;
284
} else if (L.cur().kind == ']') {
299
void IRParser::parseAttr(Node* n) {
300
std::string attrname = L.expect(TK_IDENT).text();
302
if (L.cur().kind == '[') {
304
AttributeKind k = AttributeKind::ts;
305
c10::List<int64_t> is;
306
c10::List<std::string> ss;
307
c10::List<double> fs;
308
c10::List<c10::complex<double>> cs;
309
std::vector<TypePtr> tys;
311
parseList('[', ',', ']', [&] {
312
ParsedLiteral r = parseScalarLiteral(n);
314
case AttributeKind::s:
316
AT_ASSERT(!elem_num++ || k == AttributeKind::ss);
317
k = AttributeKind::ss;
319
case AttributeKind::i:
321
AT_ASSERT(!elem_num++ || k == AttributeKind::is);
322
k = AttributeKind::is;
324
case AttributeKind::f:
326
AT_ASSERT(!elem_num++ || k == AttributeKind::fs);
327
k = AttributeKind::fs;
329
case AttributeKind::c:
331
AT_ASSERT(!elem_num++ || k == AttributeKind::cs);
332
k = AttributeKind::cs;
334
case AttributeKind::ty:
336
AT_ASSERT(!elem_num++ || k == AttributeKind::tys);
337
k = AttributeKind::tys;
340
throw ErrorReport(L.cur().range) << "Unexpected attr type";
344
case AttributeKind::ts:
345
n->ival_(Symbol::attr(attrname), IValue());
347
case AttributeKind::ss:
348
n->ival_(Symbol::attr(attrname), IValue(ss));
350
case AttributeKind::fs:
351
n->ival_(Symbol::attr(attrname), IValue(fs));
353
case AttributeKind::cs:
354
n->ival_(Symbol::attr(attrname), IValue(cs));
356
case AttributeKind::is:
357
n->ival_(Symbol::attr(attrname), IValue(is));
359
case AttributeKind::tys:
360
n->tys_(Symbol::attr(attrname), tys);
363
throw ErrorReport(L.cur().range) << "Unexpected attr type";
365
} else if (L.cur().text() == "annotate") {
368
auto type = L.cur().text();
369
if (type != "List" && type != "Dict") {
370
throw ErrorReport(L.cur().range)
371
<< "Unexpected annotation (only List and Dict can be parsed)";
377
bypassTypeAnnotationList();
380
if (type == "Dict") {
383
} else if (type == "List") {
388
deferred_empty_container_initializations_.push_back(n);
391
ParsedLiteral r = parseScalarLiteral(n);
393
case AttributeKind::s:
394
n->s_(Symbol::attr(attrname), r.s);
396
case AttributeKind::i:
397
n->i_(Symbol::attr(attrname), r.i);
399
case AttributeKind::f:
400
n->f_(Symbol::attr(attrname), r.f);
402
case AttributeKind::c:
403
n->c_(Symbol::attr(attrname), r.c);
405
case AttributeKind::ty:
406
n->ty_(Symbol::attr(attrname), r.ty);
408
case AttributeKind::t:
412
throw ErrorReport(L.cur().range) << "Unexpected attr type";
418
void IRParser::parseAttrs(Node* n) {
419
parseList('[', ',', ']', [&] { parseAttr(n); });
422
void IRParser::parseOperatorInputs(Node* n) {
423
if (L.cur().kind == '[') {
426
parseList('(', ',', ')', [&] {
427
std::string var_name = parseVar();
428
n->addInput(findValueInVMap(var_name));
432
void IRParser::parseBlocks(Node* parentNode) {
434
while (L.cur().kind != TK_DEDENT) {
435
parseBlock(parentNode);
440
void IRParser::parseBlockInputs(Block* b) {
441
parseList('(', ',', ')', [&] {
442
VarWithType v = parseVarWithType();
444
std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
445
vmap[v.name] = b->addInput(uniq_name);
446
vmap[v.name]->setType(v.type);
450
void IRParser::parseBlockOutputs(Block* b) {
452
parseList('(', ',', ')', [&] {
453
std::string var_name = parseVar();
454
b->registerOutput(findValueInVMap(var_name));
456
L.expect(TK_NEWLINE);
470
void IRParser::parseBlock(Node* parentNode) {
471
Block* b = parentNode->addBlock();
472
L.expect(TK_IDENT).text();
475
parseOperatorsList(b);
476
parseBlockOutputs(b);
484
void IRParser::parseOperatorsList(Block* b) {
486
while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) {
491
std::string IRParser::parseOperatorName() {
492
std::string name = L.expect(TK_IDENT).text();
495
name += "::" + L.expect(TK_IDENT).text();
506
void IRParser::parseOperator(Block* b) {
508
std::vector<VarWithType> outs;
509
parseOperatorOutputs(&outs);
512
auto source_range = L.cur().range;
513
std::string name = parseOperatorName();
514
Node* n = g->create(Symbol::fromQualString(name), {}, outs.size())
515
->setSourceRange(source_range);
518
parseOperatorInputs(n);
520
const FunctionSchema* schema = n->maybeSchema();
524
for (const VarWithType& v : outs) {
525
vmap[v.name] = n->outputs()[idx];
526
if (schema && !schema->is_varret()) {
528
schema->returns().size() > idx,
529
"Operator parsing error: out of bounds access at ",
531
" to schema->returns() which size is ",
532
schema->returns().size(),
534
auto schema_return_type = schema->returns().at(idx).type();
536
vmap[v.name]->setType(schema_return_type);
540
if (!schema_return_type->hasFreeVariables() &&
541
!v.type->isSubtypeOf(*schema_return_type)) {
542
throw ErrorReport(source_range)
543
<< "Annotated type " << v.type->repr_str()
544
<< " does not match schema type "
545
<< schema_return_type->repr_str() << " for operator " << *schema;
547
vmap[v.name]->setType(v.type);
550
vmap[v.name]->setType(v.type ? v.type : TensorType::get());
559
if (L.cur().kind == TK_INDENT) {
562
L.nextIf(TK_NEWLINE);
565
void IRParser::parseGraphInputs() {
566
parseList('(', ',', ')', [&] {
567
VarWithType v = parseVarWithType();
569
std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
570
vmap[v.name] = g->addInput(uniq_name);
571
vmap[v.name]->setType(v.type);
580
void IRParser::parseReturnOperator() {
584
parseList('(', ',', ')', [&] {
585
std::string var_name = parseVar();
586
g->registerOutput(findValueInVMap(var_name));
590
if (L.cur().kind != TK_EOF) {
591
L.expect(TK_NEWLINE);
606
void IRParser::parse() {
609
std::string graphName = L.expect(TK_IDENT).text();
614
parseOperatorsList(g->block());
617
parseReturnOperator();
619
for (Node* n : deferred_tensor_value_initializations_) {
620
auto type = n->output()->type()->expect<TensorType>();
621
auto tt = n->output()->type()->cast<TensorType>();
622
TORCH_INTERNAL_ASSERT(tt, "expected tensor output ", *n);
623
auto sizes = tt->sizes().concrete_sizes();
624
TORCH_INTERNAL_ASSERT(sizes);
625
auto strides = tt->strides().concrete_sizes();
626
TORCH_INTERNAL_ASSERT(strides);
627
auto device = tt->device();
628
TORCH_INTERNAL_ASSERT(device);
629
auto dtype = tt->scalarType();
630
TORCH_INTERNAL_ASSERT(dtype);
631
auto options = at::TensorOptions(*device).dtype(*dtype);
632
auto t = n->t_(attr::value, at::empty_strided(*sizes, *strides, options));
636
for (Node* n : deferred_empty_container_initializations_) {
637
auto type = n->output()->type();
639
if (type->kind() == TypeKind::ListType) {
640
val = c10::impl::GenericList(type->containedType(0));
641
} else if (type->kind() == TypeKind::DictType) {
642
val = c10::impl::GenericDict(
643
type->containedType(0), type->containedType(1));
645
n->ival_(attr::value, val);
649
void IRParser::parseList(
653
const std::function<void()>& callback) {
654
if (begin != TK_NOTHING) {
657
if (L.cur().kind != end) {
660
} while (L.nextIf(sep));
662
if (end != TK_NOTHING) {
667
Value* IRParser::findValueInVMap(const std::string& name) {
668
if (!vmap.count(name)) {
669
throw ErrorReport(L.cur().range)
670
<< "Cannot find a variable with name '" << name << "'";
672
return vmap.at(name);