pytorch

Форк
0
/
python_arg_parser.cpp 
1782 строки · 57.3 Кб
1
#include <torch/csrc/utils/python_arg_parser.h>
2

3
#include <torch/csrc/Exceptions.h>
4
#include <torch/csrc/Layout.h>
5
#include <torch/csrc/MemoryFormat.h>
6
#include <torch/csrc/autograd/python_variable.h>
7
#include <torch/csrc/utils/invalid_arguments.h>
8
#include <torch/csrc/utils/python_strings.h>
9
#include <torch/csrc/utils/python_torch_function_mode.h>
10
#include <torch/csrc/utils/torch_dispatch_mode.h>
11

12
#include <ATen/ATen.h>
13
#include <ATen/PythonTorchFunctionTLS.h>
14
#include <ATen/TracerMode.h>
15
#include <c10/util/irange.h>
16

17
#include <sstream>
18
#include <stdexcept>
19
#include <string>
20
#include <unordered_map>
21
#include <vector>
22

23
namespace torch {
24

25
static std::unordered_map<std::string, ParameterType> type_map = {
26
    {"Tensor", ParameterType::TENSOR},
27
    {"Scalar", ParameterType::SCALAR},
28
    {"int64_t", ParameterType::INT64},
29
    {"SymInt", ParameterType::SYM_INT},
30
    {"double", ParameterType::DOUBLE},
31
    {"complex", ParameterType::COMPLEX},
32
    {"TensorList", ParameterType::TENSOR_LIST},
33
    {"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
34
    {"IntArrayRef", ParameterType::INT_LIST},
35
    {"SymIntArrayRef", ParameterType::SYM_INT_LIST},
36
    {"ArrayRef<double>", ParameterType::FLOAT_LIST},
37
    {"Generator", ParameterType::GENERATOR},
38
    {"bool", ParameterType::BOOL},
39
    {"Storage", ParameterType::STORAGE},
40
    {"PyObject*", ParameterType::PYOBJECT},
41
    {"ScalarType", ParameterType::SCALARTYPE},
42
    {"Layout", ParameterType::LAYOUT},
43
    {"MemoryFormat", ParameterType::MEMORY_FORMAT},
44
    {"QScheme", ParameterType::QSCHEME},
45
    {"Device", ParameterType::DEVICE},
46
    {"DeviceIndex", ParameterType::INT64},
47
    {"Stream", ParameterType::STREAM},
48
    {"std::string", ParameterType::STRING},
49
    {"c10::string_view", ParameterType::STRING},
50
    {"Dimname", ParameterType::DIMNAME},
51
    {"DimnameList", ParameterType::DIMNAME_LIST},
52
    {"ScalarList", ParameterType::SCALAR_LIST},
53
    {"DispatchKeySet", ParameterType::DISPATCH_KEY_SET},
54
};
55

56
// Default arg name translations for compatibility with NumPy.
57
//
58
// Example:
59
// ```python
60
// t = torch.randn(10,10)
61
// torch.sum(a=t, axis=0, keepdim=True)
62
// ```
63
//
64
// A vector is necessary, because we might need to try multiple values.
65
// In particular, NumPy sometimes uses "x" and sometimes "a" for the main input
66
// tensor. Rather than annotate each function separately with whether it should
67
// take "x" or "a", just try both.
68
//
69
// TODO: Allow individual functions to specify non-default translations:
70
// For example, `torch.pow` should translate "exponent" to "x2".
71
static const std::unordered_map<std::string, std::vector<std::string>>
72
    numpy_compatibility_arg_names = {
73
        {"dim", {"axis"}},
74
        {"keepdim", {"keepdims"}},
75
        {"input", {"x", "a", "x1"}},
76
        {"other", {"x2"}},
77
};
78

79
// TODO: remove this. This is a temporary list of functions that allow Python
80
// numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
81
// overloads and binding to the Tensor overload with a number of a different
82
// type will trigger a type error.
83
//
84
// If you modify this, you will need to adjust the blocklist in
85
// tools/pyi/gen_pyi.py (and add hardcoded signatures for these
86
// functions.)
87
bool should_allow_numbers_as_tensors(const std::string& name) {
88
  static std::unordered_set<std::string> allowed = {
89
      "add",          "add_",          "add_out",
90
      "div",          "div_",          "div_out",
91
      "divide",       "divide_",       "divide_out", // alias of div
92
      "mul",          "mul_",          "mul_out",
93
      "multiply",     "multiply_",     "multiply_out", // alias of mul
94
      "sub",          "sub_",          "sub_out",
95
      "subtract",     "subtract_",     "subtract_out", // alias of sub
96
      "true_divide",  "true_divide_",  "true_divide_out",
97
      "to",           "_to_copy",      "copy_",
98
      "floor_divide", "floor_divide_", "floor_divide_out"};
99
  return allowed.find(name) != allowed.end();
100
}
101

102
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
103
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
104
    : optional(false),
105
      allow_none(false),
106
      keyword_only(keyword_only),
107
      size(0),
108
      default_scalar(0) {
109
  auto space = fmt.find(' ');
110
  if (space == std::string::npos) {
111
    throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
112
  }
113

114
  auto type_str = fmt.substr(0, space);
115

116
  auto question = type_str.find('?');
117
  if (question != std::string::npos) {
118
    allow_none = true;
119
    type_str = type_str.substr(0, question);
120
  }
121

122
  // Parse and remove brackets from type_str
123
  auto bracket = type_str.find('[');
124
  if (bracket != std::string::npos) {
125
    auto size_str =
126
        type_str.substr(bracket + 1, type_str.length() - bracket - 2);
127
    size = atoi(size_str.c_str());
128
    type_str = type_str.substr(0, bracket);
129
  }
130

131
  auto name_str = fmt.substr(space + 1);
132
  auto it = type_map.find(type_str);
133
  if (it == type_map.end()) {
134
    throw std::runtime_error(
135
        "FunctionParameter(): invalid type string: " + type_str);
136
  }
137
  type_ = it->second;
138

139
  auto eq = name_str.find('=');
140
  if (eq != std::string::npos) {
141
    name = name_str.substr(0, eq);
142
    optional = true;
143
    set_default_str(name_str.substr(eq + 1));
144
  } else {
145
    name = name_str;
146
  }
147
  python_name = THPUtils_internString(name);
148
  auto np_compat_it = numpy_compatibility_arg_names.find(name);
149
  if (np_compat_it != numpy_compatibility_arg_names.end()) {
150
    for (const auto& str : np_compat_it->second) {
151
      numpy_python_names.push_back(THPUtils_internString(str));
152
    }
153
  }
154
}
155

156
auto handle_torch_function_getter(
157
    THPVariable* self,
158
    const std::string& property_name) -> PyObject* {
159
  py::object torch_api = PyObject_FastGetAttrString(
160
      THPVariableClass, (char*)property_name.c_str());
161
  std::string module_name = "torch.Tensor." + property_name;
162
  return handle_torch_function(
163
      (PyObject*)self,
164
      "__get__",
165
      nullptr,
166
      nullptr,
167
      torch_api.ptr(),
168
      module_name);
169
}
170

171
auto handle_torch_function_setter(
172
    THPVariable* self,
173
    const std::string& property_name,
174
    PyObject* value) -> int {
175
  py::object torch_api = PyObject_FastGetAttrString(
176
      THPVariableClass, (char*)property_name.c_str());
177
  std::string module_name = "torch.Tensor." + property_name;
178
  if (value != nullptr) {
179
    py::tuple args_ = py::make_tuple(py::handle(value));
180
    handle_torch_function(
181
        (PyObject*)self,
182
        "__set__",
183
        args_.ptr(),
184
        nullptr,
185
        torch_api.ptr(),
186
        module_name);
187
  } else {
188
    handle_torch_function(
189
        (PyObject*)self,
190
        "__delete__",
191
        nullptr,
192
        nullptr,
193
        torch_api.ptr(),
194
        module_name);
195
  }
196
  return 0;
197
}
198

199
// Combines self and args into one tuple.
200
static auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple {
201
  if (args == nullptr) {
202
    return py::make_tuple(py::handle(self));
203
  } else if (self == nullptr) {
204
    return py::reinterpret_borrow<py::tuple>(args);
205
  }
206

207
  auto py_args = py::reinterpret_borrow<py::tuple>(args);
208
  size_t n = py_args.size();
209
  auto args_ = py::tuple(n + 1);
210
  args_[0] = py::handle(self);
211
  for (const auto i : c10::irange(n)) {
212
    args_[i + 1] = py_args[i];
213
  }
214
  return args_;
215
}
216

217
// TODO: I'm not sure if I should call this __torch_function__ or
218
// torch_function.  The former makes it easier to take an existing
219
// Tensor-like __torch_function__ object and turn it into a mode;
220
// but in general modes don't have to be Tensor-like (and we will
221
// improperly accept mode objects as arguments when they shouldn't
222
// be passed around in this way).
223
const char* torch_function_mode_name = "__torch_function__";
224

225
auto handle_torch_function(
226
    PyObject* self,
227
    const std::string& func_name,
228
    PyObject* args,
229
    PyObject* kwargs,
230
    PyObject* torch_api,
231
    const std::string& module_name) -> PyObject* {
232
  py::object torch_api_function =
233
      PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str());
234
  TORCH_INTERNAL_ASSERT(
235
      torch_api_function.ptr() != nullptr, "torch API function must exist");
236
  py::tuple args_ = combine_self_args(self, args);
237
  return handle_torch_function_no_python_arg_parser(
238
      {self},
239
      args_.ptr(),
240
      kwargs,
241
      func_name.c_str(),
242
      torch_api_function.ptr(),
243
      module_name.c_str(),
244
      TorchFunctionName::TorchFunction);
245
}
246

247
// Note: [Overloaded args]
248
// An overloaded arg may be one of the following:
249
// - an instance of an object that has a __torch_function__ method
250
// - an instance of an object that has a __torch_dispatch__ classmethod
251
// - a class type that has a __torch_dispatch__ classmethod
252
//
253
// This function returns the type of the arg (if the arg is an instance),
254
// otherwise, it returns the arg.
255
static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
256
  if (PyType_Check(obj_or_type)) {
257
    return obj_or_type;
258
  }
259
  return (PyObject*)Py_TYPE(obj_or_type);
260
}
261

262
static py::object dispatch_on_subclass(
263
    PyObject* args,
264
    PyObject* kwargs,
265
    at::ArrayRef<PyObject*> overloaded_args,
266
    py::tuple py_types,
267
    PyObject* torch_api_function,
268
    bool is_torch_function,
269
    const char* torch_function_name_str,
270
    c10::optional<c10::impl::TorchDispatchModeKey> maybe_mode_key =
271
        c10::nullopt) {
272
  py::object ret;
273
  for (auto& arg : overloaded_args) {
274
    py::object torch_function =
275
        PyObject_FastGetAttrString(arg, torch_function_name_str);
276
    if (!torch_function) {
277
      TORCH_INTERNAL_ASSERT(0);
278
    }
279
    if (torch_function.ptr() == torch::disabled_torch_dispatch_impl()) {
280
      // During __torch_dispatch__, don't dispatch on args with a disabled
281
      // torch_dispatch. This code runs before infra modes, so we need to make
282
      // sure that infra modes can run first. (In theory, maybe we can rearrange
283
      // things so that infra modes are *always* attempted first, and just
284
      // return NotImplemented when there are any user subclasses. Maybe that
285
      // would fix this problem?)
286
      continue;
287
    }
288

289
    // See https://github.com/pytorch/pytorch/issues/63767
290
    if (is_torch_function &&
291
        PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
292
            .is(py::handle(arg)) &&
293
        torch_function.ptr() != torch::disabled_torch_function_impl()) {
294
      TORCH_WARN(
295
          "Defining your `",
296
          torch_function_name_str,
297
          "` as a plain method is deprecated ",
298
          "and will be an error in future, please define it as a classmethod.");
299
    }
300

301
    ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
302
        torch_function.ptr(),
303
        torch_api_function,
304
        py_types.ptr(),
305
        args,
306
        kwargs,
307
        NULL));
308
    if (ret.ptr() == nullptr) {
309
      throw python_error();
310
    }
311
    if (ret.ptr() != Py_NotImplemented) {
312
      // Return the reference to the result. This also covers the case where
313
      // ret is NULL and __torch_function__/__torch_dispatch raised an
314
      // exception, which we throw below
315
      break;
316
    }
317
  }
318
  return ret;
319
}
320

321
static std::tuple<py::object, py::object> dispatch_on_mode(
322
    PyObject* args,
323
    PyObject* kwargs,
324
    py::tuple py_types,
325
    PyObject* torch_api_function,
326
    bool is_torch_function,
327
    const char* torch_function_name_str) {
328
  // Disable mode on the inside; this makes for a more user-friendly
329
  // experience if you try to, e.g., print your tensors.
330
  at::optional<torch::overrides::StashTorchFunctionModeGuard> tf_g;
331
  at::optional<torch_dispatch_mode::StashTorchDispatchModeGuard> td_g;
332
  py::object mode_obj;
333
  // NB: We only really need keep the mode_obj live if the function call
334
  // fails for error reporting, but whatever, Python refcounts are cheap
335
  if (is_torch_function) {
336
    tf_g.emplace();
337
    mode_obj = py::reinterpret_borrow<py::object>(
338
        tf_g->get_cur_mode()->ptr(getPyInterpreter()));
339
  } else {
340
    td_g.emplace();
341
    mode_obj = py::reinterpret_borrow<py::object>(
342
        td_g->get_cur_mode()->ptr(getPyInterpreter()));
343
  }
344
  py::object torch_function =
345
      PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str);
346
  if (!torch_function) {
347
    TORCH_INTERNAL_ASSERT(0);
348
  }
349
  TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr);
350
  TORCH_INTERNAL_ASSERT(args != nullptr);
351

352
  TORCH_CHECK(
353
      PyObject_FastGetAttrString(torch_function.ptr(), "__self__").is(mode_obj),
354
      "Defining your mode's `",
355
      torch_function_name_str,
356
      "` as a classmethod is not supported, please make it a plain method");
357

358
  // Blegh.  This accidentally works in PyObject_CallFunctionObjArgs below
359
  // because the nullptr terminates the argument list ick ick ick.
360
  py::object ret;
361
  if (kwargs == nullptr) {
362
    ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
363
        mode_obj.ptr(),
364
        torch_function_name_str,
365
        "OOO",
366
        torch_api_function,
367
        py_types.ptr(),
368
        args));
369
  } else {
370
    ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
371
        mode_obj.ptr(),
372
        torch_function_name_str,
373
        "OOOO",
374
        torch_api_function,
375
        py_types.ptr(),
376
        args,
377
        kwargs));
378
  }
379
  if (ret.ptr() == nullptr) {
380
    throw python_error();
381
  }
382
  return std::make_tuple(ret, mode_obj);
383
}
384

385
// See Note: [Overloaded args] for what they hold
386
auto handle_torch_function_no_python_arg_parser(
387
    at::ArrayRef<PyObject*> overloaded_args,
388
    PyObject* args,
389
    PyObject* kwargs,
390
    const char* func_name,
391
    PyObject* torch_api_function,
392
    const char* module_name,
393
    TorchFunctionName torch_function_name) -> PyObject* {
394
  const char* torch_function_name_str = nullptr;
395
  switch (torch_function_name) {
396
    case TorchFunctionName::TorchFunction:
397
      torch_function_name_str = "__torch_function__";
398
      break;
399
    case TorchFunctionName::TorchDispatch:
400
      torch_function_name_str = "__torch_dispatch__";
401
      break;
402
    default:
403
      TORCH_INTERNAL_ASSERT(0, static_cast<int>(torch_function_name));
404
  }
405
  // overloaded_args already all have unique types
406
  // nb: modes don't go in the overloaded types list, as they are not
407
  // necessarily types
408
  std::vector<py::object> overloaded_types;
409
  overloaded_types.reserve(overloaded_args.size());
410
  for (auto& arg : overloaded_args) {
411
    overloaded_types.push_back(
412
        py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg)));
413
  }
414
  py::tuple py_types = py::cast(overloaded_types);
415
  py::object ret;
416
  py::object mode_obj;
417

418
  // Step 1: Try to dispatch based on the mode stack, *ignoring* infra
419
  // torch_dispatch modes.
420
  const bool is_torch_function =
421
      torch_function_name == TorchFunctionName::TorchFunction;
422
  const auto is_mode_active = [&]() {
423
    return is_torch_function
424
        ? at::impl::torch_function_mode_enabled()
425
        // Check if any *user* torch_dispatch modes are active (not including
426
        // fake and proxy modes, which are special)
427
        : c10::impl::dispatch_mode_enabled();
428
  };
429
  // Note [__torch_dispatch__ dispatching order]
430
  // The high-level idea motivating the dispatching
431
  // order below is that: (1) modes get higher dispatch precedence over
432
  // subclasses (2) "user" modes/subclasses get higher dispatch precedence over
433
  // "infra" modes/subclasses.
434
  //
435
  // To give a complete example: let's say we are running torch.compile, with
436
  // the following "user" modes and subclasses:
437
  //   mode_stack: [ModeA]
438
  //   user_args: [MyWrapperSubclassB(torchTensor)]
439

440
  // During tracing in AOTAutograd tracing, we use some additional infra modes
441
  // and subclasses to perform tracing:
442
  //   FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode,
443
  //   FunctionalTensor, FakeTensor
444
  // The modified mode stack and tracing arguments will look like this:
445
  //   mode_stack (user modes): [ModeA]
446
  //   mode_stack (infra modes): [
447
  //     FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode
448
  //   ]
449
  //   tracing_args: [
450
  //     MyWrapperSubclassB(FunctionalTensor(_to_functional_tensor(FakeTensor)))
451
  //   ]
452

453
  // And the dispatching order that we want is as follows:
454
  // (1) ModeA.__torch_dispatch__ (user modes highest)
455
  // (2) MyWrapperSubclassB.__torch_dispatch__ (user subclasses next highest)
456
  // (3) FunctionalTensorMode.__torch_dispatch__ (infra modes next highest)
457
  // (4) ProxyTorchDispatchMode.__torch_dispatch__ (infra modes next highest)
458
  // (5) FakeTensorMode.__torch_dispatch__ (infra modes next highest)
459
  // (6) FakeTensor.__torch_fake_dispatch__ (infra subclasses next highest)
460

461
  // Why does do FunctionalTensor and FakeTensor even need to be special-cased
462
  // in the ordering?
463
  // In theory we could remove their __torch_dispatch__, but both of these
464
  // subclasses override sizes/strides metadata calls with __torch_dispatch__,
465
  // which would mean a mode would be **required** to access their metadata.
466

467
  if (is_mode_active()) {
468
    // Step 1: Try to dispatch on any user TorchDispatchModes (including infra
469
    // modes, which will always be at the bottom of the mode stack).
470
    auto ret_ = dispatch_on_mode(
471
        args,
472
        kwargs,
473
        py_types,
474
        torch_api_function,
475
        is_torch_function,
476
        torch_function_name_str);
477
    ret = std::get<0>(ret_);
478
    mode_obj = std::get<1>(ret_);
479
  }
480

481
  // Step 2: Try to dispatch based on any user subclasses,
482
  // ignoring any subclasses that have a _mode_key field
483
  // (corresponding to infra subclasses)
484
  // Note: user subclasses should always run *before* infra modes like
485
  // proxy/fake. This is handles by having proxy/fake modes return
486
  // NotImplemented when they see a user subclass that they don't understand.
487
  if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) {
488
    auto curr_ret = dispatch_on_subclass(
489
        args,
490
        kwargs,
491
        overloaded_args,
492
        py_types,
493
        torch_api_function,
494
        is_torch_function,
495
        torch_function_name_str);
496
    if (curr_ret.ptr() != nullptr) {
497
      ret = curr_ret;
498
    }
499
  }
500

501
  if (ret.ptr() == nullptr) {
502
    // if an exception occurred in a user's implementation of
503
    // __torch_function__, throw it
504
    throw python_error();
505
  } else if (ret.ptr() == Py_NotImplemented) {
506
    // all __torch_function__ implementations in overloaded_args
507
    // returned NotImplemented, so we raise a TypeError.
508
    std::stringstream ss;
509
    ss << "Multiple dispatch failed for '";
510
    if (module_name && func_name) {
511
      ss << module_name << "." << func_name;
512
    } else {
513
      py::handle fn = torch_api_function;
514
      ss << py::str(fn.attr("__module__")) << "."
515
         << py::str(fn.attr("__name__"));
516
    }
517
    ss << "'; all " << torch_function_name_str
518
       << " handlers returned NotImplemented:\n\n";
519
    if (mode_obj) {
520
      ss << "  - mode object " << py::repr(mode_obj) << "\n";
521
    }
522
    for (auto& arg : overloaded_args) {
523
      ss << "  - tensor subclass " << py::repr(get_type_of_overloaded_arg(arg))
524
         << "\n";
525
    }
526
    ss << "\nFor more information, try re-running with TORCH_LOGS=not_implemented";
527
    const std::string& tmp = ss.str();
528
    PyErr_SetString(PyExc_TypeError, tmp.c_str());
529
    throw python_error();
530
  }
531
  return ret.release().ptr();
532
}
533

534
auto handle_torch_function(
535
    PythonArgs& r,
536
    PyObject* self,
537
    PyObject* args,
538
    PyObject* kwargs,
539
    PyObject* torch_api,
540
    const char* module_name,
541
    const char* func_name_override) -> PyObject* {
542
  py::object torch_api_function = PyObject_FastGetAttrString(
543
      torch_api,
544
      (char*)(func_name_override ? func_name_override
545
                                 : r.get_func_name().c_str()));
546
  TORCH_INTERNAL_ASSERT(
547
      torch_api_function.ptr() != nullptr, "torch API function must exist");
548
  py::tuple args_ = combine_self_args(self, args);
549
  return handle_torch_function_no_python_arg_parser(
550
      r.overloaded_args,
551
      args_.ptr(),
552
      kwargs,
553
      r.get_func_name().c_str(),
554
      torch_api_function.ptr(),
555
      module_name);
556
}
557

558
auto handle_torch_function(
559
    PythonArgs& r,
560
    PyObject* args,
561
    PyObject* kwargs,
562
    PyObject* torch_api,
563
    const char* module_name,
564
    const char* func_name_override) -> PyObject* {
565
  return handle_torch_function(
566
      r, nullptr, args, kwargs, torch_api, module_name, func_name_override);
567
}
568

569
auto handle_torch_function_indexing(
570
    PyObject* self,
571
    PyObject* index,
572
    PyObject* val) -> PyObject* {
573
  const char* func_name = (val == nullptr) ? "__getitem__" : "__setitem__";
574
  py::object index_tup;
575
  if (PyTuple_Check(index)) {
576
    index_tup = py::reinterpret_borrow<py::object>(index);
577
  } else {
578
    index_tup = py::make_tuple(py::handle(index));
579
  }
580
  std::vector<PyObject*> overridable_args;
581
  is_tensor_and_append_overloaded(self, &overridable_args);
582
  auto size = PyTuple_GET_SIZE(index_tup.ptr());
583
  for (auto i : c10::irange(size)) {
584
    auto* obj = PyTuple_GetItem(index_tup.ptr(), i);
585
    is_tensor_and_append_overloaded(obj, &overridable_args);
586
  }
587
  if (val != nullptr) {
588
    is_tensor_and_append_overloaded(val, &overridable_args);
589
  }
590
  py::object func =
591
      PyObject_FastGetAttrString(THPVariableClass, (char*)func_name);
592
  py::object args = (val == nullptr)
593
      ? py::make_tuple(py::handle(self), py::handle(index))
594
      : py::make_tuple(py::handle(self), py::handle(index), py::handle(val));
595
  return handle_torch_function_no_python_arg_parser(
596
      overridable_args,
597
      args.ptr(),
598
      nullptr,
599
      func_name,
600
      func.ptr(),
601
      "torch.Tensor");
602
}
603

604
/*
605
 *  obj has a __torch_function__ implementation and may either be a
606
 *  subclass of Tensor or a Tensor-like duck type. We may need to
607
 *  append this object to the overloaded_args vector, which tracks all
608
 *  of the arguments with distinct __torch_function__ implementations
609
 *  we've seen so far.
610
 *
611
 *  If this is the first argument we've seen with __torch_function__
612
 *  defined, we unconditionally add obj to the overloaded_args vector.
613
 *
614
 *  If we've already seen arguments with __torch_function__ defined,
615
 *  then we first need to check if obj is the same type as any of the
616
 *  entries in overloaded_args.  If so, we can ignore obj since we
617
 *  already have an entry in overloaded_args with the same
618
 *  __torch_function__ implementation.
619
 *
620
 *  If it's a different type, we then need to check if it's a subclass
621
 *  of one of the types we've already seen. If so, we need to insert an
622
 *  entry in overloaded_args for this type with higher precedence than
623
 *  the superclass.
624
 *
625
 *  See torch._overrides._get_overloaded_args for the equivalent
626
 *  function in the Python __torch_function__ implementation.
627
 *
628
 *  The precedence-determining algorithm implemented in this function is
629
 *  described in NEP-0018:
630
 *  https://numpy.org/neps/nep-0018-array-function-protocol.html
631
 *
632
 *  'overloaded_args' is a raw pointer to a vector of pybind11 handles
633
 *  that have distinct __torch_function__ implementations, in order of calling
634
 *  precedence.
635
 *
636
 *  'obj' is an object to check for a __torch_function__ implementation
637
 *
638
 * If changing this file in a way that can affect the __torch_function__
639
 * overhead, please report the benchmarks in 'benchmarks/overrides_benchmark'.
640
 * See the instructions in the 'README.md' in that directory.
641
 *
642
 */
643

644
static void append_overloaded_arg(
645
    std::vector<PyObject*>* overloaded_args,
646
    PyObject* obj,
647
    bool obj_is_type) {
648
  bool class_not_seen_yet = true;
649
  PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
650
  for (auto& arg : *overloaded_args) {
651
    if (obj_type == get_type_of_overloaded_arg(arg)) {
652
      // obj is the same type as another parameter we've seen in a prior
653
      // iteration of the loop over parameters so we already have an entry
654
      // with the proper __torch_function__ implementation to call, so skip
655
      // this parameter
656
      class_not_seen_yet = false;
657
      break;
658
    }
659
  }
660
  if (class_not_seen_yet) {
661
    auto arg_index = overloaded_args->size();
662
    for (const auto j : c10::irange(arg_index)) {
663
      if (PyObject_IsSubclass(
664
              obj_type, get_type_of_overloaded_arg((*overloaded_args)[j]))) {
665
        // obj is a subclass of another object we've seen already so its
666
        // __torch_function__ should be called first, therefore we
667
        // insert it into overloaded_args before the superclass
668
        arg_index = j;
669
        break;
670
      }
671
    }
672
    // add object to overloaded_args. If it's a subclass of another class
673
    // we've already seen it will be inserted before the superclass,
674
    // otherwise it will be inserted at the end of the array
675
    overloaded_args->insert(
676
        overloaded_args->begin() + static_cast<long>(arg_index), obj);
677
  }
678
}
679

680
void append_overloaded_tensor(
681
    std::vector<PyObject*>* overloaded_args,
682
    PyObject* obj) {
683
  append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ false);
684
}
685

686
void append_overloaded_type(
687
    std::vector<PyObject*>* overloaded_args,
688
    PyObject* obj) {
689
  append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ true);
690
}
691

692
bool is_tensor_and_append_overloaded(
693
    PyObject* obj,
694
    std::vector<PyObject*>* overloaded_args) {
695
  if (THPVariable_CheckExact(obj)) {
696
    // torch.Tensor instances (not subclasses, except for Parameter)
697
    return true;
698
  }
699

700
  if (check_has_torch_function(obj, /*ignore_mode*/ true)) {
701
    // tensor subclasses and unrelated objects with __torch_function__
702
    append_overloaded_tensor(overloaded_args, obj);
703
    return true;
704
  } else if (THPVariable_Check(obj)) {
705
    // tensor subclasses without __torch_function__
706
    return true;
707
  }
708

709
  return false;
710
}
711

712
static bool is_scalar_list(PyObject* obj) {
713
  auto tuple = six::isTuple(obj);
714
  if (!(tuple || PyList_Check(obj))) {
715
    return false;
716
  }
717
  // NOLINTNEXTLINE(bugprone-branch-clone)
718
  const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
719
  for (const auto idx : c10::irange(size)) {
720
    PyObject* iobj =
721
        tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
722
    if (!THPUtils_checkScalar(iobj)) {
723
      return false;
724
    }
725
  }
726
  return true;
727
}
728

729
bool is_tensor_list_and_append_overloaded(
730
    PyObject* obj,
731
    std::vector<PyObject*>* overloaded_args,
732
    int argnum,
733
    bool throw_error) {
734
  auto tuple = six::isTuple(obj);
735
  if (!(tuple || PyList_Check(obj))) {
736
    return false;
737
  }
738
  // NOLINTNEXTLINE(bugprone-branch-clone)
739
  const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
740
  for (long idx = 0; idx < size; idx++) {
741
    PyObject* iobj =
742
        tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
743
    if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) {
744
      if (throw_error) {
745
        TORCH_CHECK_TYPE(
746
            false,
747
            "expected Tensor as element ",
748
            idx,
749
            " in argument ",
750
            argnum,
751
            ", but got ",
752
            Py_TYPE(iobj)->tp_name);
753
      }
754
      return false;
755
    }
756
  }
757
  return true;
758
}
759

760
static bool is_float_or_complex_list(PyObject* obj) {
761
  auto tuple = six::isTuple(obj);
762
  if (!(tuple || PyList_Check(obj))) {
763
    return false;
764
  }
765

766
  // NOLINTNEXTLINE(bugprone-branch-clone)
767
  const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
768
  if (size > 0) {
769
    PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
770
    if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
771
      return false;
772
    }
773
  }
774

775
  return true;
776
}
777

778
static bool is_int_or_symint(PyObject* obj) {
779
  // THPUtils_checkIndex may call __index__ or __int__
780
  // which may have side effects if obj is a symint node
781
  // so we do `is_symint` check first
782
  // TODO: maybe we should be using checkLong here?
783
  if (torch::is_symint(py::handle(obj))) {
784
    return true;
785
  }
786

787
  // FakeTensor(..., size=()) is qualified for SymInt param,
788
  // but we can't go via __index__ (below) as we would normally
789
  // do for regular tensors, because __index__ first forces a
790
  // conversion into an int, which in general you cannot do
791
  // if you have an unbacked SymInt.  So this fastpath ensures
792
  // that we still allow for fake tensors in this case, but
793
  // for regular tensors it's redundant with the test below.
794
  if (THPVariable_Check(obj)) {
795
    auto& var = THPVariable_Unpack(obj);
796
    if (var.numel() == 1 &&
797
        at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
798
      return true;
799
    }
800
  }
801

802
  if (THPUtils_checkIndex(obj)) {
803
    return true;
804
  }
805

806
  return false;
807
}
808

809
static bool is_int_or_symint_list(
810
    PyObject* obj,
811
    int broadcast_size,
812
    int64_t* failed_idx = nullptr) {
813
  if (PyTuple_Check(obj) || PyList_Check(obj)) {
814
    if (PySequence_Size(obj) == 0) {
815
      return true;
816
    }
817
    auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
818

819
    if (is_int_or_symint(item.ptr())) {
820
      return true;
821
    }
822

823
    // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
824
    // in an intlist argument. Even float or complex scalar tensors.
825
    bool r =
826
        (jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
827
         THPVariable_Unpack(item.ptr()).sizes().empty());
828
    if (!r && failed_idx != nullptr) {
829
      *failed_idx = 0;
830
    }
831
    return r;
832
  }
833

834
  // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
835
  // int
836
  return broadcast_size > 0 && is_int_or_symint(obj);
837
}
838

839
// argnum is needed for raising the TypeError, it's used in the error message.
840
auto FunctionParameter::check(
841
    PyObject* obj,
842
    std::vector<PyObject*>& overloaded_args,
843
    int argnum,
844
    int64_t* failed_idx) -> bool {
845
  switch (type_) {
846
    case ParameterType::TENSOR: {
847
      if (is_tensor_and_append_overloaded(obj, &overloaded_args)) {
848
        return true;
849
      }
850
      if (allow_numbers_as_tensors) {
851
        return THPUtils_checkScalar(obj);
852
      }
853
      return false;
854
    }
855
    case ParameterType::SCALAR:
856
      if (THPUtils_checkScalar(obj)) {
857
        return true;
858
      }
859
      [[fallthrough]];
860
    case ParameterType::COMPLEX:
861
      if (PyComplex_Check(obj)) {
862
        return true;
863
      }
864
      [[fallthrough]];
865
    case ParameterType::DOUBLE: {
866
      if (THPUtils_checkDouble(obj)) {
867
        return true;
868
      }
869
      if (THPVariable_Check(obj)) {
870
        const auto& var = THPVariable_Unpack(obj);
871
        return !var.requires_grad() && var.dim() == 0;
872
      }
873
      if (torch::is_symfloat(py::handle(obj)) ||
874
          torch::is_symint(py::handle(obj))) {
875
        // This will induce a guard
876
        return true;
877
      }
878
      return false;
879
    }
880
    case ParameterType::INT64: {
881
      if (THPUtils_checkLong(obj)) {
882
        return true;
883
      }
884
      if (THPVariable_Check(obj)) {
885
        const auto& var = THPVariable_Unpack(obj);
886
        return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) &&
887
            !var.requires_grad() && var.dim() == 0;
888
      }
889
      if (torch::is_symint(py::handle(obj))) {
890
        // This will induce a guard
891
        return true;
892
      }
893
      return false;
894
    }
895
    case ParameterType::DIMNAME:
896
      return THPUtils_checkDimname(obj);
897
    case ParameterType::DIMNAME_LIST: {
898
      if (THPUtils_checkDimnameList(obj)) {
899
        return true;
900
      }
901
      // if a size is specified (e.g. DimnameList[1]) we also allow passing a
902
      // single Dimname
903
      return size == 1 && THPUtils_checkDimname(obj);
904
    }
905
    case ParameterType::TENSOR_LIST: {
906
      return is_tensor_list_and_append_overloaded(
907
          obj, &overloaded_args, argnum, true /* throw_error */);
908
    }
909
    case ParameterType::FLOAT_LIST:
910
      return is_float_or_complex_list(obj);
911
    case ParameterType::GENERATOR:
912
      return THPGenerator_Check(obj);
913
    case ParameterType::BOOL:
914
      return PyBool_Check(obj);
915
    case ParameterType::STORAGE:
916
      return isStorage(obj);
917
    case ParameterType::PYOBJECT:
918
      return true;
919
    case ParameterType::SCALARTYPE:
920
      return THPDtype_Check(obj) || THPPythonScalarType_Check(obj);
921
    case ParameterType::LAYOUT:
922
      return THPLayout_Check(obj);
923
    case ParameterType::MEMORY_FORMAT:
924
      return THPMemoryFormat_Check(obj);
925
    case ParameterType::QSCHEME:
926
      return THPQScheme_Check(obj);
927
    case ParameterType::DEVICE:
928
      return THPUtils_checkLong(obj) || THPUtils_checkString(obj) ||
929
          THPDevice_Check(obj);
930
    case ParameterType::STREAM:
931
      return THPStream_Check(obj);
932
    case ParameterType::STRING:
933
      return THPUtils_checkString(obj);
934
    case ParameterType::SCALAR_LIST:
935
      return is_scalar_list(obj);
936
    case ParameterType::SYM_INT:
937
      return is_int_or_symint(obj);
938
    // Allow SymInt where int is expected; we'll guard in this case
939
    case ParameterType::INT_LIST:
940
    case ParameterType::SYM_INT_LIST:
941
      return is_int_or_symint_list(obj, size, failed_idx);
942
    case ParameterType::DISPATCH_KEY_SET:
943
      return py::isinstance<c10::DispatchKeySet>(py::handle(obj));
944
    default:
945
      throw std::runtime_error("unknown parameter type");
946
  }
947
}
948

949
// WARNING: these strings are parsed invalid_arguments.cpp
950
std::string FunctionParameter::type_name() const {
951
  switch (type_) {
952
    case ParameterType::TENSOR:
953
      return "Tensor";
954
    case ParameterType::SCALAR:
955
      return "Number";
956
    case ParameterType::INT64:
957
    // NB: SymInt is intentionally not mentioned here, as conventional user
958
    // use will only know about ints
959
    case ParameterType::SYM_INT:
960
      return "int";
961
    case ParameterType::DOUBLE:
962
      return "float";
963
    case ParameterType::COMPLEX:
964
      return "complex";
965
    case ParameterType::TENSOR_LIST:
966
      return "tuple of Tensors";
967
    case ParameterType::INT_LIST:
968
      return "tuple of ints";
969
    case ParameterType::FLOAT_LIST:
970
      return "tuple of floats";
971
    case ParameterType::GENERATOR:
972
      return "torch.Generator";
973
    case ParameterType::BOOL:
974
      return "bool";
975
    case ParameterType::STORAGE:
976
      return "torch.Storage";
977
    case ParameterType::PYOBJECT:
978
      return "object";
979
    case ParameterType::SCALARTYPE:
980
      return "torch.dtype";
981
    case ParameterType::LAYOUT:
982
      return "torch.layout";
983
    case ParameterType::MEMORY_FORMAT:
984
      return "torch.memory_format";
985
    case ParameterType::QSCHEME:
986
      return "torch.qscheme";
987
    case ParameterType::DEVICE:
988
      return "torch.device";
989
    case ParameterType::STRING:
990
      return "str";
991
    case ParameterType::DIMNAME:
992
      return "name";
993
    case ParameterType::DIMNAME_LIST:
994
      return "tuple of names";
995
    case ParameterType::SCALAR_LIST:
996
      return "tuple of Scalars";
997
    case ParameterType::SYM_INT_LIST:
998
      return "tuple of ints";
999
    case ParameterType::DISPATCH_KEY_SET:
1000
      return "DispatchKeySet";
1001
    default:
1002
      throw std::runtime_error("unknown parameter type");
1003
  }
1004
}
1005

1006
static inline c10::optional<int64_t> parse_as_integer(const std::string& s) {
1007
  if (s.empty())
1008
    return c10::nullopt;
1009
  char* str_end = nullptr;
1010
  long ans = strtol(s.c_str(), &str_end, 0);
1011
  // *str_end == 0 if the entire string was parsed as an integer.
1012
  return (*str_end == 0) ? c10::optional<int64_t>(ans) : c10::nullopt;
1013
}
1014

1015
/*
1016
Parse default value of IntArrayRef declared at native_functions.yaml
1017

1018
There are two kinds of default values:
1019
1. IntArrayRef[2] x=1 (where size=2, value={1,1}
1020
2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be
1021
space after comma since native_parse.py uses ', ' to split args)
1022
*/
1023
static inline std::vector<int64_t> parse_intlist_args(
1024
    const std::string& s,
1025
    int64_t size) {
1026
  size_t n = s.size();
1027

1028
  if (s.empty())
1029
    return std::vector<int64_t>();
1030

1031
  // case 1. s is an int (e.g., s=2)
1032
  if (s[0] != '{') {
1033
    TORCH_CHECK(size > 0, "Incorrect size of IntArrayRef: ", size);
1034
    return std::vector<int64_t>(size, std::stol(s));
1035
  }
1036

1037
  // case 2. s is a list of dims (e.g., s={1,2})
1038

1039
  // since already checked left brace '{' above, here only checks right brace
1040
  // '}'
1041
  TORCH_CHECK(
1042
      s[n - 1] == '}',
1043
      "Default value of IntArrayRef is missing right brace '}', found ",
1044
      s[n - 1]);
1045

1046
  auto args = std::vector<int64_t>();
1047
  std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
1048
  std::string tok;
1049

1050
  while (std::getline(ss, tok, ',')) {
1051
    args.emplace_back(std::stol(tok));
1052
  }
1053
  return args;
1054
}
1055

1056
// Parse a string literal to remove quotes and escape sequences
1057
static std::string parse_string_literal(c10::string_view str) {
1058
  TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
1059

1060
  if (str.front() == '"') {
1061
    TORCH_CHECK(
1062
        str.back() == '"', "Mismatched quotes in string default: ", str);
1063
  } else {
1064
    TORCH_CHECK(
1065
        str.front() == '\'' && str.back() == '\'',
1066
        "Invalid quotes in string default: ",
1067
        str)
1068
  }
1069

1070
  std::string parsed;
1071
  parsed.reserve(str.size());
1072
  for (size_t i = 1; i < str.size() - 1;) {
1073
    if (str[i] != '\\') {
1074
      parsed.push_back(str[i]);
1075
      ++i;
1076
      continue;
1077
    }
1078

1079
    // Handle escape sequences
1080
    TORCH_CHECK(
1081
        i < str.size() - 2, "String ends with escaped final quote: ", str)
1082
    char c = str[i + 1];
1083
    switch (c) {
1084
      case '\\':
1085
      case '\'':
1086
      case '\"':
1087
        break;
1088
      case 'a':
1089
        c = '\a';
1090
        break;
1091
      case 'b':
1092
        c = '\b';
1093
        break;
1094
      case 'f':
1095
        c = '\f';
1096
        break;
1097
      case 'n':
1098
        c = '\n';
1099
        break;
1100
      case 'v':
1101
        c = '\v';
1102
        break;
1103
      case 't':
1104
        c = '\t';
1105
        break;
1106
      default:
1107
        TORCH_CHECK(
1108
            false,
1109
            "Unsupported escape sequence in string default: \\",
1110
            str[i + 1]);
1111
    }
1112
    parsed.push_back(c);
1113
    i += 2;
1114
  }
1115
  return parsed;
1116
}
1117

1118
void FunctionParameter::set_default_str(const std::string& str) {
1119
  if (str == "None") {
1120
    allow_none = true;
1121
  }
1122
  if (type_ == ParameterType::TENSOR ||
1123
      type_ == ParameterType::DISPATCH_KEY_SET) {
1124
    if (str != "None") {
1125
      throw std::runtime_error(
1126
          "default value for Tensor must be none, got: " + str);
1127
    }
1128
  } else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
1129
    default_int = atol(str.c_str());
1130
  } else if (type_ == ParameterType::BOOL) {
1131
    default_bool = (str == "True" || str == "true");
1132
  } else if (type_ == ParameterType::DOUBLE) {
1133
    default_double = atof(str.c_str());
1134
  } else if (type_ == ParameterType::COMPLEX) {
1135
    default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"?
1136
    default_complex[1] = 0;
1137
  } else if (type_ == ParameterType::SCALAR) {
1138
    if (str != "None") {
1139
      // we sometimes rely on integer-vs-float values, e.g. with arange.
1140
      const auto as_integer = parse_as_integer(str);
1141
      default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value())
1142
                                              : at::Scalar(atof(str.c_str()));
1143
    }
1144
  } else if (
1145
      type_ == ParameterType::INT_LIST ||
1146
      type_ == ParameterType::SYM_INT_LIST) {
1147
    if (str != "None") {
1148
      default_intlist = parse_intlist_args(str, size);
1149
    }
1150
  } else if (type_ == ParameterType::FLOAT_LIST) {
1151
    if (str != "None") {
1152
      throw std::runtime_error("Defaults not supported for float[]");
1153
    }
1154
  } else if (type_ == ParameterType::SCALARTYPE) {
1155
    if (str == "None") {
1156
      default_scalartype = at::ScalarType::Undefined;
1157
    } else if (str == "torch.int64") {
1158
      default_scalartype = at::ScalarType::Long;
1159
    } else {
1160
      throw std::runtime_error("invalid default value for ScalarType: " + str);
1161
    }
1162
  } else if (type_ == ParameterType::LAYOUT) {
1163
    if (str == "None") {
1164
      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
1165
    } else if (str == "torch.strided") {
1166
      default_layout = at::Layout::Strided;
1167
    } else if (str == "torch.sparse_coo") {
1168
      default_layout = at::Layout::Sparse;
1169
    } else {
1170
      throw std::runtime_error("invalid default value for layout: " + str);
1171
    }
1172
  } else if (type_ == ParameterType::DEVICE) {
1173
    if (str != "None") {
1174
      throw std::runtime_error("invalid device: " + str);
1175
    }
1176
  } else if (type_ == ParameterType::STREAM) {
1177
    if (str != "None") {
1178
      throw std::runtime_error("invalid stream: " + str);
1179
    }
1180
  } else if (type_ == ParameterType::STRING) {
1181
    if (str != "None") {
1182
      default_string = parse_string_literal(str);
1183
    }
1184
  }
1185
  // These types weren't handled here before. Adding a default error
1186
  // led to a lot of test failures so adding this skip for now.
1187
  // We should correctly handle these though because it might be causing
1188
  // silent failures.
1189
  else if (type_ == ParameterType::TENSOR_LIST) { // NOLINT
1190
    // throw std::runtime_error("Invalid Tensor List");
1191
  } else if (type_ == ParameterType::GENERATOR) { // NOLINT
1192
    // throw std::runtime_error("ParameterType::GENERATOR");
1193
  } else if (type_ == ParameterType::PYOBJECT) { // NOLINT
1194
    // throw std::runtime_error("ParameterType::PYOBJECT");
1195
  } else if (type_ == ParameterType::MEMORY_FORMAT) { // NOLINT
1196
    // throw std::runtime_error("ParameterType::MEMORY_FORMAT");
1197
  } else if (type_ == ParameterType::DIMNAME) { // NOLINT
1198
    // throw std::runtime_error("ParameterType::DIMNAME");
1199
  } else if (type_ == ParameterType::DIMNAME_LIST) { // NOLINT
1200
    // throw std::runtime_error("ParameterType::DIMNAME_LIST");
1201
  } else if (type_ == ParameterType::SCALAR_LIST) { // NOLINT
1202
    // throw std::runtime_error("ParameterType::SCALAR_LIST");
1203
  } else if (type_ == ParameterType::STORAGE) { // NOLINT
1204
    // throw std::runtime_error("ParameterType::STORAGE");
1205
  } else if (type_ == ParameterType::QSCHEME) { // NOLINT
1206
    // throw std::runtime_error("ParameterType::QSCHEME");
1207
  } else {
1208
    throw std::runtime_error("unknown parameter type");
1209
  }
1210
}
1211

1212
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1213
FunctionSignature::FunctionSignature(const std::string& fmt, int index)
1214
    : min_args(0),
1215
      max_args(0),
1216
      max_pos_args(0),
1217
      index(index),
1218
      hidden(false),
1219
      deprecated(false) {
1220
  auto open_paren = fmt.find('(');
1221
  if (open_paren == std::string::npos) {
1222
    throw std::runtime_error("missing opening parenthesis: " + fmt);
1223
  }
1224
  name = fmt.substr(0, open_paren);
1225

1226
  bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
1227

1228
  auto last_offset = open_paren + 1;
1229
  bool keyword_only = false;
1230
  bool done = false;
1231
  while (!done) {
1232
    auto offset = fmt.find(", ", last_offset);
1233
    auto next_offset = offset + 2;
1234
    if (offset == std::string::npos) {
1235
      offset = fmt.find(')', last_offset);
1236
      done = true;
1237
      next_offset = offset + 1;
1238
      // this 'if' happens for an empty parameter list, i.e. fn().
1239
      if (offset == last_offset) {
1240
        last_offset = next_offset;
1241
        break;
1242
      }
1243
    }
1244
    if (offset == std::string::npos) {
1245
      throw std::runtime_error("missing closing parenthesis: " + fmt);
1246
    }
1247
    if (offset == last_offset) {
1248
      throw std::runtime_error("malformed signature: " + fmt);
1249
    }
1250

1251
    auto param_str = fmt.substr(last_offset, offset - last_offset);
1252
    last_offset = next_offset;
1253
    if (param_str == "*") {
1254
      keyword_only = true;
1255
    } else {
1256
      params.emplace_back(param_str, keyword_only);
1257
      params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
1258
    }
1259
  }
1260

1261
  if (fmt.substr(last_offset) == "|deprecated") {
1262
    hidden = true;
1263
    // TODO: raise warning when parsing deprecated signatures
1264
    deprecated = true;
1265
  } else if (fmt.substr(last_offset) == "|hidden") {
1266
    hidden = true;
1267
  }
1268

1269
  max_args = params.size();
1270

1271
  // count the number of non-optional args
1272
  for (auto& param : params) {
1273
    if (!param.optional) {
1274
      min_args++;
1275
    }
1276
    if (!param.keyword_only) {
1277
      max_pos_args++;
1278
    }
1279
  }
1280
}
1281

1282
std::string FunctionSignature::toString() const {
1283
  // TODO: consider printing more proper schema strings with defaults,
1284
  // optionals, etc.
1285
  std::ostringstream ss;
1286
  bool keyword_already = false;
1287
  ss << "(";
1288
  int i = 0;
1289
  for (auto& param : params) {
1290
    if (i != 0) {
1291
      ss << ", ";
1292
    }
1293
    if (param.keyword_only && !keyword_already) {
1294
      ss << "*, ";
1295
      keyword_already = true;
1296
    }
1297
    ss << param.type_name() << " " << param.name;
1298
    i++;
1299
  }
1300
  ss << ")";
1301
  return ss.str();
1302
}
1303

1304
[[noreturn]] static void extra_args(
1305
    const FunctionSignature& signature,
1306
    Py_ssize_t nargs) {
1307
  const auto max_pos_args = signature.max_pos_args;
1308
  const auto min_args = signature.min_args;
1309
  const long nargs_ = nargs;
1310
  if (min_args != max_pos_args) {
1311
    throw TypeError(
1312
        "%s() takes from %zu to %zu positional arguments but %ld were given",
1313
        signature.name.c_str(),
1314
        min_args,
1315
        max_pos_args,
1316
        nargs_);
1317
  }
1318
  throw TypeError(
1319
      "%s() takes %zu positional argument%s but %ld %s given",
1320
      signature.name.c_str(),
1321
      max_pos_args,
1322
      max_pos_args == 1 ? "" : "s",
1323
      nargs_,
1324
      nargs == 1 ? "was" : "were");
1325
}
1326

1327
[[noreturn]] static void missing_args(
1328
    const FunctionSignature& signature,
1329
    int idx) {
1330
  int num_missing = 0;
1331
  std::stringstream ss;
1332

1333
  auto& params = signature.params;
1334
  for (auto it = params.begin() + idx; it != params.end(); ++it) {
1335
    if (!it->optional) {
1336
      if (num_missing > 0) {
1337
        ss << ", ";
1338
      }
1339
      ss << '"' << it->name << '"';
1340
      num_missing++;
1341
    }
1342
  }
1343

1344
  throw TypeError(
1345
      "%s() missing %d required positional argument%s: %s",
1346
      signature.name.c_str(),
1347
      num_missing,
1348
      num_missing == 1 ? "s" : "",
1349
      ss.str().c_str());
1350
}
1351

1352
static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) {
1353
  Py_ssize_t i = 0;
1354
  for (auto& param : signature.params) {
1355
    int cmp = PyObject_RichCompareBool(name, param.python_name, Py_EQ);
1356
    if (cmp < 0) {
1357
      throw python_error();
1358
    } else if (cmp) {
1359
      return i;
1360
    }
1361
    i++;
1362
  }
1363
  return -1;
1364
}
1365

1366
[[noreturn]] static void extra_kwargs(
1367
    FunctionSignature& signature,
1368
    PyObject* kwargs,
1369
    Py_ssize_t num_pos_args) {
1370
  PyObject* key = nullptr;
1371
  PyObject* value = nullptr;
1372
  Py_ssize_t pos = 0;
1373

1374
  while (PyDict_Next(kwargs, &pos, &key, &value)) {
1375
    if (!THPUtils_checkString(key)) {
1376
      throw TypeError("keywords must be strings");
1377
    }
1378

1379
    auto param_idx = find_param(signature, key);
1380
    if (param_idx < 0) {
1381
      throw TypeError(
1382
          "%s() got an unexpected keyword argument '%s'",
1383
          signature.name.c_str(),
1384
          THPUtils_unpackString(key).c_str());
1385
    }
1386

1387
    if (param_idx < num_pos_args) {
1388
      throw TypeError(
1389
          "%s() got multiple values for argument '%s'",
1390
          signature.name.c_str(),
1391
          THPUtils_unpackString(key).c_str());
1392
    }
1393
  }
1394

1395
  // this should never be hit
1396
  throw TypeError("invalid keyword arguments");
1397
}
1398

1399
bool FunctionSignature::parse(
1400
    PyObject* self,
1401
    PyObject* args,
1402
    PyObject* kwargs,
1403
    PyObject* dst[], // NOLINT
1404
    std::vector<PyObject*>& overloaded_args,
1405
    bool raise_exception) {
1406
  Py_ssize_t nargs = args ? PyTuple_GET_SIZE(args) : 0;
1407
  auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
1408
  size_t arg_pos = 0;
1409
  bool allow_varargs_intlist = false;
1410

1411
  // if there is a single positional IntArrayRef argument, i.e. expand(..),
1412
  // view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as
1413
  // expand((5,3))
1414
  if (max_pos_args == 1 &&
1415
      (params[0].type_ == ParameterType::INT_LIST ||
1416
       params[0].type_ == ParameterType::SYM_INT_LIST)) {
1417
    allow_varargs_intlist = true;
1418
  }
1419

1420
  if (static_cast<size_t>(nargs) > max_pos_args && !allow_varargs_intlist) {
1421
    if (raise_exception) {
1422
      // foo() takes takes 2 positional arguments but 3 were given
1423
      extra_args(*this, nargs);
1424
    }
1425
    return false;
1426
  }
1427

1428
  int i = 0;
1429
  if (self != nullptr && check_has_torch_function(self, /*ignore_mode*/ true)) {
1430
    append_overloaded_tensor(&overloaded_args, self);
1431
  }
1432
  for (auto& param : params) {
1433
    PyObject* obj = nullptr;
1434
    bool is_kwd = false;
1435
    if (arg_pos < static_cast<size_t>(nargs)) {
1436
      // extra positional args given after single positional IntArrayRef arg
1437
      if (param.keyword_only) {
1438
        if (raise_exception) {
1439
          extra_args(*this, nargs);
1440
        }
1441
        return false;
1442
      }
1443
      obj = PyTuple_GET_ITEM(args, arg_pos);
1444
    } else if (kwargs) {
1445
      obj = PyDict_GetItem(kwargs, param.python_name);
1446
      for (PyObject* numpy_name : param.numpy_python_names) {
1447
        if (obj) {
1448
          break;
1449
        }
1450
        obj = PyDict_GetItem(kwargs, numpy_name);
1451
      }
1452
      is_kwd = true;
1453
    }
1454

1455
    int64_t failed_idx = -1;
1456
    bool varargs_eligible = allow_varargs_intlist && arg_pos == 0 && !is_kwd;
1457
    if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
1458
      dst[i++] = nullptr;
1459
    } else if (!obj) {
1460
      if (raise_exception) {
1461
        // foo() missing 1 required positional argument: "b"
1462
        missing_args(*this, i);
1463
      }
1464
      return false;
1465
    } else if (param.check(obj, overloaded_args, i, &failed_idx)) {
1466
      dst[i++] = obj;
1467
      // XXX: the Variable check is necessary because sizes become tensors when
1468
      // tracer is enabled. This behavior easily leads to ambiguities, and we
1469
      // should avoid having complex signatures that make use of it...
1470
    } else if (
1471
        varargs_eligible &&
1472
        (is_int_or_symint_list(args, param.size, &failed_idx))) {
1473
      // take all positional arguments as this parameter
1474
      // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
1475
      dst[i++] = args;
1476
      arg_pos = nargs;
1477
      continue;
1478
    } else if (raise_exception) {
1479
      if (is_kwd) {
1480
        // foo(): argument 'other' must be str, not int
1481
        throw TypeError(
1482
            "%s(): argument '%s' must be %s, not %s",
1483
            name.c_str(),
1484
            param.name.c_str(),
1485
            param.type_name().c_str(),
1486
            Py_TYPE(obj)->tp_name);
1487
      } else {
1488
        // foo(): argument 'other' (position 2) must be str, not int
1489
        if (failed_idx != -1) {
1490
          if (!(PyTuple_Check(obj) || PyList_Check(obj))) {
1491
            TORCH_INTERNAL_ASSERT(varargs_eligible);
1492
            obj = args;
1493
          }
1494
          TORCH_INTERNAL_ASSERT(failed_idx < PySequence_Size(obj));
1495
          throw TypeError(
1496
              "%s(): argument '%s' (position %ld) must be %s, but found element of type %s at pos %ld",
1497
              name.c_str(),
1498
              param.name.c_str(),
1499
              static_cast<long>(arg_pos + 1),
1500
              param.type_name().c_str(),
1501
              Py_TYPE(py::reinterpret_steal<py::object>(
1502
                          PySequence_GetItem(obj, failed_idx))
1503
                          .ptr())
1504
                  ->tp_name,
1505
              static_cast<long>(failed_idx));
1506
        }
1507
        throw TypeError(
1508
            "%s(): argument '%s' (position %ld) must be %s, not %s",
1509
            name.c_str(),
1510
            param.name.c_str(),
1511
            static_cast<long>(arg_pos + 1),
1512
            param.type_name().c_str(),
1513
            Py_TYPE(obj)->tp_name);
1514
      }
1515
    } else {
1516
      return false;
1517
    }
1518

1519
    if (!is_kwd) {
1520
      arg_pos++;
1521
    } else if (obj) {
1522
      remaining_kwargs--;
1523
    }
1524
  }
1525

1526
  if (remaining_kwargs > 0) {
1527
    if (raise_exception) {
1528
      // foo() got an unexpected keyword argument "b"
1529
      extra_kwargs(*this, kwargs, nargs);
1530
    }
1531
    return false;
1532
  }
1533
  return true;
1534
}
1535

1536
PythonArgParser::PythonArgParser(
1537
    const std::vector<std::string>& fmts,
1538
    bool traceable)
1539
    : max_args(0), traceable(traceable) {
1540
  int index = 0;
1541
  for (auto& fmt : fmts) {
1542
    signatures_.emplace_back(fmt, index);
1543
    ++index;
1544
  }
1545
  for (auto& signature : signatures_) {
1546
    if (signature.max_args > max_args) {
1547
      max_args = signature.max_args;
1548
    }
1549
  }
1550
  if (!signatures_.empty()) {
1551
    function_name = signatures_[0].name;
1552
  }
1553

1554
  // Check deprecated signatures last
1555
  std::stable_partition(
1556
      signatures_.begin(), signatures_.end(), [](const FunctionSignature& sig) {
1557
        return !sig.deprecated;
1558
      });
1559
}
1560

1561
void PythonArgParser::check_deprecated(const FunctionSignature& signature) {
1562
  if (signature.deprecated) {
1563
    auto msg = c10::str(
1564
        "This overload of ",
1565
        signature.name,
1566
        " is deprecated:\n\t",
1567
        signature.name,
1568
        signature.toString());
1569
    auto signatures = get_signatures();
1570
    if (!signatures.empty()) {
1571
      msg += "\nConsider using one of the following signatures instead:";
1572
      for (const auto& sig : signatures) {
1573
        msg += "\n\t";
1574
        msg += signature.name;
1575
        msg += sig;
1576
      }
1577
    }
1578
    TORCH_WARN_ONCE(msg);
1579
  }
1580
}
1581

1582
PythonArgs PythonArgParser::raw_parse(
1583
    PyObject* self,
1584
    PyObject* args,
1585
    PyObject* kwargs,
1586
    PyObject* parsed_args[]) { // NOLINT
1587
  if (signatures_.size() == 1) {
1588
    auto& signature = signatures_[0];
1589
    std::vector<PyObject*> overloaded_args;
1590
    signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
1591
    check_deprecated(signature);
1592
    return PythonArgs(
1593
        traceable, signature, parsed_args, std::move(overloaded_args));
1594
  }
1595

1596
  for (auto& signature : signatures_) {
1597
    std::vector<PyObject*> overloaded_args;
1598
    if (signature.parse(
1599
            self, args, kwargs, parsed_args, overloaded_args, false)) {
1600
      check_deprecated(signature);
1601
      return PythonArgs(
1602
          traceable, signature, parsed_args, std::move(overloaded_args));
1603
    }
1604
  }
1605

1606
  print_error(self, args, kwargs, parsed_args);
1607
}
1608

1609
void PythonArgParser::print_error(
1610
    PyObject* self,
1611
    PyObject* args,
1612
    PyObject* kwargs,
1613
    PyObject* parsed_args[]) { // NOLINT
1614
  size_t num_args =
1615
      (args ? PyTuple_GET_SIZE(args) : 0) + (kwargs ? PyDict_Size(kwargs) : 0);
1616
  std::vector<unsigned> plausible_idxs;
1617
  unsigned i = 0;
1618
  for (auto& signature : signatures_) {
1619
    if (num_args >= signature.min_args && num_args <= signature.max_args &&
1620
        !signature.hidden) {
1621
      plausible_idxs.push_back(i);
1622
    }
1623
    i++;
1624
  }
1625

1626
  if (plausible_idxs.size() == 1) {
1627
    auto& signature = signatures_[plausible_idxs[0]];
1628
    std::vector<PyObject*> overloaded_args;
1629
    signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
1630
  }
1631

1632
  auto options = get_signatures();
1633
  auto msg =
1634
      torch::format_invalid_args(args, kwargs, function_name + "()", options);
1635
  throw TypeError("%s", msg.c_str());
1636
}
1637

1638
std::vector<std::string> PythonArgParser::get_signatures() const {
1639
  std::vector<std::string> options;
1640
  for (auto& signature : signatures_) {
1641
    if (!signature.hidden) {
1642
      options.push_back(signature.toString());
1643
    }
1644
  }
1645
  return options;
1646
}
1647

1648
at::Tensor PythonArgs::tensor_slow(int i) {
1649
  PyObject* obj = args[i];
1650
  if (!obj) {
1651
    return at::Tensor();
1652
  }
1653
  if (THPVariable_Check(obj)) {
1654
    return THPVariable_Unpack(obj);
1655
  }
1656

1657
  bool save_symint = false;
1658
  at::Scalar scalar;
1659
  if (PyBool_Check(obj)) {
1660
    scalar = at::Scalar(THPUtils_unpackBool(obj));
1661
  } else if (THPUtils_checkLong(obj)) {
1662
    int overflow = -1;
1663
    long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
1664
    if (value == -1 && PyErr_Occurred()) {
1665
      throw python_error();
1666
    }
1667
    if (overflow != 0) {
1668
      // try unsigned
1669
      unsigned long long value = PyLong_AsUnsignedLongLong(obj);
1670
      if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1671
        throw python_error();
1672
      }
1673
      scalar = at::Scalar(static_cast<uint64_t>(value));
1674
    } else {
1675
      scalar = at::Scalar(static_cast<int64_t>(value));
1676
    }
1677
  } else if (PyComplex_Check(obj)) {
1678
    scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
1679
  } else if (THPUtils_checkDouble(obj)) {
1680
    scalar = at::Scalar(THPUtils_unpackDouble(obj));
1681
    // NB: we DO NOT put symbolic ints/floats into the Scalar itself,
1682
    // because although Scalar supports SymInt/SymFloat, the subsequent
1683
    // conversion to Tensor does not.  Instead, do it out of band.
1684
  } else if (torch::is_symint(py::handle(obj))) {
1685
    save_symint = true;
1686
    // This scalar value doesn't matter, it shouldn't ever actually
1687
    // get read out.  Make it a big and weird looking number to help
1688
    // people figure out if there's aproblem.
1689
    scalar = at::Scalar(7777777);
1690
  } else if (torch::is_symfloat(py::handle(obj))) {
1691
    save_symint = true;
1692
    scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
1693
  } else if (torch::is_symbool(py::handle(obj))) {
1694
    save_symint = true;
1695
    scalar = at::Scalar(true);
1696
  } else {
1697
    // NB: Are you here because you passed None to a Variable method,
1698
    // and you expected an undefined tensor to be returned?   Don't add
1699
    // a test for Py_None here; instead, you need to mark the argument
1700
    // as *allowing none*; you can do this by writing 'Tensor?' instead
1701
    // of 'Tensor' in the ATen metadata.
1702
    throw TypeError(
1703
        "expected Tensor as argument %d, but got %s", i, Py_TYPE(obj)->tp_name);
1704
  }
1705
  at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
1706
  at::tracer::impl::NoTracerDispatchMode tracer_guard;
1707

1708
  at::Tensor tensor = scalar_to_tensor(scalar);
1709
  tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
1710

1711
  if (save_symint) {
1712
    auto py_tensor = py::cast(tensor);
1713
    if (PyObject_SetAttrString(py_tensor.ptr(), "_wrapped_number", obj) < 0) {
1714
      throw python_error();
1715
    }
1716
  }
1717

1718
  return tensor;
1719
}
1720

1721
at::Scalar PythonArgs::scalar_slow(int i) {
1722
  if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
1723
    auto& var = THPVariable_Unpack(args[i]);
1724
    jit::tracer::ArgumentStash::stashValue(
1725
        signature.params[i].name, idx, var, c10::NumberType::get());
1726
  }
1727

1728
  return scalar_slow(args[i]);
1729
}
1730

1731
at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
1732
  // Zero-dim tensors are converted to Scalars as-is. Note this doesn't
1733
  // currently handle most NumPy scalar types except np.float64.
1734
  if (THPVariable_Check(arg)) {
1735
    return THPVariable_Unpack(arg).item();
1736
  }
1737

1738
  if (THPUtils_checkLong(arg)) {
1739
    int overflow = -1;
1740
    long long value = PyLong_AsLongLongAndOverflow(arg, &overflow);
1741
    if (value == -1 && PyErr_Occurred()) {
1742
      throw python_error();
1743
    }
1744
    if (overflow != 0) {
1745
      // try unsigned
1746
      unsigned long long value = PyLong_AsUnsignedLongLong(arg);
1747
      if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1748
        throw python_error();
1749
      }
1750
      return at::Scalar(static_cast<uint64_t>(value));
1751
    } else {
1752
      return at::Scalar(static_cast<int64_t>(value));
1753
    }
1754
  }
1755

1756
  if (PyBool_Check(arg)) {
1757
    return at::Scalar(THPUtils_unpackBool(arg));
1758
  }
1759

1760
  if (PyComplex_Check(arg)) {
1761
    return at::Scalar(THPUtils_unpackComplexDouble(arg));
1762
  }
1763

1764
  if (torch::is_symint(arg)) {
1765
    return at::Scalar(py::cast<c10::SymInt>(arg));
1766
  }
1767

1768
  if (torch::is_symfloat(arg)) {
1769
    return at::Scalar(py::cast<c10::SymFloat>(arg));
1770
  }
1771

1772
  if (torch::is_symbool(arg)) {
1773
    // Windows build fails with C2440: '<function-style-cast>'
1774
    // when at:Scalar(py::cast<c10::SymBool>(arg))
1775
    auto sym_bool = py::handle(arg).cast<c10::SymBool>();
1776
    return at::Scalar(sym_bool);
1777
  }
1778

1779
  return at::Scalar(THPUtils_unpackDouble(arg));
1780
}
1781

1782
} // namespace torch
1783

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.