pytorch

Форк
0
/
invalid_arguments.cpp 
444 строки · 12.9 Кб
1
#include <torch/csrc/utils/invalid_arguments.h>
2

3
#include <torch/csrc/utils/python_strings.h>
4

5
#include <c10/util/irange.h>
6

7
#include <algorithm>
8
#include <memory>
9
#include <unordered_map>
10

11
namespace torch {
12

13
namespace {
14

15
std::string py_typename(PyObject* object) {
16
  return Py_TYPE(object)->tp_name;
17
}
18

19
struct Type {
20
  Type() = default;
21
  Type(const Type&) = default;
22
  Type& operator=(const Type&) = default;
23
  Type(Type&&) noexcept = default;
24
  Type& operator=(Type&&) noexcept = default;
25
  virtual bool is_matching(PyObject* object) = 0;
26
  virtual ~Type() = default;
27
};
28

29
struct SimpleType : public Type {
30
  SimpleType(std::string& name) : name(name){};
31

32
  bool is_matching(PyObject* object) override {
33
    return py_typename(object) == name;
34
  }
35

36
  std::string name;
37
};
38

39
struct MultiType : public Type {
40
  MultiType(std::initializer_list<std::string> accepted_types)
41
      : types(accepted_types){};
42

43
  bool is_matching(PyObject* object) override {
44
    auto it = std::find(types.begin(), types.end(), py_typename(object));
45
    return it != types.end();
46
  }
47

48
  std::vector<std::string> types;
49
};
50

51
struct NullableType : public Type {
52
  NullableType(std::unique_ptr<Type> type) : type(std::move(type)){};
53

54
  bool is_matching(PyObject* object) override {
55
    return object == Py_None || type->is_matching(object);
56
  }
57

58
  std::unique_ptr<Type> type;
59
};
60

61
struct TupleType : public Type {
62
  TupleType(std::vector<std::unique_ptr<Type>> types)
63
      : types(std::move(types)){};
64

65
  bool is_matching(PyObject* object) override {
66
    if (!PyTuple_Check(object))
67
      return false;
68
    auto num_elements = PyTuple_GET_SIZE(object);
69
    if (num_elements != (long)types.size())
70
      return false;
71
    for (const auto i : c10::irange(num_elements)) {
72
      if (!types[i]->is_matching(PyTuple_GET_ITEM(object, i)))
73
        return false;
74
    }
75
    return true;
76
  }
77

78
  std::vector<std::unique_ptr<Type>> types;
79
};
80

81
struct SequenceType : public Type {
82
  SequenceType(std::unique_ptr<Type> type) : type(std::move(type)){};
83

84
  bool is_matching(PyObject* object) override {
85
    if (!PySequence_Check(object))
86
      return false;
87
    auto num_elements = PySequence_Length(object);
88
    for (const auto i : c10::irange(num_elements)) {
89
      if (!type->is_matching(
90
              py::reinterpret_steal<py::object>(PySequence_GetItem(object, i))
91
                  .ptr()))
92
        return false;
93
    }
94
    return true;
95
  }
96

97
  std::unique_ptr<Type> type;
98
};
99

100
struct Argument {
101
  Argument(std::string name, std::unique_ptr<Type> type)
102
      : name(std::move(name)), type(std::move(type)){};
103

104
  std::string name;
105
  std::unique_ptr<Type> type;
106
};
107

108
struct Option {
109
  Option(std::vector<Argument> arguments, bool is_variadic, bool has_out)
110
      : arguments(std::move(arguments)),
111
        is_variadic(is_variadic),
112
        has_out(has_out){};
113
  Option(bool is_variadic, bool has_out)
114
      : arguments(), is_variadic(is_variadic), has_out(has_out){};
115
  Option(const Option&) = delete;
116
  Option(Option&& other) noexcept
117
      : arguments(std::move(other.arguments)),
118
        is_variadic(other.is_variadic),
119
        has_out(other.has_out){};
120

121
  std::vector<Argument> arguments;
122
  bool is_variadic;
123
  bool has_out;
124
};
125

126
std::vector<std::string> _splitString(
127
    const std::string& s,
128
    const std::string& delim) {
129
  std::vector<std::string> tokens;
130
  size_t start = 0;
131
  size_t end = 0;
132
  while ((end = s.find(delim, start)) != std::string::npos) {
133
    tokens.push_back(s.substr(start, end - start));
134
    start = end + delim.length();
135
  }
136
  tokens.push_back(s.substr(start));
137
  return tokens;
138
}
139

140
std::unique_ptr<Type> _buildType(std::string type_name, bool is_nullable) {
141
  std::unique_ptr<Type> result;
142
  if (type_name == "float") {
143
    result = std::make_unique<MultiType>(MultiType{"float", "int", "long"});
144
  } else if (type_name == "int") {
145
    result = std::make_unique<MultiType>(MultiType{"int", "long"});
146
  } else if (type_name.find("tuple[") == 0) {
147
    auto type_list = type_name.substr(6);
148
    type_list.pop_back();
149
    std::vector<std::unique_ptr<Type>> types;
150
    for (auto& type : _splitString(type_list, ","))
151
      types.emplace_back(_buildType(type, false));
152
    result = std::make_unique<TupleType>(std::move(types));
153
  } else if (type_name.find("sequence[") == 0) {
154
    auto subtype = type_name.substr(9);
155
    subtype.pop_back();
156
    result = std::make_unique<SequenceType>(_buildType(subtype, false));
157
  } else {
158
    result = std::make_unique<SimpleType>(type_name);
159
  }
160
  if (is_nullable)
161
    result = std::make_unique<NullableType>(std::move(result));
162
  return result;
163
}
164

165
std::pair<Option, std::string> _parseOption(
166
    const std::string& _option_str,
167
    const std::unordered_map<std::string, PyObject*>& kwargs) {
168
  if (_option_str == "no arguments")
169
    return std::pair<Option, std::string>(Option(false, false), _option_str);
170
  bool has_out = false;
171
  std::vector<Argument> arguments;
172
  std::string printable_option = _option_str;
173
  std::string option_str = _option_str.substr(1, _option_str.length() - 2);
174

175
  /// XXX: this is a hack only for the out arg in TensorMethods
176
  auto out_pos = printable_option.find('#');
177
  if (out_pos != std::string::npos) {
178
    if (kwargs.count("out") > 0) {
179
      std::string kwonly_part = printable_option.substr(out_pos + 1);
180
      printable_option.erase(out_pos);
181
      printable_option += "*, ";
182
      printable_option += kwonly_part;
183
    } else if (out_pos >= 2) {
184
      printable_option.erase(out_pos - 2);
185
      printable_option += ")";
186
    } else {
187
      printable_option.erase(out_pos);
188
      printable_option += ")";
189
    }
190
    has_out = true;
191
  }
192

193
  for (auto& arg : _splitString(option_str, ", ")) {
194
    bool is_nullable = false;
195
    auto type_start_idx = 0;
196
    if (arg[type_start_idx] == '#') {
197
      type_start_idx++;
198
    }
199
    if (arg[type_start_idx] == '[') {
200
      is_nullable = true;
201
      type_start_idx++;
202
      arg.erase(arg.length() - std::string(" or None]").length());
203
    }
204

205
    auto type_end_idx = arg.find_last_of(' ');
206
    auto name_start_idx = type_end_idx + 1;
207

208
    // "type ... name" => "type ... name"
209
    //          ^              ^
210
    auto dots_idx = arg.find("...");
211
    if (dots_idx != std::string::npos)
212
      type_end_idx -= 4;
213

214
    std::string type_name =
215
        arg.substr(type_start_idx, type_end_idx - type_start_idx);
216
    std::string name = arg.substr(name_start_idx);
217

218
    arguments.emplace_back(name, _buildType(type_name, is_nullable));
219
  }
220

221
  bool is_variadic = option_str.find("...") != std::string::npos;
222
  return std::pair<Option, std::string>(
223
      Option(std::move(arguments), is_variadic, has_out),
224
      std::move(printable_option));
225
}
226

227
bool _argcountMatch(
228
    const Option& option,
229
    const std::vector<PyObject*>& arguments,
230
    const std::unordered_map<std::string, PyObject*>& kwargs) {
231
  auto num_expected = option.arguments.size();
232
  auto num_got = arguments.size() + kwargs.size();
233
  // Note: variadic functions don't accept kwargs, so it's ok
234
  if (option.has_out && kwargs.count("out") == 0)
235
    num_expected--;
236
  return num_got == num_expected ||
237
      (option.is_variadic && num_got > num_expected);
238
}
239

240
std::string _formattedArgDesc(
241
    const Option& option,
242
    const std::vector<PyObject*>& arguments,
243
    const std::unordered_map<std::string, PyObject*>& kwargs) {
244
  std::string red;
245
  std::string reset_red;
246
  std::string green;
247
  std::string reset_green;
248
  if (isatty(1) && isatty(2)) {
249
    red = "\33[31;1m";
250
    reset_red = "\33[0m";
251
    green = "\33[32;1m";
252
    reset_green = "\33[0m";
253
  } else {
254
    red = "!";
255
    reset_red = "!";
256
    green = "";
257
    reset_green = "";
258
  }
259

260
  auto num_args = arguments.size() + kwargs.size();
261
  std::string result = "(";
262
  for (const auto i : c10::irange(num_args)) {
263
    bool is_kwarg = i >= arguments.size();
264
    PyObject* arg =
265
        is_kwarg ? kwargs.at(option.arguments[i].name) : arguments[i];
266

267
    bool is_matching = false;
268
    if (i < option.arguments.size()) {
269
      is_matching = option.arguments[i].type->is_matching(arg);
270
    } else if (option.is_variadic) {
271
      is_matching = option.arguments.back().type->is_matching(arg);
272
    }
273

274
    if (is_matching)
275
      result += green;
276
    else
277
      result += red;
278
    if (is_kwarg)
279
      result += option.arguments[i].name + "=";
280
    bool is_tuple = PyTuple_Check(arg);
281
    if (is_tuple || PyList_Check(arg)) {
282
      result += py_typename(arg) + " of ";
283
      auto num_elements = PySequence_Length(arg);
284
      if (is_tuple) {
285
        result += "(";
286
      } else {
287
        result += "[";
288
      }
289
      for (const auto i : c10::irange(num_elements)) {
290
        if (i != 0) {
291
          result += ", ";
292
        }
293
        result += py_typename(
294
            py::reinterpret_steal<py::object>(PySequence_GetItem(arg, i))
295
                .ptr());
296
      }
297
      if (is_tuple) {
298
        if (num_elements == 1) {
299
          result += ",";
300
        }
301
        result += ")";
302
      } else {
303
        result += "]";
304
      }
305
    } else {
306
      result += py_typename(arg);
307
    }
308
    if (is_matching)
309
      result += reset_green;
310
    else
311
      result += reset_red;
312
    result += ", ";
313
  }
314
  if (!arguments.empty())
315
    result.erase(result.length() - 2);
316
  result += ")";
317
  return result;
318
}
319

320
std::string _argDesc(
321
    const std::vector<PyObject*>& arguments,
322
    const std::unordered_map<std::string, PyObject*>& kwargs) {
323
  std::string result = "(";
324
  for (auto& arg : arguments)
325
    result += std::string(py_typename(arg)) + ", ";
326
  for (auto& kwarg : kwargs)
327
    result += kwarg.first + "=" + py_typename(kwarg.second) + ", ";
328
  if (!arguments.empty())
329
    result.erase(result.length() - 2);
330
  result += ")";
331
  return result;
332
}
333

334
std::vector<std::string> _tryMatchKwargs(
335
    const Option& option,
336
    const std::unordered_map<std::string, PyObject*>& kwargs) {
337
  std::vector<std::string> unmatched;
338
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
339
  int64_t start_idx = option.arguments.size() - kwargs.size();
340
  if (option.has_out && kwargs.count("out") == 0)
341
    start_idx--;
342
  if (start_idx < 0)
343
    start_idx = 0;
344
  for (auto& entry : kwargs) {
345
    bool found = false;
346
    for (unsigned int i = start_idx; i < option.arguments.size(); i++) {
347
      if (option.arguments[i].name == entry.first) {
348
        found = true;
349
        break;
350
      }
351
    }
352
    if (!found)
353
      unmatched.push_back(entry.first);
354
  }
355
  return unmatched;
356
}
357

358
} // anonymous namespace
359

360
std::string format_invalid_args(
361
    PyObject* given_args,
362
    PyObject* given_kwargs,
363
    const std::string& function_name,
364
    const std::vector<std::string>& options) {
365
  std::vector<PyObject*> args;
366
  std::unordered_map<std::string, PyObject*> kwargs;
367
  std::string error_msg;
368
  error_msg.reserve(2000);
369
  error_msg += function_name;
370
  error_msg += " received an invalid combination of arguments - ";
371

372
  Py_ssize_t num_args = PyTuple_Size(given_args);
373
  for (const auto i : c10::irange(num_args)) {
374
    PyObject* arg = PyTuple_GET_ITEM(given_args, i);
375
    args.push_back(arg);
376
  }
377

378
  bool has_kwargs = given_kwargs && PyDict_Size(given_kwargs) > 0;
379
  if (has_kwargs) {
380
    PyObject *key = nullptr, *value = nullptr;
381
    Py_ssize_t pos = 0;
382

383
    while (PyDict_Next(given_kwargs, &pos, &key, &value)) {
384
      kwargs.emplace(THPUtils_unpackString(key), value);
385
    }
386
  }
387

388
  if (options.size() == 1) {
389
    auto pair = _parseOption(options[0], kwargs);
390
    auto& option = pair.first;
391
    auto& option_str = pair.second;
392
    std::vector<std::string> unmatched_kwargs;
393
    if (has_kwargs)
394
      unmatched_kwargs = _tryMatchKwargs(option, kwargs);
395
    if (!unmatched_kwargs.empty()) {
396
      error_msg += "got unrecognized keyword arguments: ";
397
      for (auto& kwarg : unmatched_kwargs)
398
        error_msg += kwarg + ", ";
399
      error_msg.erase(error_msg.length() - 2);
400
    } else {
401
      error_msg += "got ";
402
      if (_argcountMatch(option, args, kwargs)) {
403
        error_msg += _formattedArgDesc(option, args, kwargs);
404
      } else {
405
        error_msg += _argDesc(args, kwargs);
406
      }
407
      error_msg += ", but expected ";
408
      error_msg += option_str;
409
    }
410
  } else {
411
    error_msg += "got ";
412
    error_msg += _argDesc(args, kwargs);
413
    error_msg += ", but expected one of:\n";
414
    for (auto& option_str : options) {
415
      auto pair = _parseOption(option_str, kwargs);
416
      auto& option = pair.first;
417
      auto& printable_option_str = pair.second;
418
      error_msg += " * ";
419
      error_msg += printable_option_str;
420
      error_msg += "\n";
421
      if (_argcountMatch(option, args, kwargs)) {
422
        std::vector<std::string> unmatched_kwargs;
423
        if (has_kwargs)
424
          unmatched_kwargs = _tryMatchKwargs(option, kwargs);
425
        if (!unmatched_kwargs.empty()) {
426
          error_msg +=
427
              "      didn't match because some of the keywords were incorrect: ";
428
          for (auto& kwarg : unmatched_kwargs)
429
            error_msg += kwarg + ", ";
430
          error_msg.erase(error_msg.length() - 2);
431
          error_msg += "\n";
432
        } else {
433
          error_msg +=
434
              "      didn't match because some of the arguments have invalid types: ";
435
          error_msg += _formattedArgDesc(option, args, kwargs);
436
          error_msg += "\n";
437
        }
438
      }
439
    }
440
  }
441
  return error_msg;
442
}
443

444
} // namespace torch
445

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.