1
#include <torch/csrc/utils/invalid_arguments.h>
3
#include <torch/csrc/utils/python_strings.h>
5
#include <c10/util/irange.h>
9
#include <unordered_map>
15
std::string py_typename(PyObject* object) {
16
return Py_TYPE(object)->tp_name;
21
Type(const Type&) = default;
22
Type& operator=(const Type&) = default;
23
Type(Type&&) noexcept = default;
24
Type& operator=(Type&&) noexcept = default;
25
virtual bool is_matching(PyObject* object) = 0;
26
virtual ~Type() = default;
29
struct SimpleType : public Type {
30
SimpleType(std::string& name) : name(name){};
32
bool is_matching(PyObject* object) override {
33
return py_typename(object) == name;
39
struct MultiType : public Type {
40
MultiType(std::initializer_list<std::string> accepted_types)
41
: types(accepted_types){};
43
bool is_matching(PyObject* object) override {
44
auto it = std::find(types.begin(), types.end(), py_typename(object));
45
return it != types.end();
48
std::vector<std::string> types;
51
struct NullableType : public Type {
52
NullableType(std::unique_ptr<Type> type) : type(std::move(type)){};
54
bool is_matching(PyObject* object) override {
55
return object == Py_None || type->is_matching(object);
58
std::unique_ptr<Type> type;
61
struct TupleType : public Type {
62
TupleType(std::vector<std::unique_ptr<Type>> types)
63
: types(std::move(types)){};
65
bool is_matching(PyObject* object) override {
66
if (!PyTuple_Check(object))
68
auto num_elements = PyTuple_GET_SIZE(object);
69
if (num_elements != (long)types.size())
71
for (const auto i : c10::irange(num_elements)) {
72
if (!types[i]->is_matching(PyTuple_GET_ITEM(object, i)))
78
std::vector<std::unique_ptr<Type>> types;
81
struct SequenceType : public Type {
82
SequenceType(std::unique_ptr<Type> type) : type(std::move(type)){};
84
bool is_matching(PyObject* object) override {
85
if (!PySequence_Check(object))
87
auto num_elements = PySequence_Length(object);
88
for (const auto i : c10::irange(num_elements)) {
89
if (!type->is_matching(
90
py::reinterpret_steal<py::object>(PySequence_GetItem(object, i))
97
std::unique_ptr<Type> type;
101
Argument(std::string name, std::unique_ptr<Type> type)
102
: name(std::move(name)), type(std::move(type)){};
105
std::unique_ptr<Type> type;
109
Option(std::vector<Argument> arguments, bool is_variadic, bool has_out)
110
: arguments(std::move(arguments)),
111
is_variadic(is_variadic),
113
Option(bool is_variadic, bool has_out)
114
: arguments(), is_variadic(is_variadic), has_out(has_out){};
115
Option(const Option&) = delete;
116
Option(Option&& other) noexcept
117
: arguments(std::move(other.arguments)),
118
is_variadic(other.is_variadic),
119
has_out(other.has_out){};
121
std::vector<Argument> arguments;
126
std::vector<std::string> _splitString(
127
const std::string& s,
128
const std::string& delim) {
129
std::vector<std::string> tokens;
132
while ((end = s.find(delim, start)) != std::string::npos) {
133
tokens.push_back(s.substr(start, end - start));
134
start = end + delim.length();
136
tokens.push_back(s.substr(start));
140
std::unique_ptr<Type> _buildType(std::string type_name, bool is_nullable) {
141
std::unique_ptr<Type> result;
142
if (type_name == "float") {
143
result = std::make_unique<MultiType>(MultiType{"float", "int", "long"});
144
} else if (type_name == "int") {
145
result = std::make_unique<MultiType>(MultiType{"int", "long"});
146
} else if (type_name.find("tuple[") == 0) {
147
auto type_list = type_name.substr(6);
148
type_list.pop_back();
149
std::vector<std::unique_ptr<Type>> types;
150
for (auto& type : _splitString(type_list, ","))
151
types.emplace_back(_buildType(type, false));
152
result = std::make_unique<TupleType>(std::move(types));
153
} else if (type_name.find("sequence[") == 0) {
154
auto subtype = type_name.substr(9);
156
result = std::make_unique<SequenceType>(_buildType(subtype, false));
158
result = std::make_unique<SimpleType>(type_name);
161
result = std::make_unique<NullableType>(std::move(result));
165
std::pair<Option, std::string> _parseOption(
166
const std::string& _option_str,
167
const std::unordered_map<std::string, PyObject*>& kwargs) {
168
if (_option_str == "no arguments")
169
return std::pair<Option, std::string>(Option(false, false), _option_str);
170
bool has_out = false;
171
std::vector<Argument> arguments;
172
std::string printable_option = _option_str;
173
std::string option_str = _option_str.substr(1, _option_str.length() - 2);
176
auto out_pos = printable_option.find('#');
177
if (out_pos != std::string::npos) {
178
if (kwargs.count("out") > 0) {
179
std::string kwonly_part = printable_option.substr(out_pos + 1);
180
printable_option.erase(out_pos);
181
printable_option += "*, ";
182
printable_option += kwonly_part;
183
} else if (out_pos >= 2) {
184
printable_option.erase(out_pos - 2);
185
printable_option += ")";
187
printable_option.erase(out_pos);
188
printable_option += ")";
193
for (auto& arg : _splitString(option_str, ", ")) {
194
bool is_nullable = false;
195
auto type_start_idx = 0;
196
if (arg[type_start_idx] == '#') {
199
if (arg[type_start_idx] == '[') {
202
arg.erase(arg.length() - std::string(" or None]").length());
205
auto type_end_idx = arg.find_last_of(' ');
206
auto name_start_idx = type_end_idx + 1;
210
auto dots_idx = arg.find("...");
211
if (dots_idx != std::string::npos)
214
std::string type_name =
215
arg.substr(type_start_idx, type_end_idx - type_start_idx);
216
std::string name = arg.substr(name_start_idx);
218
arguments.emplace_back(name, _buildType(type_name, is_nullable));
221
bool is_variadic = option_str.find("...") != std::string::npos;
222
return std::pair<Option, std::string>(
223
Option(std::move(arguments), is_variadic, has_out),
224
std::move(printable_option));
228
const Option& option,
229
const std::vector<PyObject*>& arguments,
230
const std::unordered_map<std::string, PyObject*>& kwargs) {
231
auto num_expected = option.arguments.size();
232
auto num_got = arguments.size() + kwargs.size();
234
if (option.has_out && kwargs.count("out") == 0)
236
return num_got == num_expected ||
237
(option.is_variadic && num_got > num_expected);
240
std::string _formattedArgDesc(
241
const Option& option,
242
const std::vector<PyObject*>& arguments,
243
const std::unordered_map<std::string, PyObject*>& kwargs) {
245
std::string reset_red;
247
std::string reset_green;
248
if (isatty(1) && isatty(2)) {
250
reset_red = "\33[0m";
252
reset_green = "\33[0m";
260
auto num_args = arguments.size() + kwargs.size();
261
std::string result = "(";
262
for (const auto i : c10::irange(num_args)) {
263
bool is_kwarg = i >= arguments.size();
265
is_kwarg ? kwargs.at(option.arguments[i].name) : arguments[i];
267
bool is_matching = false;
268
if (i < option.arguments.size()) {
269
is_matching = option.arguments[i].type->is_matching(arg);
270
} else if (option.is_variadic) {
271
is_matching = option.arguments.back().type->is_matching(arg);
279
result += option.arguments[i].name + "=";
280
bool is_tuple = PyTuple_Check(arg);
281
if (is_tuple || PyList_Check(arg)) {
282
result += py_typename(arg) + " of ";
283
auto num_elements = PySequence_Length(arg);
289
for (const auto i : c10::irange(num_elements)) {
293
result += py_typename(
294
py::reinterpret_steal<py::object>(PySequence_GetItem(arg, i))
298
if (num_elements == 1) {
306
result += py_typename(arg);
309
result += reset_green;
314
if (!arguments.empty())
315
result.erase(result.length() - 2);
321
const std::vector<PyObject*>& arguments,
322
const std::unordered_map<std::string, PyObject*>& kwargs) {
323
std::string result = "(";
324
for (auto& arg : arguments)
325
result += std::string(py_typename(arg)) + ", ";
326
for (auto& kwarg : kwargs)
327
result += kwarg.first + "=" + py_typename(kwarg.second) + ", ";
328
if (!arguments.empty())
329
result.erase(result.length() - 2);
334
std::vector<std::string> _tryMatchKwargs(
335
const Option& option,
336
const std::unordered_map<std::string, PyObject*>& kwargs) {
337
std::vector<std::string> unmatched;
339
int64_t start_idx = option.arguments.size() - kwargs.size();
340
if (option.has_out && kwargs.count("out") == 0)
344
for (auto& entry : kwargs) {
346
for (unsigned int i = start_idx; i < option.arguments.size(); i++) {
347
if (option.arguments[i].name == entry.first) {
353
unmatched.push_back(entry.first);
360
std::string format_invalid_args(
361
PyObject* given_args,
362
PyObject* given_kwargs,
363
const std::string& function_name,
364
const std::vector<std::string>& options) {
365
std::vector<PyObject*> args;
366
std::unordered_map<std::string, PyObject*> kwargs;
367
std::string error_msg;
368
error_msg.reserve(2000);
369
error_msg += function_name;
370
error_msg += " received an invalid combination of arguments - ";
372
Py_ssize_t num_args = PyTuple_Size(given_args);
373
for (const auto i : c10::irange(num_args)) {
374
PyObject* arg = PyTuple_GET_ITEM(given_args, i);
378
bool has_kwargs = given_kwargs && PyDict_Size(given_kwargs) > 0;
380
PyObject *key = nullptr, *value = nullptr;
383
while (PyDict_Next(given_kwargs, &pos, &key, &value)) {
384
kwargs.emplace(THPUtils_unpackString(key), value);
388
if (options.size() == 1) {
389
auto pair = _parseOption(options[0], kwargs);
390
auto& option = pair.first;
391
auto& option_str = pair.second;
392
std::vector<std::string> unmatched_kwargs;
394
unmatched_kwargs = _tryMatchKwargs(option, kwargs);
395
if (!unmatched_kwargs.empty()) {
396
error_msg += "got unrecognized keyword arguments: ";
397
for (auto& kwarg : unmatched_kwargs)
398
error_msg += kwarg + ", ";
399
error_msg.erase(error_msg.length() - 2);
402
if (_argcountMatch(option, args, kwargs)) {
403
error_msg += _formattedArgDesc(option, args, kwargs);
405
error_msg += _argDesc(args, kwargs);
407
error_msg += ", but expected ";
408
error_msg += option_str;
412
error_msg += _argDesc(args, kwargs);
413
error_msg += ", but expected one of:\n";
414
for (auto& option_str : options) {
415
auto pair = _parseOption(option_str, kwargs);
416
auto& option = pair.first;
417
auto& printable_option_str = pair.second;
419
error_msg += printable_option_str;
421
if (_argcountMatch(option, args, kwargs)) {
422
std::vector<std::string> unmatched_kwargs;
424
unmatched_kwargs = _tryMatchKwargs(option, kwargs);
425
if (!unmatched_kwargs.empty()) {
427
" didn't match because some of the keywords were incorrect: ";
428
for (auto& kwarg : unmatched_kwargs)
429
error_msg += kwarg + ", ";
430
error_msg.erase(error_msg.length() - 2);
434
" didn't match because some of the arguments have invalid types: ";
435
error_msg += _formattedArgDesc(option, args, kwargs);