pytorch

Форк
0
/
function_schema_parser.cpp 
388 строк · 12.2 Кб
1
#include <torch/csrc/jit/frontend/function_schema_parser.h>
2

3
#include <ATen/core/Reduction.h>
4
#include <ATen/core/type_factory.h>
5
#include <c10/util/Optional.h>
6
#include <c10/util/string_utils.h>
7
#include <torch/csrc/jit/frontend/lexer.h>
8
#include <torch/csrc/jit/frontend/parse_string_literal.h>
9
#include <torch/csrc/jit/frontend/schema_type_parser.h>
10

11
#include <functional>
12
#include <memory>
13
#include <vector>
14

15
using at::TypeKind;
16
using c10::Argument;
17
using c10::FunctionSchema;
18
using c10::IValue;
19
using c10::ListType;
20
using c10::OperatorName;
21

22
namespace torch::jit {
23

24
namespace {
25
struct SchemaParser {
26
  explicit SchemaParser(const std::string& str)
27
      : L(std::make_shared<Source>(
28
            c10::string_view(str),
29
            c10::nullopt,
30
            0,
31
            nullptr,
32
            Source::DONT_COPY)),
33
        type_parser(L, /*parse_complete_tensor_types*/ false) {}
34

35
  std::variant<OperatorName, FunctionSchema> parseDeclaration() {
36
    OperatorName name = parseName();
37

38
    // If there is no parentheses coming, then this is just the operator name
39
    // without an argument list
40
    if (L.cur().kind != '(') {
41
      return OperatorName(std::move(name));
42
    }
43

44
    std::vector<Argument> arguments;
45
    std::vector<Argument> returns;
46
    bool kwarg_only = false;
47
    bool is_vararg = false;
48
    bool is_varret = false;
49
    size_t idx = 0;
50
    parseList('(', ',', ')', [&] {
51
      if (is_vararg)
52
        throw ErrorReport(L.cur())
53
            << "... must be the last element of the argument list";
54
      if (L.nextIf('*')) {
55
        kwarg_only = true;
56
      } else if (L.nextIf(TK_DOTS)) {
57
        is_vararg = true;
58
      } else {
59
        arguments.push_back(parseArgument(
60
            idx++, /*is_return=*/false, /*kwarg_only=*/kwarg_only));
61
      }
62
    });
63

64
    // check if all arguments are not-default for vararg schemas
65
    if (is_vararg) {
66
      for (const auto& arg : arguments) {
67
        if (arg.default_value().has_value()) {
68
          throw ErrorReport(L.cur())
69
              << "schemas with vararg (...) can't have default value args";
70
        }
71
      }
72
    }
73

74
    idx = 0;
75
    L.expect(TK_ARROW);
76
    if (L.nextIf(TK_DOTS)) {
77
      is_varret = true;
78
    } else if (L.cur().kind == '(') {
79
      parseList('(', ',', ')', [&] {
80
        if (is_varret) {
81
          throw ErrorReport(L.cur())
82
              << "... must be the last element of the return list";
83
        }
84
        if (L.nextIf(TK_DOTS)) {
85
          is_varret = true;
86
        } else {
87
          returns.push_back(
88
              parseArgument(idx++, /*is_return=*/true, /*kwarg_only=*/false));
89
        }
90
      });
91
    } else {
92
      returns.push_back(
93
          parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false));
94
    }
95

96
    return FunctionSchema(
97
        std::move(name.name),
98
        std::move(name.overload_name),
99
        std::move(arguments),
100
        std::move(returns),
101
        is_vararg,
102
        is_varret);
103
  }
104

105
  c10::OperatorName parseName() {
106
    std::string name = L.expect(TK_IDENT).text();
107
    if (L.nextIf(':')) {
108
      L.expect(':');
109
      name = name + "::" + L.expect(TK_IDENT).text();
110
    }
111
    std::string overload_name = "";
112
    if (L.nextIf('.')) {
113
      overload_name = L.expect(TK_IDENT).text();
114
    }
115
    // default is used as an attribute on the `OpOverloadPacket`
116
    // (obtained using `torch.ops.aten.foo`) to get the operator
117
    // overload with overload name as an empty string
118
    // and so shouldn't be used as an overload name
119
    // also disallow dunder attribute names to be overload names
120
    bool is_a_valid_overload_name =
121
        !((overload_name == "default") || (overload_name.rfind("__", 0) == 0));
122
    TORCH_CHECK(
123
        is_a_valid_overload_name,
124
        overload_name,
125
        " is not a legal overload name for aten operators");
126
    return {name, overload_name};
127
  }
128

129
  std::vector<std::variant<OperatorName, FunctionSchema>> parseDeclarations() {
130
    std::vector<std::variant<OperatorName, FunctionSchema>> results;
131
    do {
132
      results.emplace_back(parseDeclaration());
133
    } while (L.nextIf(TK_NEWLINE));
134
    L.expect(TK_EOF);
135
    return results;
136
  }
137

138
  std::variant<OperatorName, FunctionSchema> parseExactlyOneDeclaration() {
139
    auto result = parseDeclaration();
140
    L.nextIf(TK_NEWLINE);
141
    L.expect(TK_EOF);
142
    return result;
143
  }
144

145
  Argument parseArgument(size_t /*idx*/, bool is_return, bool kwarg_only) {
146
    // fake and real type coincide except for Layout/MemoryFormat/ScalarType
147
    // the fake type for these is Int instead
148
    auto p = type_parser.parseFakeAndRealType();
149
    auto fake_type = std::move(std::get<0>(p));
150
    auto real_type = std::move(std::get<1>(p));
151
    auto alias_info = std::move(std::get<2>(p));
152
    c10::optional<int32_t> N;
153
    c10::optional<IValue> default_value;
154
    c10::optional<std::string> alias_set;
155
    std::string name;
156
    if (L.nextIf('[')) {
157
      // note: an array with a size hint can only occur at the Argument level
158
      fake_type = ListType::create(std::move(fake_type));
159
      real_type = ListType::create(std::move(real_type));
160
      N = c10::stoll(L.expect(TK_NUMBER).text());
161
      L.expect(']');
162
      auto container = type_parser.parseAliasAnnotation();
163
      if (alias_info) {
164
        if (!container) {
165
          container = c10::optional<at::AliasInfo>(at::AliasInfo());
166
          container->setIsWrite(alias_info->isWrite());
167
        }
168
        container->addContainedType(std::move(*alias_info));
169
      }
170
      alias_info = std::move(container);
171
      if (L.nextIf('?')) {
172
        fake_type =
173
            c10::TypeFactory::create<c10::OptionalType>(std::move(fake_type));
174
        real_type =
175
            c10::TypeFactory::create<c10::OptionalType>(std::move(real_type));
176
      }
177
    }
178
    if (is_return) {
179
      // optionally field names in return values
180
      if (L.cur().kind == TK_IDENT) {
181
        name = L.next().text();
182
      } else {
183
        name = "";
184
      }
185
    } else {
186
      name = L.expect(TK_IDENT).text();
187
      if (L.nextIf('=')) {
188
        // NB: this means we have to unswizzle default too
189
        default_value = parseDefaultValue(*fake_type, fake_type->kind(), N);
190
      }
191
    }
192
    return Argument(
193
        std::move(name),
194
        std::move(fake_type),
195
        std::move(real_type),
196
        N,
197
        std::move(default_value),
198
        !is_return && kwarg_only,
199
        std::move(alias_info));
200
  }
201
  IValue parseSingleConstant(const c10::Type& type, TypeKind kind) {
202
    if (kind == c10::TypeKind::DynamicType) {
203
      return parseSingleConstant(
204
          type, type.expectRef<c10::DynamicType>().dynamicKind());
205
    }
206
    switch (L.cur().kind) {
207
      case TK_TRUE:
208
        L.next();
209
        return true;
210
      case TK_FALSE:
211
        L.next();
212
        return false;
213
      case TK_NONE:
214
        L.next();
215
        return IValue();
216
      case TK_STRINGLITERAL: {
217
        auto token = L.next();
218
        return parseStringLiteral(token.range, token.text());
219
      }
220
      case TK_IDENT: {
221
        auto tok = L.next();
222
        auto text = tok.text();
223
        if ("float" == text) {
224
          return static_cast<int64_t>(at::kFloat);
225
        } else if ("complex" == text) {
226
          return static_cast<int64_t>(at::kComplexFloat);
227
        } else if ("long" == text) {
228
          return static_cast<int64_t>(at::kLong);
229
        } else if ("strided" == text) {
230
          return static_cast<int64_t>(at::kStrided);
231
        } else if ("Mean" == text) {
232
          return static_cast<int64_t>(at::Reduction::Mean);
233
        } else if ("contiguous_format" == text) {
234
          return static_cast<int64_t>(c10::MemoryFormat::Contiguous);
235
        } else {
236
          throw ErrorReport(L.cur().range) << "invalid numeric default value";
237
        }
238
      }
239
      default:
240
        std::string n;
241
        if (L.nextIf('-'))
242
          n = "-" + L.expect(TK_NUMBER).text();
243
        else
244
          n = L.expect(TK_NUMBER).text();
245

246
        if (kind == TypeKind::ComplexType || n.find('j') != std::string::npos) {
247
          auto imag = c10::stod(n.substr(0, n.size() - 1));
248
          return c10::complex<double>(0, imag);
249
        } else if (
250
            kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
251
            n.find('e') != std::string::npos) {
252
          return c10::stod(n);
253
        } else {
254
          int64_t v = c10::stoll(n);
255
          return v;
256
        }
257
    }
258
  }
259
  IValue convertToList(
260
      const c10::Type& type,
261
      TypeKind kind,
262
      const SourceRange& range,
263
      const std::vector<IValue>& vs) {
264
    switch (kind) {
265
      case TypeKind::ComplexType:
266
        return fmap(vs, [](const IValue& v) { return v.toComplexDouble(); });
267
      case TypeKind::FloatType:
268
        return fmap(vs, [](const IValue& v) { return v.toDouble(); });
269
      case TypeKind::IntType:
270
        return fmap(vs, [](const IValue& v) { return v.toInt(); });
271
      case TypeKind::BoolType:
272
        return fmap(vs, [](const IValue& v) { return v.toBool(); });
273
      case TypeKind::DynamicType:
274
        return convertToList(
275
            type, type.expectRef<c10::DynamicType>().dynamicKind(), range, vs);
276
      default:
277
        throw ErrorReport(range)
278
            << "lists are only supported for float, int and complex types";
279
    }
280
  }
281
  IValue parseConstantList(const c10::Type& type, TypeKind kind) {
282
    auto tok = L.expect('[');
283
    std::vector<IValue> vs;
284
    if (L.cur().kind != ']') {
285
      do {
286
        vs.push_back(parseSingleConstant(type, kind));
287
      } while (L.nextIf(','));
288
    }
289
    L.expect(']');
290
    return convertToList(type, kind, tok.range, vs);
291
  }
292

293
  IValue parseTensorDefault(const SourceRange& /*range*/) {
294
    L.expect(TK_NONE);
295
    return IValue();
296
  }
297
  IValue parseDefaultValue(
298
      const c10::Type& arg_type,
299
      TypeKind kind,
300
      c10::optional<int32_t> arg_N) {
301
    auto range = L.cur().range;
302
    switch (kind) {
303
      case TypeKind::TensorType:
304
      case TypeKind::GeneratorType:
305
      case TypeKind::QuantizerType: {
306
        return parseTensorDefault(range);
307
      } break;
308
      case TypeKind::StringType:
309
      case TypeKind::OptionalType:
310
      case TypeKind::NumberType:
311
      case TypeKind::IntType:
312
      case TypeKind::BoolType:
313
      case TypeKind::FloatType:
314
      case TypeKind::ComplexType:
315
        return parseSingleConstant(arg_type, kind);
316
        break;
317
      case TypeKind::DeviceObjType: {
318
        auto device_text =
319
            parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
320
        return c10::Device(device_text);
321
        break;
322
      }
323
      case TypeKind::ListType: {
324
        auto elem_type = arg_type.containedType(0);
325
        if (L.cur().kind == TK_IDENT) {
326
          return parseTensorDefault(range);
327
        } else if (arg_N && L.cur().kind != '[') {
328
          IValue v = parseSingleConstant(*elem_type, elem_type->kind());
329
          std::vector<IValue> repeated(*arg_N, v);
330
          return convertToList(*elem_type, elem_type->kind(), range, repeated);
331
        } else {
332
          return parseConstantList(*elem_type, elem_type->kind());
333
        }
334
      } break;
335
      case TypeKind::DynamicType:
336
        return parseDefaultValue(
337
            arg_type,
338
            arg_type.expectRef<c10::DynamicType>().dynamicKind(),
339
            arg_N);
340
      default:
341
        throw ErrorReport(range) << "unexpected type, file a bug report";
342
    }
343
    return IValue(); // silence warnings
344
  }
345

346
  void parseList(
347
      int begin,
348
      int sep,
349
      int end,
350
      c10::function_ref<void()> callback) {
351
    auto r = L.cur().range;
352
    if (begin != TK_NOTHING)
353
      L.expect(begin);
354
    if (L.cur().kind != end) {
355
      do {
356
        callback();
357
      } while (L.nextIf(sep));
358
    }
359
    if (end != TK_NOTHING)
360
      L.expect(end);
361
  }
362
  Lexer L;
363
  SchemaTypeParser type_parser;
364
};
365
} // namespace
366

367
std::variant<OperatorName, FunctionSchema> parseSchemaOrName(
368
    const std::string& schemaOrName) {
369
  return SchemaParser(schemaOrName).parseExactlyOneDeclaration();
370
}
371

372
FunctionSchema parseSchema(const std::string& schema) {
373
  auto parsed = parseSchemaOrName(schema);
374
  TORCH_CHECK(
375
      std::holds_alternative<FunctionSchema>(parsed),
376
      "Tried to parse a function schema but only the operator name was given");
377
  return std::get<FunctionSchema>(std::move(parsed));
378
}
379

380
OperatorName parseName(const std::string& name) {
381
  auto parsed = parseSchemaOrName(name);
382
  TORCH_CHECK(
383
      std::holds_alternative<OperatorName>(parsed),
384
      "Tried to parse an operator name but function schema was given");
385
  return std::get<OperatorName>(std::move(parsed));
386
}
387

388
} // namespace torch::jit
389

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.