1
#include <torch/csrc/utils/python_arg_parser.h>
3
#include <torch/csrc/Exceptions.h>
4
#include <torch/csrc/Layout.h>
5
#include <torch/csrc/MemoryFormat.h>
6
#include <torch/csrc/autograd/python_variable.h>
7
#include <torch/csrc/utils/invalid_arguments.h>
8
#include <torch/csrc/utils/python_strings.h>
9
#include <torch/csrc/utils/python_torch_function_mode.h>
10
#include <torch/csrc/utils/torch_dispatch_mode.h>
13
#include <ATen/PythonTorchFunctionTLS.h>
14
#include <ATen/TracerMode.h>
15
#include <c10/util/irange.h>
20
#include <unordered_map>
25
static std::unordered_map<std::string, ParameterType> type_map = {
26
{"Tensor", ParameterType::TENSOR},
27
{"Scalar", ParameterType::SCALAR},
28
{"int64_t", ParameterType::INT64},
29
{"SymInt", ParameterType::SYM_INT},
30
{"double", ParameterType::DOUBLE},
31
{"complex", ParameterType::COMPLEX},
32
{"TensorList", ParameterType::TENSOR_LIST},
33
{"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
34
{"IntArrayRef", ParameterType::INT_LIST},
35
{"SymIntArrayRef", ParameterType::SYM_INT_LIST},
36
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
37
{"Generator", ParameterType::GENERATOR},
38
{"bool", ParameterType::BOOL},
39
{"Storage", ParameterType::STORAGE},
40
{"PyObject*", ParameterType::PYOBJECT},
41
{"ScalarType", ParameterType::SCALARTYPE},
42
{"Layout", ParameterType::LAYOUT},
43
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
44
{"QScheme", ParameterType::QSCHEME},
45
{"Device", ParameterType::DEVICE},
46
{"DeviceIndex", ParameterType::INT64},
47
{"Stream", ParameterType::STREAM},
48
{"std::string", ParameterType::STRING},
49
{"c10::string_view", ParameterType::STRING},
50
{"Dimname", ParameterType::DIMNAME},
51
{"DimnameList", ParameterType::DIMNAME_LIST},
52
{"ScalarList", ParameterType::SCALAR_LIST},
53
{"DispatchKeySet", ParameterType::DISPATCH_KEY_SET},
71
static const std::unordered_map<std::string, std::vector<std::string>>
72
numpy_compatibility_arg_names = {
74
{"keepdim", {"keepdims"}},
75
{"input", {"x", "a", "x1"}},
87
bool should_allow_numbers_as_tensors(const std::string& name) {
88
static std::unordered_set<std::string> allowed = {
89
"add", "add_", "add_out",
90
"div", "div_", "div_out",
91
"divide", "divide_", "divide_out",
92
"mul", "mul_", "mul_out",
93
"multiply", "multiply_", "multiply_out",
94
"sub", "sub_", "sub_out",
95
"subtract", "subtract_", "subtract_out",
96
"true_divide", "true_divide_", "true_divide_out",
97
"to", "_to_copy", "copy_",
98
"floor_divide", "floor_divide_", "floor_divide_out"};
99
return allowed.find(name) != allowed.end();
103
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
106
keyword_only(keyword_only),
109
auto space = fmt.find(' ');
110
if (space == std::string::npos) {
111
throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
114
auto type_str = fmt.substr(0, space);
116
auto question = type_str.find('?');
117
if (question != std::string::npos) {
119
type_str = type_str.substr(0, question);
123
auto bracket = type_str.find('[');
124
if (bracket != std::string::npos) {
126
type_str.substr(bracket + 1, type_str.length() - bracket - 2);
127
size = atoi(size_str.c_str());
128
type_str = type_str.substr(0, bracket);
131
auto name_str = fmt.substr(space + 1);
132
auto it = type_map.find(type_str);
133
if (it == type_map.end()) {
134
throw std::runtime_error(
135
"FunctionParameter(): invalid type string: " + type_str);
139
auto eq = name_str.find('=');
140
if (eq != std::string::npos) {
141
name = name_str.substr(0, eq);
143
set_default_str(name_str.substr(eq + 1));
147
python_name = THPUtils_internString(name);
148
auto np_compat_it = numpy_compatibility_arg_names.find(name);
149
if (np_compat_it != numpy_compatibility_arg_names.end()) {
150
for (const auto& str : np_compat_it->second) {
151
numpy_python_names.push_back(THPUtils_internString(str));
156
auto handle_torch_function_getter(
158
const std::string& property_name) -> PyObject* {
159
py::object torch_api = PyObject_FastGetAttrString(
160
THPVariableClass, (char*)property_name.c_str());
161
std::string module_name = "torch.Tensor." + property_name;
162
return handle_torch_function(
171
auto handle_torch_function_setter(
173
const std::string& property_name,
174
PyObject* value) -> int {
175
py::object torch_api = PyObject_FastGetAttrString(
176
THPVariableClass, (char*)property_name.c_str());
177
std::string module_name = "torch.Tensor." + property_name;
178
if (value != nullptr) {
179
py::tuple args_ = py::make_tuple(py::handle(value));
180
handle_torch_function(
188
handle_torch_function(
200
static auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple {
201
if (args == nullptr) {
202
return py::make_tuple(py::handle(self));
203
} else if (self == nullptr) {
204
return py::reinterpret_borrow<py::tuple>(args);
207
auto py_args = py::reinterpret_borrow<py::tuple>(args);
208
size_t n = py_args.size();
209
auto args_ = py::tuple(n + 1);
210
args_[0] = py::handle(self);
211
for (const auto i : c10::irange(n)) {
212
args_[i + 1] = py_args[i];
223
const char* torch_function_mode_name = "__torch_function__";
225
auto handle_torch_function(
227
const std::string& func_name,
231
const std::string& module_name) -> PyObject* {
232
py::object torch_api_function =
233
PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str());
234
TORCH_INTERNAL_ASSERT(
235
torch_api_function.ptr() != nullptr, "torch API function must exist");
236
py::tuple args_ = combine_self_args(self, args);
237
return handle_torch_function_no_python_arg_parser(
242
torch_api_function.ptr(),
244
TorchFunctionName::TorchFunction);
255
static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
256
if (PyType_Check(obj_or_type)) {
259
return (PyObject*)Py_TYPE(obj_or_type);
262
static py::object dispatch_on_subclass(
265
at::ArrayRef<PyObject*> overloaded_args,
267
PyObject* torch_api_function,
268
bool is_torch_function,
269
const char* torch_function_name_str,
270
c10::optional<c10::impl::TorchDispatchModeKey> maybe_mode_key =
273
for (auto& arg : overloaded_args) {
274
py::object torch_function =
275
PyObject_FastGetAttrString(arg, torch_function_name_str);
276
if (!torch_function) {
277
TORCH_INTERNAL_ASSERT(0);
279
if (torch_function.ptr() == torch::disabled_torch_dispatch_impl()) {
290
if (is_torch_function &&
291
PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
292
.is(py::handle(arg)) &&
293
torch_function.ptr() != torch::disabled_torch_function_impl()) {
296
torch_function_name_str,
297
"` as a plain method is deprecated ",
298
"and will be an error in future, please define it as a classmethod.");
301
ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
302
torch_function.ptr(),
308
if (ret.ptr() == nullptr) {
309
throw python_error();
311
if (ret.ptr() != Py_NotImplemented) {
321
static std::tuple<py::object, py::object> dispatch_on_mode(
325
PyObject* torch_api_function,
326
bool is_torch_function,
327
const char* torch_function_name_str) {
330
at::optional<torch::overrides::StashTorchFunctionModeGuard> tf_g;
331
at::optional<torch_dispatch_mode::StashTorchDispatchModeGuard> td_g;
335
if (is_torch_function) {
337
mode_obj = py::reinterpret_borrow<py::object>(
338
tf_g->get_cur_mode()->ptr(getPyInterpreter()));
341
mode_obj = py::reinterpret_borrow<py::object>(
342
td_g->get_cur_mode()->ptr(getPyInterpreter()));
344
py::object torch_function =
345
PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str);
346
if (!torch_function) {
347
TORCH_INTERNAL_ASSERT(0);
349
TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr);
350
TORCH_INTERNAL_ASSERT(args != nullptr);
353
PyObject_FastGetAttrString(torch_function.ptr(), "__self__").is(mode_obj),
354
"Defining your mode's `",
355
torch_function_name_str,
356
"` as a classmethod is not supported, please make it a plain method");
361
if (kwargs == nullptr) {
362
ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
364
torch_function_name_str,
370
ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
372
torch_function_name_str,
379
if (ret.ptr() == nullptr) {
380
throw python_error();
382
return std::make_tuple(ret, mode_obj);
386
auto handle_torch_function_no_python_arg_parser(
387
at::ArrayRef<PyObject*> overloaded_args,
390
const char* func_name,
391
PyObject* torch_api_function,
392
const char* module_name,
393
TorchFunctionName torch_function_name) -> PyObject* {
394
const char* torch_function_name_str = nullptr;
395
switch (torch_function_name) {
396
case TorchFunctionName::TorchFunction:
397
torch_function_name_str = "__torch_function__";
399
case TorchFunctionName::TorchDispatch:
400
torch_function_name_str = "__torch_dispatch__";
403
TORCH_INTERNAL_ASSERT(0, static_cast<int>(torch_function_name));
408
std::vector<py::object> overloaded_types;
409
overloaded_types.reserve(overloaded_args.size());
410
for (auto& arg : overloaded_args) {
411
overloaded_types.push_back(
412
py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg)));
414
py::tuple py_types = py::cast(overloaded_types);
420
const bool is_torch_function =
421
torch_function_name == TorchFunctionName::TorchFunction;
422
const auto is_mode_active = [&]() {
423
return is_torch_function
424
? at::impl::torch_function_mode_enabled()
427
: c10::impl::dispatch_mode_enabled();
467
if (is_mode_active()) {
470
auto ret_ = dispatch_on_mode(
476
torch_function_name_str);
477
ret = std::get<0>(ret_);
478
mode_obj = std::get<1>(ret_);
487
if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) {
488
auto curr_ret = dispatch_on_subclass(
495
torch_function_name_str);
496
if (curr_ret.ptr() != nullptr) {
501
if (ret.ptr() == nullptr) {
504
throw python_error();
505
} else if (ret.ptr() == Py_NotImplemented) {
508
std::stringstream ss;
509
ss << "Multiple dispatch failed for '";
510
if (module_name && func_name) {
511
ss << module_name << "." << func_name;
513
py::handle fn = torch_api_function;
514
ss << py::str(fn.attr("__module__")) << "."
515
<< py::str(fn.attr("__name__"));
517
ss << "'; all " << torch_function_name_str
518
<< " handlers returned NotImplemented:\n\n";
520
ss << " - mode object " << py::repr(mode_obj) << "\n";
522
for (auto& arg : overloaded_args) {
523
ss << " - tensor subclass " << py::repr(get_type_of_overloaded_arg(arg))
526
ss << "\nFor more information, try re-running with TORCH_LOGS=not_implemented";
527
const std::string& tmp = ss.str();
528
PyErr_SetString(PyExc_TypeError, tmp.c_str());
529
throw python_error();
531
return ret.release().ptr();
534
auto handle_torch_function(
540
const char* module_name,
541
const char* func_name_override) -> PyObject* {
542
py::object torch_api_function = PyObject_FastGetAttrString(
544
(char*)(func_name_override ? func_name_override
545
: r.get_func_name().c_str()));
546
TORCH_INTERNAL_ASSERT(
547
torch_api_function.ptr() != nullptr, "torch API function must exist");
548
py::tuple args_ = combine_self_args(self, args);
549
return handle_torch_function_no_python_arg_parser(
553
r.get_func_name().c_str(),
554
torch_api_function.ptr(),
558
auto handle_torch_function(
563
const char* module_name,
564
const char* func_name_override) -> PyObject* {
565
return handle_torch_function(
566
r, nullptr, args, kwargs, torch_api, module_name, func_name_override);
569
auto handle_torch_function_indexing(
572
PyObject* val) -> PyObject* {
573
const char* func_name = (val == nullptr) ? "__getitem__" : "__setitem__";
574
py::object index_tup;
575
if (PyTuple_Check(index)) {
576
index_tup = py::reinterpret_borrow<py::object>(index);
578
index_tup = py::make_tuple(py::handle(index));
580
std::vector<PyObject*> overridable_args;
581
is_tensor_and_append_overloaded(self, &overridable_args);
582
auto size = PyTuple_GET_SIZE(index_tup.ptr());
583
for (auto i : c10::irange(size)) {
584
auto* obj = PyTuple_GetItem(index_tup.ptr(), i);
585
is_tensor_and_append_overloaded(obj, &overridable_args);
587
if (val != nullptr) {
588
is_tensor_and_append_overloaded(val, &overridable_args);
591
PyObject_FastGetAttrString(THPVariableClass, (char*)func_name);
592
py::object args = (val == nullptr)
593
? py::make_tuple(py::handle(self), py::handle(index))
594
: py::make_tuple(py::handle(self), py::handle(index), py::handle(val));
595
return handle_torch_function_no_python_arg_parser(
644
static void append_overloaded_arg(
645
std::vector<PyObject*>* overloaded_args,
648
bool class_not_seen_yet = true;
649
PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
650
for (auto& arg : *overloaded_args) {
651
if (obj_type == get_type_of_overloaded_arg(arg)) {
656
class_not_seen_yet = false;
660
if (class_not_seen_yet) {
661
auto arg_index = overloaded_args->size();
662
for (const auto j : c10::irange(arg_index)) {
663
if (PyObject_IsSubclass(
664
obj_type, get_type_of_overloaded_arg((*overloaded_args)[j]))) {
675
overloaded_args->insert(
676
overloaded_args->begin() + static_cast<long>(arg_index), obj);
680
void append_overloaded_tensor(
681
std::vector<PyObject*>* overloaded_args,
683
append_overloaded_arg(overloaded_args, obj, false);
686
void append_overloaded_type(
687
std::vector<PyObject*>* overloaded_args,
689
append_overloaded_arg(overloaded_args, obj, true);
692
bool is_tensor_and_append_overloaded(
694
std::vector<PyObject*>* overloaded_args) {
695
if (THPVariable_CheckExact(obj)) {
700
if (check_has_torch_function(obj, true)) {
702
append_overloaded_tensor(overloaded_args, obj);
704
} else if (THPVariable_Check(obj)) {
712
static bool is_scalar_list(PyObject* obj) {
713
auto tuple = six::isTuple(obj);
714
if (!(tuple || PyList_Check(obj))) {
718
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
719
for (const auto idx : c10::irange(size)) {
721
tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
722
if (!THPUtils_checkScalar(iobj)) {
729
bool is_tensor_list_and_append_overloaded(
731
std::vector<PyObject*>* overloaded_args,
734
auto tuple = six::isTuple(obj);
735
if (!(tuple || PyList_Check(obj))) {
739
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
740
for (long idx = 0; idx < size; idx++) {
742
tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
743
if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) {
747
"expected Tensor as element ",
752
Py_TYPE(iobj)->tp_name);
760
static bool is_float_or_complex_list(PyObject* obj) {
761
auto tuple = six::isTuple(obj);
762
if (!(tuple || PyList_Check(obj))) {
767
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
769
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
770
if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
778
static bool is_int_or_symint(PyObject* obj) {
783
if (torch::is_symint(py::handle(obj))) {
794
if (THPVariable_Check(obj)) {
795
auto& var = THPVariable_Unpack(obj);
796
if (var.numel() == 1 &&
797
at::isIntegralType(var.dtype().toScalarType(), true)) {
802
if (THPUtils_checkIndex(obj)) {
809
static bool is_int_or_symint_list(
812
int64_t* failed_idx = nullptr) {
813
if (PyTuple_Check(obj) || PyList_Check(obj)) {
814
if (PySequence_Size(obj) == 0) {
817
auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
819
if (is_int_or_symint(item.ptr())) {
826
(jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
827
THPVariable_Unpack(item.ptr()).sizes().empty());
828
if (!r && failed_idx != nullptr) {
836
return broadcast_size > 0 && is_int_or_symint(obj);
840
auto FunctionParameter::check(
842
std::vector<PyObject*>& overloaded_args,
844
int64_t* failed_idx) -> bool {
846
case ParameterType::TENSOR: {
847
if (is_tensor_and_append_overloaded(obj, &overloaded_args)) {
850
if (allow_numbers_as_tensors) {
851
return THPUtils_checkScalar(obj);
855
case ParameterType::SCALAR:
856
if (THPUtils_checkScalar(obj)) {
860
case ParameterType::COMPLEX:
861
if (PyComplex_Check(obj)) {
865
case ParameterType::DOUBLE: {
866
if (THPUtils_checkDouble(obj)) {
869
if (THPVariable_Check(obj)) {
870
const auto& var = THPVariable_Unpack(obj);
871
return !var.requires_grad() && var.dim() == 0;
873
if (torch::is_symfloat(py::handle(obj)) ||
874
torch::is_symint(py::handle(obj))) {
880
case ParameterType::INT64: {
881
if (THPUtils_checkLong(obj)) {
884
if (THPVariable_Check(obj)) {
885
const auto& var = THPVariable_Unpack(obj);
886
return at::isIntegralType(var.scalar_type(), false) &&
887
!var.requires_grad() && var.dim() == 0;
889
if (torch::is_symint(py::handle(obj))) {
895
case ParameterType::DIMNAME:
896
return THPUtils_checkDimname(obj);
897
case ParameterType::DIMNAME_LIST: {
898
if (THPUtils_checkDimnameList(obj)) {
903
return size == 1 && THPUtils_checkDimname(obj);
905
case ParameterType::TENSOR_LIST: {
906
return is_tensor_list_and_append_overloaded(
907
obj, &overloaded_args, argnum, true );
909
case ParameterType::FLOAT_LIST:
910
return is_float_or_complex_list(obj);
911
case ParameterType::GENERATOR:
912
return THPGenerator_Check(obj);
913
case ParameterType::BOOL:
914
return PyBool_Check(obj);
915
case ParameterType::STORAGE:
916
return isStorage(obj);
917
case ParameterType::PYOBJECT:
919
case ParameterType::SCALARTYPE:
920
return THPDtype_Check(obj) || THPPythonScalarType_Check(obj);
921
case ParameterType::LAYOUT:
922
return THPLayout_Check(obj);
923
case ParameterType::MEMORY_FORMAT:
924
return THPMemoryFormat_Check(obj);
925
case ParameterType::QSCHEME:
926
return THPQScheme_Check(obj);
927
case ParameterType::DEVICE:
928
return THPUtils_checkLong(obj) || THPUtils_checkString(obj) ||
929
THPDevice_Check(obj);
930
case ParameterType::STREAM:
931
return THPStream_Check(obj);
932
case ParameterType::STRING:
933
return THPUtils_checkString(obj);
934
case ParameterType::SCALAR_LIST:
935
return is_scalar_list(obj);
936
case ParameterType::SYM_INT:
937
return is_int_or_symint(obj);
939
case ParameterType::INT_LIST:
940
case ParameterType::SYM_INT_LIST:
941
return is_int_or_symint_list(obj, size, failed_idx);
942
case ParameterType::DISPATCH_KEY_SET:
943
return py::isinstance<c10::DispatchKeySet>(py::handle(obj));
945
throw std::runtime_error("unknown parameter type");
950
std::string FunctionParameter::type_name() const {
952
case ParameterType::TENSOR:
954
case ParameterType::SCALAR:
956
case ParameterType::INT64:
959
case ParameterType::SYM_INT:
961
case ParameterType::DOUBLE:
963
case ParameterType::COMPLEX:
965
case ParameterType::TENSOR_LIST:
966
return "tuple of Tensors";
967
case ParameterType::INT_LIST:
968
return "tuple of ints";
969
case ParameterType::FLOAT_LIST:
970
return "tuple of floats";
971
case ParameterType::GENERATOR:
972
return "torch.Generator";
973
case ParameterType::BOOL:
975
case ParameterType::STORAGE:
976
return "torch.Storage";
977
case ParameterType::PYOBJECT:
979
case ParameterType::SCALARTYPE:
980
return "torch.dtype";
981
case ParameterType::LAYOUT:
982
return "torch.layout";
983
case ParameterType::MEMORY_FORMAT:
984
return "torch.memory_format";
985
case ParameterType::QSCHEME:
986
return "torch.qscheme";
987
case ParameterType::DEVICE:
988
return "torch.device";
989
case ParameterType::STRING:
991
case ParameterType::DIMNAME:
993
case ParameterType::DIMNAME_LIST:
994
return "tuple of names";
995
case ParameterType::SCALAR_LIST:
996
return "tuple of Scalars";
997
case ParameterType::SYM_INT_LIST:
998
return "tuple of ints";
999
case ParameterType::DISPATCH_KEY_SET:
1000
return "DispatchKeySet";
1002
throw std::runtime_error("unknown parameter type");
1006
static inline c10::optional<int64_t> parse_as_integer(const std::string& s) {
1008
return c10::nullopt;
1009
char* str_end = nullptr;
1010
long ans = strtol(s.c_str(), &str_end, 0);
1012
return (*str_end == 0) ? c10::optional<int64_t>(ans) : c10::nullopt;
1023
static inline std::vector<int64_t> parse_intlist_args(
1024
const std::string& s,
1026
size_t n = s.size();
1029
return std::vector<int64_t>();
1033
TORCH_CHECK(size > 0, "Incorrect size of IntArrayRef: ", size);
1034
return std::vector<int64_t>(size, std::stol(s));
1043
"Default value of IntArrayRef is missing right brace '}', found ",
1046
auto args = std::vector<int64_t>();
1047
std::istringstream ss(s.substr(1, s.length() - 2));
1050
while (std::getline(ss, tok, ',')) {
1051
args.emplace_back(std::stol(tok));
1057
static std::string parse_string_literal(c10::string_view str) {
1058
TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
1060
if (str.front() == '"') {
1062
str.back() == '"', "Mismatched quotes in string default: ", str);
1065
str.front() == '\'' && str.back() == '\'',
1066
"Invalid quotes in string default: ",
1071
parsed.reserve(str.size());
1072
for (size_t i = 1; i < str.size() - 1;) {
1073
if (str[i] != '\\') {
1074
parsed.push_back(str[i]);
1081
i < str.size() - 2, "String ends with escaped final quote: ", str)
1082
char c = str[i + 1];
1109
"Unsupported escape sequence in string default: \\",
1112
parsed.push_back(c);
1118
void FunctionParameter::set_default_str(const std::string& str) {
1119
if (str == "None") {
1122
if (type_ == ParameterType::TENSOR ||
1123
type_ == ParameterType::DISPATCH_KEY_SET) {
1124
if (str != "None") {
1125
throw std::runtime_error(
1126
"default value for Tensor must be none, got: " + str);
1128
} else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
1129
default_int = atol(str.c_str());
1130
} else if (type_ == ParameterType::BOOL) {
1131
default_bool = (str == "True" || str == "true");
1132
} else if (type_ == ParameterType::DOUBLE) {
1133
default_double = atof(str.c_str());
1134
} else if (type_ == ParameterType::COMPLEX) {
1135
default_complex[0] = atof(str.c_str());
1136
default_complex[1] = 0;
1137
} else if (type_ == ParameterType::SCALAR) {
1138
if (str != "None") {
1140
const auto as_integer = parse_as_integer(str);
1141
default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value())
1142
: at::Scalar(atof(str.c_str()));
1145
type_ == ParameterType::INT_LIST ||
1146
type_ == ParameterType::SYM_INT_LIST) {
1147
if (str != "None") {
1148
default_intlist = parse_intlist_args(str, size);
1150
} else if (type_ == ParameterType::FLOAT_LIST) {
1151
if (str != "None") {
1152
throw std::runtime_error("Defaults not supported for float[]");
1154
} else if (type_ == ParameterType::SCALARTYPE) {
1155
if (str == "None") {
1156
default_scalartype = at::ScalarType::Undefined;
1157
} else if (str == "torch.int64") {
1158
default_scalartype = at::ScalarType::Long;
1160
throw std::runtime_error("invalid default value for ScalarType: " + str);
1162
} else if (type_ == ParameterType::LAYOUT) {
1163
if (str == "None") {
1164
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
1165
} else if (str == "torch.strided") {
1166
default_layout = at::Layout::Strided;
1167
} else if (str == "torch.sparse_coo") {
1168
default_layout = at::Layout::Sparse;
1170
throw std::runtime_error("invalid default value for layout: " + str);
1172
} else if (type_ == ParameterType::DEVICE) {
1173
if (str != "None") {
1174
throw std::runtime_error("invalid device: " + str);
1176
} else if (type_ == ParameterType::STREAM) {
1177
if (str != "None") {
1178
throw std::runtime_error("invalid stream: " + str);
1180
} else if (type_ == ParameterType::STRING) {
1181
if (str != "None") {
1182
default_string = parse_string_literal(str);
1189
else if (type_ == ParameterType::TENSOR_LIST) {
1191
} else if (type_ == ParameterType::GENERATOR) {
1193
} else if (type_ == ParameterType::PYOBJECT) {
1195
} else if (type_ == ParameterType::MEMORY_FORMAT) {
1197
} else if (type_ == ParameterType::DIMNAME) {
1199
} else if (type_ == ParameterType::DIMNAME_LIST) {
1201
} else if (type_ == ParameterType::SCALAR_LIST) {
1203
} else if (type_ == ParameterType::STORAGE) {
1205
} else if (type_ == ParameterType::QSCHEME) {
1208
throw std::runtime_error("unknown parameter type");
1213
FunctionSignature::FunctionSignature(const std::string& fmt, int index)
1220
auto open_paren = fmt.find('(');
1221
if (open_paren == std::string::npos) {
1222
throw std::runtime_error("missing opening parenthesis: " + fmt);
1224
name = fmt.substr(0, open_paren);
1226
bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
1228
auto last_offset = open_paren + 1;
1229
bool keyword_only = false;
1232
auto offset = fmt.find(", ", last_offset);
1233
auto next_offset = offset + 2;
1234
if (offset == std::string::npos) {
1235
offset = fmt.find(')', last_offset);
1237
next_offset = offset + 1;
1239
if (offset == last_offset) {
1240
last_offset = next_offset;
1244
if (offset == std::string::npos) {
1245
throw std::runtime_error("missing closing parenthesis: " + fmt);
1247
if (offset == last_offset) {
1248
throw std::runtime_error("malformed signature: " + fmt);
1251
auto param_str = fmt.substr(last_offset, offset - last_offset);
1252
last_offset = next_offset;
1253
if (param_str == "*") {
1254
keyword_only = true;
1256
params.emplace_back(param_str, keyword_only);
1257
params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
1261
if (fmt.substr(last_offset) == "|deprecated") {
1265
} else if (fmt.substr(last_offset) == "|hidden") {
1269
max_args = params.size();
1272
for (auto& param : params) {
1273
if (!param.optional) {
1276
if (!param.keyword_only) {
1282
std::string FunctionSignature::toString() const {
1285
std::ostringstream ss;
1286
bool keyword_already = false;
1289
for (auto& param : params) {
1293
if (param.keyword_only && !keyword_already) {
1295
keyword_already = true;
1297
ss << param.type_name() << " " << param.name;
1304
[[noreturn]] static void extra_args(
1305
const FunctionSignature& signature,
1307
const auto max_pos_args = signature.max_pos_args;
1308
const auto min_args = signature.min_args;
1309
const long nargs_ = nargs;
1310
if (min_args != max_pos_args) {
1312
"%s() takes from %zu to %zu positional arguments but %ld were given",
1313
signature.name.c_str(),
1319
"%s() takes %zu positional argument%s but %ld %s given",
1320
signature.name.c_str(),
1322
max_pos_args == 1 ? "" : "s",
1324
nargs == 1 ? "was" : "were");
1327
[[noreturn]] static void missing_args(
1328
const FunctionSignature& signature,
1330
int num_missing = 0;
1331
std::stringstream ss;
1333
auto& params = signature.params;
1334
for (auto it = params.begin() + idx; it != params.end(); ++it) {
1335
if (!it->optional) {
1336
if (num_missing > 0) {
1339
ss << '"' << it->name << '"';
1345
"%s() missing %d required positional argument%s: %s",
1346
signature.name.c_str(),
1348
num_missing == 1 ? "s" : "",
1352
static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) {
1354
for (auto& param : signature.params) {
1355
int cmp = PyObject_RichCompareBool(name, param.python_name, Py_EQ);
1357
throw python_error();
1366
[[noreturn]] static void extra_kwargs(
1367
FunctionSignature& signature,
1369
Py_ssize_t num_pos_args) {
1370
PyObject* key = nullptr;
1371
PyObject* value = nullptr;
1374
while (PyDict_Next(kwargs, &pos, &key, &value)) {
1375
if (!THPUtils_checkString(key)) {
1376
throw TypeError("keywords must be strings");
1379
auto param_idx = find_param(signature, key);
1380
if (param_idx < 0) {
1382
"%s() got an unexpected keyword argument '%s'",
1383
signature.name.c_str(),
1384
THPUtils_unpackString(key).c_str());
1387
if (param_idx < num_pos_args) {
1389
"%s() got multiple values for argument '%s'",
1390
signature.name.c_str(),
1391
THPUtils_unpackString(key).c_str());
1396
throw TypeError("invalid keyword arguments");
1399
bool FunctionSignature::parse(
1404
std::vector<PyObject*>& overloaded_args,
1405
bool raise_exception) {
1406
Py_ssize_t nargs = args ? PyTuple_GET_SIZE(args) : 0;
1407
auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
1409
bool allow_varargs_intlist = false;
1414
if (max_pos_args == 1 &&
1415
(params[0].type_ == ParameterType::INT_LIST ||
1416
params[0].type_ == ParameterType::SYM_INT_LIST)) {
1417
allow_varargs_intlist = true;
1420
if (static_cast<size_t>(nargs) > max_pos_args && !allow_varargs_intlist) {
1421
if (raise_exception) {
1423
extra_args(*this, nargs);
1429
if (self != nullptr && check_has_torch_function(self, true)) {
1430
append_overloaded_tensor(&overloaded_args, self);
1432
for (auto& param : params) {
1433
PyObject* obj = nullptr;
1434
bool is_kwd = false;
1435
if (arg_pos < static_cast<size_t>(nargs)) {
1437
if (param.keyword_only) {
1438
if (raise_exception) {
1439
extra_args(*this, nargs);
1443
obj = PyTuple_GET_ITEM(args, arg_pos);
1444
} else if (kwargs) {
1445
obj = PyDict_GetItem(kwargs, param.python_name);
1446
for (PyObject* numpy_name : param.numpy_python_names) {
1450
obj = PyDict_GetItem(kwargs, numpy_name);
1455
int64_t failed_idx = -1;
1456
bool varargs_eligible = allow_varargs_intlist && arg_pos == 0 && !is_kwd;
1457
if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
1460
if (raise_exception) {
1462
missing_args(*this, i);
1465
} else if (param.check(obj, overloaded_args, i, &failed_idx)) {
1472
(is_int_or_symint_list(args, param.size, &failed_idx))) {
1478
} else if (raise_exception) {
1482
"%s(): argument '%s' must be %s, not %s",
1485
param.type_name().c_str(),
1486
Py_TYPE(obj)->tp_name);
1489
if (failed_idx != -1) {
1490
if (!(PyTuple_Check(obj) || PyList_Check(obj))) {
1491
TORCH_INTERNAL_ASSERT(varargs_eligible);
1494
TORCH_INTERNAL_ASSERT(failed_idx < PySequence_Size(obj));
1496
"%s(): argument '%s' (position %ld) must be %s, but found element of type %s at pos %ld",
1499
static_cast<long>(arg_pos + 1),
1500
param.type_name().c_str(),
1501
Py_TYPE(py::reinterpret_steal<py::object>(
1502
PySequence_GetItem(obj, failed_idx))
1505
static_cast<long>(failed_idx));
1508
"%s(): argument '%s' (position %ld) must be %s, not %s",
1511
static_cast<long>(arg_pos + 1),
1512
param.type_name().c_str(),
1513
Py_TYPE(obj)->tp_name);
1526
if (remaining_kwargs > 0) {
1527
if (raise_exception) {
1529
extra_kwargs(*this, kwargs, nargs);
1536
PythonArgParser::PythonArgParser(
1537
const std::vector<std::string>& fmts,
1539
: max_args(0), traceable(traceable) {
1541
for (auto& fmt : fmts) {
1542
signatures_.emplace_back(fmt, index);
1545
for (auto& signature : signatures_) {
1546
if (signature.max_args > max_args) {
1547
max_args = signature.max_args;
1550
if (!signatures_.empty()) {
1551
function_name = signatures_[0].name;
1555
std::stable_partition(
1556
signatures_.begin(), signatures_.end(), [](const FunctionSignature& sig) {
1557
return !sig.deprecated;
1561
void PythonArgParser::check_deprecated(const FunctionSignature& signature) {
1562
if (signature.deprecated) {
1563
auto msg = c10::str(
1564
"This overload of ",
1566
" is deprecated:\n\t",
1568
signature.toString());
1569
auto signatures = get_signatures();
1570
if (!signatures.empty()) {
1571
msg += "\nConsider using one of the following signatures instead:";
1572
for (const auto& sig : signatures) {
1574
msg += signature.name;
1578
TORCH_WARN_ONCE(msg);
1582
PythonArgs PythonArgParser::raw_parse(
1586
PyObject* parsed_args[]) {
1587
if (signatures_.size() == 1) {
1588
auto& signature = signatures_[0];
1589
std::vector<PyObject*> overloaded_args;
1590
signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
1591
check_deprecated(signature);
1593
traceable, signature, parsed_args, std::move(overloaded_args));
1596
for (auto& signature : signatures_) {
1597
std::vector<PyObject*> overloaded_args;
1598
if (signature.parse(
1599
self, args, kwargs, parsed_args, overloaded_args, false)) {
1600
check_deprecated(signature);
1602
traceable, signature, parsed_args, std::move(overloaded_args));
1606
print_error(self, args, kwargs, parsed_args);
1609
void PythonArgParser::print_error(
1613
PyObject* parsed_args[]) {
1615
(args ? PyTuple_GET_SIZE(args) : 0) + (kwargs ? PyDict_Size(kwargs) : 0);
1616
std::vector<unsigned> plausible_idxs;
1618
for (auto& signature : signatures_) {
1619
if (num_args >= signature.min_args && num_args <= signature.max_args &&
1620
!signature.hidden) {
1621
plausible_idxs.push_back(i);
1626
if (plausible_idxs.size() == 1) {
1627
auto& signature = signatures_[plausible_idxs[0]];
1628
std::vector<PyObject*> overloaded_args;
1629
signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
1632
auto options = get_signatures();
1634
torch::format_invalid_args(args, kwargs, function_name + "()", options);
1635
throw TypeError("%s", msg.c_str());
1638
std::vector<std::string> PythonArgParser::get_signatures() const {
1639
std::vector<std::string> options;
1640
for (auto& signature : signatures_) {
1641
if (!signature.hidden) {
1642
options.push_back(signature.toString());
1648
at::Tensor PythonArgs::tensor_slow(int i) {
1649
PyObject* obj = args[i];
1651
return at::Tensor();
1653
if (THPVariable_Check(obj)) {
1654
return THPVariable_Unpack(obj);
1657
bool save_symint = false;
1659
if (PyBool_Check(obj)) {
1660
scalar = at::Scalar(THPUtils_unpackBool(obj));
1661
} else if (THPUtils_checkLong(obj)) {
1663
long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
1664
if (value == -1 && PyErr_Occurred()) {
1665
throw python_error();
1667
if (overflow != 0) {
1669
unsigned long long value = PyLong_AsUnsignedLongLong(obj);
1670
if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1671
throw python_error();
1673
scalar = at::Scalar(static_cast<uint64_t>(value));
1675
scalar = at::Scalar(static_cast<int64_t>(value));
1677
} else if (PyComplex_Check(obj)) {
1678
scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
1679
} else if (THPUtils_checkDouble(obj)) {
1680
scalar = at::Scalar(THPUtils_unpackDouble(obj));
1684
} else if (torch::is_symint(py::handle(obj))) {
1689
scalar = at::Scalar(7777777);
1690
} else if (torch::is_symfloat(py::handle(obj))) {
1692
scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
1693
} else if (torch::is_symbool(py::handle(obj))) {
1695
scalar = at::Scalar(true);
1703
"expected Tensor as argument %d, but got %s", i, Py_TYPE(obj)->tp_name);
1705
at::AutoDispatchBelowADInplaceOrView guard;
1706
at::tracer::impl::NoTracerDispatchMode tracer_guard;
1708
at::Tensor tensor = scalar_to_tensor(scalar);
1709
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
1712
auto py_tensor = py::cast(tensor);
1713
if (PyObject_SetAttrString(py_tensor.ptr(), "_wrapped_number", obj) < 0) {
1714
throw python_error();
1721
at::Scalar PythonArgs::scalar_slow(int i) {
1722
if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
1723
auto& var = THPVariable_Unpack(args[i]);
1724
jit::tracer::ArgumentStash::stashValue(
1725
signature.params[i].name, idx, var, c10::NumberType::get());
1728
return scalar_slow(args[i]);
1731
at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
1734
if (THPVariable_Check(arg)) {
1735
return THPVariable_Unpack(arg).item();
1738
if (THPUtils_checkLong(arg)) {
1740
long long value = PyLong_AsLongLongAndOverflow(arg, &overflow);
1741
if (value == -1 && PyErr_Occurred()) {
1742
throw python_error();
1744
if (overflow != 0) {
1746
unsigned long long value = PyLong_AsUnsignedLongLong(arg);
1747
if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1748
throw python_error();
1750
return at::Scalar(static_cast<uint64_t>(value));
1752
return at::Scalar(static_cast<int64_t>(value));
1756
if (PyBool_Check(arg)) {
1757
return at::Scalar(THPUtils_unpackBool(arg));
1760
if (PyComplex_Check(arg)) {
1761
return at::Scalar(THPUtils_unpackComplexDouble(arg));
1764
if (torch::is_symint(arg)) {
1765
return at::Scalar(py::cast<c10::SymInt>(arg));
1768
if (torch::is_symfloat(arg)) {
1769
return at::Scalar(py::cast<c10::SymFloat>(arg));
1772
if (torch::is_symbool(arg)) {
1775
auto sym_bool = py::handle(arg).cast<c10::SymBool>();
1776
return at::Scalar(sym_bool);
1779
return at::Scalar(THPUtils_unpackDouble(arg));