1
#include <torch/csrc/jit/frontend/function_schema_parser.h>
3
#include <ATen/core/Reduction.h>
4
#include <ATen/core/type_factory.h>
5
#include <c10/util/Optional.h>
6
#include <c10/util/string_utils.h>
7
#include <torch/csrc/jit/frontend/lexer.h>
8
#include <torch/csrc/jit/frontend/parse_string_literal.h>
9
#include <torch/csrc/jit/frontend/schema_type_parser.h>
17
using c10::FunctionSchema;
20
using c10::OperatorName;
26
explicit SchemaParser(const std::string& str)
27
: L(std::make_shared<Source>(
28
c10::string_view(str),
33
type_parser(L, /*parse_complete_tensor_types*/ false) {}
35
std::variant<OperatorName, FunctionSchema> parseDeclaration() {
36
OperatorName name = parseName();
38
// If there is no parentheses coming, then this is just the operator name
39
// without an argument list
40
if (L.cur().kind != '(') {
41
return OperatorName(std::move(name));
44
std::vector<Argument> arguments;
45
std::vector<Argument> returns;
46
bool kwarg_only = false;
47
bool is_vararg = false;
48
bool is_varret = false;
50
parseList('(', ',', ')', [&] {
52
throw ErrorReport(L.cur())
53
<< "... must be the last element of the argument list";
56
} else if (L.nextIf(TK_DOTS)) {
59
arguments.push_back(parseArgument(
60
idx++, /*is_return=*/false, /*kwarg_only=*/kwarg_only));
64
// check if all arguments are not-default for vararg schemas
66
for (const auto& arg : arguments) {
67
if (arg.default_value().has_value()) {
68
throw ErrorReport(L.cur())
69
<< "schemas with vararg (...) can't have default value args";
76
if (L.nextIf(TK_DOTS)) {
78
} else if (L.cur().kind == '(') {
79
parseList('(', ',', ')', [&] {
81
throw ErrorReport(L.cur())
82
<< "... must be the last element of the return list";
84
if (L.nextIf(TK_DOTS)) {
88
parseArgument(idx++, /*is_return=*/true, /*kwarg_only=*/false));
93
parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false));
96
return FunctionSchema(
98
std::move(name.overload_name),
105
c10::OperatorName parseName() {
106
std::string name = L.expect(TK_IDENT).text();
109
name = name + "::" + L.expect(TK_IDENT).text();
111
std::string overload_name = "";
113
overload_name = L.expect(TK_IDENT).text();
115
// default is used as an attribute on the `OpOverloadPacket`
116
// (obtained using `torch.ops.aten.foo`) to get the operator
117
// overload with overload name as an empty string
118
// and so shouldn't be used as an overload name
119
// also disallow dunder attribute names to be overload names
120
bool is_a_valid_overload_name =
121
!((overload_name == "default") || (overload_name.rfind("__", 0) == 0));
123
is_a_valid_overload_name,
125
" is not a legal overload name for aten operators");
126
return {name, overload_name};
129
std::vector<std::variant<OperatorName, FunctionSchema>> parseDeclarations() {
130
std::vector<std::variant<OperatorName, FunctionSchema>> results;
132
results.emplace_back(parseDeclaration());
133
} while (L.nextIf(TK_NEWLINE));
138
std::variant<OperatorName, FunctionSchema> parseExactlyOneDeclaration() {
139
auto result = parseDeclaration();
140
L.nextIf(TK_NEWLINE);
145
Argument parseArgument(size_t /*idx*/, bool is_return, bool kwarg_only) {
146
// fake and real type coincide except for Layout/MemoryFormat/ScalarType
147
// the fake type for these is Int instead
148
auto p = type_parser.parseFakeAndRealType();
149
auto fake_type = std::move(std::get<0>(p));
150
auto real_type = std::move(std::get<1>(p));
151
auto alias_info = std::move(std::get<2>(p));
152
c10::optional<int32_t> N;
153
c10::optional<IValue> default_value;
154
c10::optional<std::string> alias_set;
157
// note: an array with a size hint can only occur at the Argument level
158
fake_type = ListType::create(std::move(fake_type));
159
real_type = ListType::create(std::move(real_type));
160
N = c10::stoll(L.expect(TK_NUMBER).text());
162
auto container = type_parser.parseAliasAnnotation();
165
container = c10::optional<at::AliasInfo>(at::AliasInfo());
166
container->setIsWrite(alias_info->isWrite());
168
container->addContainedType(std::move(*alias_info));
170
alias_info = std::move(container);
173
c10::TypeFactory::create<c10::OptionalType>(std::move(fake_type));
175
c10::TypeFactory::create<c10::OptionalType>(std::move(real_type));
179
// optionally field names in return values
180
if (L.cur().kind == TK_IDENT) {
181
name = L.next().text();
186
name = L.expect(TK_IDENT).text();
188
// NB: this means we have to unswizzle default too
189
default_value = parseDefaultValue(*fake_type, fake_type->kind(), N);
194
std::move(fake_type),
195
std::move(real_type),
197
std::move(default_value),
198
!is_return && kwarg_only,
199
std::move(alias_info));
201
IValue parseSingleConstant(const c10::Type& type, TypeKind kind) {
202
if (kind == c10::TypeKind::DynamicType) {
203
return parseSingleConstant(
204
type, type.expectRef<c10::DynamicType>().dynamicKind());
206
switch (L.cur().kind) {
216
case TK_STRINGLITERAL: {
217
auto token = L.next();
218
return parseStringLiteral(token.range, token.text());
222
auto text = tok.text();
223
if ("float" == text) {
224
return static_cast<int64_t>(at::kFloat);
225
} else if ("complex" == text) {
226
return static_cast<int64_t>(at::kComplexFloat);
227
} else if ("long" == text) {
228
return static_cast<int64_t>(at::kLong);
229
} else if ("strided" == text) {
230
return static_cast<int64_t>(at::kStrided);
231
} else if ("Mean" == text) {
232
return static_cast<int64_t>(at::Reduction::Mean);
233
} else if ("contiguous_format" == text) {
234
return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
236
throw ErrorReport(L.cur().range) << "invalid numeric default value";
242
n = "-" + L.expect(TK_NUMBER).text();
244
n = L.expect(TK_NUMBER).text();
246
if (kind == TypeKind::ComplexType || n.find('j') != std::string::npos) {
247
auto imag = c10::stod(n.substr(0, n.size() - 1));
248
return c10::complex<double>(0, imag);
250
kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
251
n.find('e') != std::string::npos) {
254
int64_t v = c10::stoll(n);
259
IValue convertToList(
260
const c10::Type& type,
262
const SourceRange& range,
263
const std::vector<IValue>& vs) {
265
case TypeKind::ComplexType:
266
return fmap(vs, [](const IValue& v) { return v.toComplexDouble(); });
267
case TypeKind::FloatType:
268
return fmap(vs, [](const IValue& v) { return v.toDouble(); });
269
case TypeKind::IntType:
270
return fmap(vs, [](const IValue& v) { return v.toInt(); });
271
case TypeKind::BoolType:
272
return fmap(vs, [](const IValue& v) { return v.toBool(); });
273
case TypeKind::DynamicType:
274
return convertToList(
275
type, type.expectRef<c10::DynamicType>().dynamicKind(), range, vs);
277
throw ErrorReport(range)
278
<< "lists are only supported for float, int and complex types";
281
IValue parseConstantList(const c10::Type& type, TypeKind kind) {
282
auto tok = L.expect('[');
283
std::vector<IValue> vs;
284
if (L.cur().kind != ']') {
286
vs.push_back(parseSingleConstant(type, kind));
287
} while (L.nextIf(','));
290
return convertToList(type, kind, tok.range, vs);
293
IValue parseTensorDefault(const SourceRange& /*range*/) {
297
IValue parseDefaultValue(
298
const c10::Type& arg_type,
300
c10::optional<int32_t> arg_N) {
301
auto range = L.cur().range;
303
case TypeKind::TensorType:
304
case TypeKind::GeneratorType:
305
case TypeKind::QuantizerType: {
306
return parseTensorDefault(range);
308
case TypeKind::StringType:
309
case TypeKind::OptionalType:
310
case TypeKind::NumberType:
311
case TypeKind::IntType:
312
case TypeKind::BoolType:
313
case TypeKind::FloatType:
314
case TypeKind::ComplexType:
315
return parseSingleConstant(arg_type, kind);
317
case TypeKind::DeviceObjType: {
319
parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
320
return c10::Device(device_text);
323
case TypeKind::ListType: {
324
auto elem_type = arg_type.containedType(0);
325
if (L.cur().kind == TK_IDENT) {
326
return parseTensorDefault(range);
327
} else if (arg_N && L.cur().kind != '[') {
328
IValue v = parseSingleConstant(*elem_type, elem_type->kind());
329
std::vector<IValue> repeated(*arg_N, v);
330
return convertToList(*elem_type, elem_type->kind(), range, repeated);
332
return parseConstantList(*elem_type, elem_type->kind());
335
case TypeKind::DynamicType:
336
return parseDefaultValue(
338
arg_type.expectRef<c10::DynamicType>().dynamicKind(),
341
throw ErrorReport(range) << "unexpected type, file a bug report";
343
return IValue(); // silence warnings
350
c10::function_ref<void()> callback) {
351
auto r = L.cur().range;
352
if (begin != TK_NOTHING)
354
if (L.cur().kind != end) {
357
} while (L.nextIf(sep));
359
if (end != TK_NOTHING)
363
SchemaTypeParser type_parser;
367
std::variant<OperatorName, FunctionSchema> parseSchemaOrName(
368
const std::string& schemaOrName) {
369
return SchemaParser(schemaOrName).parseExactlyOneDeclaration();
372
FunctionSchema parseSchema(const std::string& schema) {
373
auto parsed = parseSchemaOrName(schema);
375
std::holds_alternative<FunctionSchema>(parsed),
376
"Tried to parse a function schema but only the operator name was given");
377
return std::get<FunctionSchema>(std::move(parsed));
380
OperatorName parseName(const std::string& name) {
381
auto parsed = parseSchemaOrName(name);
383
std::holds_alternative<OperatorName>(parsed),
384
"Tried to parse an operator name but function schema was given");
385
return std::get<OperatorName>(std::move(parsed));
388
} // namespace torch::jit