1
#include <torch/csrc/jit/frontend/function_schema_parser.h>
2
#include <torch/csrc/utils/python_dispatch.h>
5
#include <ATen/FuncTorchTLS.h>
6
#include <ATen/FunctionalTensorWrapper.h>
7
#include <ATen/TensorSubclassLikeUtils.h>
8
#include <ATen/core/NestedIntSymNodeImpl.h>
9
#include <ATen/core/PythonOpRegistrationTrampoline.h>
10
#include <ATen/core/dispatch/Dispatcher.h>
12
#include <ATen/functorch/BatchedTensorImpl.h>
13
#include <torch/library.h>
15
#include <c10/core/SafePyObject.h>
16
#include <torch/csrc/PyInterpreter.h>
17
#include <torch/csrc/autograd/python_variable.h>
18
#include <torch/csrc/jit/python/pybind_utils.h>
20
#include <c10/util/flat_hash_map.h>
21
#include <pybind11/operators.h>
22
#include <pybind11/stl.h>
23
#include <torch/csrc/utils/pybind.h>
24
#include <torch/csrc/utils/python_raii.h>
29
namespace py = pybind11;
35
// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
36
// guarantee that the main interpreter has finish doing all registrations before
37
// the other interpreters start banging on it
38
static ska::flat_hash_map<
40
ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
41
python_registrations_;
43
static torch::Library::Kind parseKind(const std::string& k) {
44
static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
45
{"DEF", torch::Library::DEF},
46
{"IMPL", torch::Library::IMPL},
47
{"FRAGMENT", torch::Library::FRAGMENT},
49
auto it = kind_map.find(k);
50
TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
53
static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
54
static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
55
{"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
56
{"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
57
{"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
58
{"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
60
auto it = key_map.find(k);
61
TORCH_CHECK(it != key_map.end(), "could not parse ", k);
65
template <typename Func>
66
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
67
auto mb_key = std::string(key).empty()
69
: c10::make_optional(c10::parseDispatchKey(key));
71
return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
73
torch::CppFunction f(std::forward<Func>(raw_f));
78
struct EnableHermeticPyObject {
79
EnableHermeticPyObject()
80
: old_(c10::impl::HermeticPyObjectTLS::get_state()),
82
c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
84
c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
85
old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
86
at::DispatchKey::PythonTLSSnapshot)) {
87
c10::impl::HermeticPyObjectTLS::set_state(true);
88
c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
89
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
90
c10::impl::tls_set_dispatch_key_included(
91
at::DispatchKey::PythonTLSSnapshot, false);
93
~EnableHermeticPyObject() {
94
c10::impl::HermeticPyObjectTLS::set_state(old_);
95
c10::impl::tls_set_dispatch_key_excluded(
96
at::DispatchKey::Python, old_excluded_python_);
97
c10::impl::tls_set_dispatch_key_included(
98
at::DispatchKey::Python, old_python_);
99
c10::impl::tls_set_dispatch_key_included(
100
at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
103
bool old_excluded_python_;
105
bool old_python_snapshot_;
108
class PythonKernelHolder : public c10::OperatorKernel {
109
c10::SafePyObject func_;
110
c10::DispatchKey dispatch_key_;
113
PythonKernelHolder(py::object func, c10::DispatchKey dispatch_key)
114
: func_(func.release().ptr(), getPyInterpreter()),
115
dispatch_key_(dispatch_key) {}
118
const c10::OperatorHandle& op,
119
c10::DispatchKeySet keyset,
120
torch::jit::Stack* stack) {
121
// Figure out if we can handle it hermetically, or if we have
122
// to double dispatch
124
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
125
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
126
if (mode_stack_len > 0) {
127
const auto& cur_torch_dispatch_mode_state =
128
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
129
cur_torch_dispatch_mode_state->pyinterpreter()
130
->python_op_registration_trampoline(op, dispatch_key_, stack);
134
const auto& schema = op.schema();
135
const auto num_arguments = schema.arguments().size();
137
// Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
138
// means it's a nontrivial tensor subclass)
139
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
140
if (ivalue.isTensor()) {
142
ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
144
ivalue.unsafeToTensorImpl()->key_set().has(
145
at::DispatchKey::Python)) {
147
->python_op_registration_trampoline(op, dispatch_key_, stack);
150
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
151
// NB: use toListRef as it doesn't induce refcount bumps
152
// (toTensorListRef is not a thing)
153
for (const auto& nv : ivalue.toListRef()) {
158
nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
160
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
162
->python_op_registration_trampoline(op, dispatch_key_, stack);
169
// Nothing requires the operator to be homed to a specific interpreter, so
170
// run it on the current interpreter
172
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
173
py::gil_scoped_acquire g;
174
// Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic
175
// mode unconditionally in all situations when you're using multipy.
176
// Eventually just delete this entirely. (Note that you may break multipy
177
// anyway this way with dispatcher registered functions that require
178
// hermetic to be off.)
179
#if defined(USE_DEPLOY)
180
EnableHermeticPyObject g2;
182
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
183
auto obj = py::reinterpret_steal<py::object>(PyObject_Call(
184
func_.ptr(getPyInterpreter()),
185
args_kwargs.first.ptr(),
186
args_kwargs.second.ptr()));
188
throw python_error();
190
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
194
static torch::_RegisterOrVerify register_or_verify() {
195
if (isMainPyInterpreter()) {
196
return torch::_RegisterOrVerify::REGISTER;
198
return torch::_RegisterOrVerify::VERIFY;
202
static py::object ophandle_call_boxed(
203
const c10::OperatorHandle& handle,
205
const py::kwargs& kwargs) {
206
auto stack = torch::jit::createStackForSchema(
210
/*self=*/c10::nullopt);
212
pybind11::gil_scoped_release no_gil_guard;
213
handle.callBoxed(stack);
215
return torch::jit::createPyObjectForStack(std::move(stack));
218
// A small RAII guard that lets you explicitly *remove* a key from the TLS
220
class SetExcludeDispatchKeyGuard {
222
SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded)
223
: k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) {
224
c10::impl::tls_set_dispatch_key_excluded(k, set_excluded);
226
~SetExcludeDispatchKeyGuard() {
227
c10::impl::tls_set_dispatch_key_excluded(k, old);
229
SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete;
230
SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) =
232
SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete;
233
SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete;
240
void initDispatchBindings(PyObject* module) {
241
auto m = py::handle(module).cast<py::module>();
243
py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
244
.def("schema", &c10::OperatorHandle::schema);
246
m.def("_dispatch_call_boxed", &ophandle_call_boxed);
248
// TODO: figure out how to do chaining
249
py::class_<torch::Library>(m, "_DispatchModule")
252
[](const py::object& self) {
253
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
254
self.cast<torch::Library&>().reset();
258
// Some of these APIs are only for testing and do not work in multipy
262
[](py::object self, const char* schema, const char* alias) {
263
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
264
self.cast<torch::Library&>().def(
265
torch::schema(schema, parseAliasAnalysisKind(alias)));
270
py::arg("alias") = "")
271
// Simulated "legacy" def where alias analysis kind is not set.
272
// Ordinarily this can only be exercised from RegisterOperators() API
273
// but I am not going to bind that here
276
[](py::object self, const char* schema) {
277
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
278
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
283
// We can't conveniently turn Python functions into valid functions
284
// in the dispatcher. So instead we provide a bunch of precanned
285
// functions for testing purposes. You're NOT intended to actually
286
// call these functions; they're just here so we can actually register
289
// Mangling scheme: args_rets. One character per.
295
const char* dispatch,
297
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
298
self.cast<torch::Library&>().def(
299
name, dispatch_str(dispatch, [](const at::Tensor& a) {
306
py::arg("dispatch") = "",
307
py::arg("debug") = "default_def_name_t_t")
312
const char* dispatch,
315
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
316
self.cast<torch::Library&>().def(
317
torch::schema(schema, parseAliasAnalysisKind(alias)),
318
dispatch_str(dispatch, [](const at::Tensor& a) {
325
py::arg("dispatch") = "",
326
py::arg("alias") = "",
327
py::arg("debug") = "default_def_schema_t_t")
328
// TODO: maybe consider deduplicating the definitions here, it's getting
334
const char* dispatch,
336
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
337
self.cast<torch::Library&>().impl(
338
name, dispatch_str(dispatch, [](const at::Tensor& a) {
345
py::arg("dispatch") = "",
346
py::arg("debug") = "impl_t_t")
349
[](const py::object& self,
351
// TODO: empty string no longer works
352
c10::DispatchKey dispatch,
355
auto& lib = self.cast<torch::Library&>();
356
if (func.is(py::module::import("torch.library")
357
.attr("fallthrough_kernel"))) {
360
torch::dispatch(dispatch, CppFunction::makeFallthrough()),
361
register_or_verify());
367
CppFunction::makeFromBoxedFunctor(
368
std::make_unique<PythonKernelHolder>(
370
register_or_verify());
371
python_registrations_[lib._resolve(name)].insert_or_assign(
373
std::make_shared<c10::SafePyObject>(
374
func.release().ptr(), getPyInterpreter()));
376
END_HANDLE_TH_ERRORS_PYBIND
384
[](const py::object& self,
386
const char* alias_analysis,
387
const std::vector<at::Tag>& tags) {
389
torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
390
self.cast<torch::Library&>().def(
391
std::move(parsed_schema), tags, register_or_verify());
392
// TODO: this is dumb, had to make a second copy
393
return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
398
py::arg("alias_analysis") = "",
399
py::arg("tags") = std::vector<at::Tag>())
401
"fallback_fallthrough",
402
[](py::object self, const char* dispatch) {
403
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
404
self.cast<torch::Library&>().fallback(
405
dispatch_str(dispatch, CppFunction::makeFallthrough()));
409
py::arg("dispatch") = "");
415
const char* dispatch,
419
return std::make_unique<torch::Library>(
422
std::string(dispatch).empty()
424
: c10::make_optional(c10::parseDispatchKey(dispatch)),
425
"/dev/null", // temporary workaround
427
END_HANDLE_TH_ERRORS_PYBIND
433
py::arg("file") = "/dev/null",
434
py::arg("linenum") = 0);
437
"_dispatch_find_schema_or_throw",
438
[](const char* name, const char* overload_name) -> c10::OperatorHandle {
439
return c10::Dispatcher::singleton().findSchemaOrThrow(
440
name, overload_name);
443
m.def("_dispatch_dump", [](const char* name) -> std::string {
444
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
448
return op->dumpState();
452
m.def("_dispatch_dump_table", [](const char* name) -> std::string {
453
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
457
return op->dumpComputedTable();
461
m.def("_dispatch_check_invariants", [](const char* name) {
462
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
465
return op->checkInvariants();
469
m.def("_dispatch_check_all_invariants", []() {
470
c10::Dispatcher::singleton().checkInvariants();
473
m.def("_dispatch_has_kernel", [](const char* name) -> bool {
474
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
475
return static_cast<bool>(op);
479
// Returns whether or not a direct kernel registration exists
480
// for this <op_name, dispatch_key> pair.
481
"_dispatch_has_kernel_for_dispatch_key",
482
[](const char* name, c10::DispatchKey dispatch) -> bool {
484
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
485
TORCH_CHECK(op, "operator ", name, " does not exist");
486
return op->hasKernelForDispatchKey(dispatch);
490
"_dispatch_has_kernel_for_any_dispatch_key",
491
[](const char* name, c10::DispatchKeySet ks) -> bool {
493
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
494
TORCH_CHECK(op, "operator ", name, " does not exist");
495
return op->hasKernelForAnyDispatchKey(ks);
499
// Returns whether or not there is an entry in the runtime computed
500
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
501
// "op" has a `CompositeImplicitAutograd` kernel, Then
502
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
503
// true for all backends that are part of the alias set for
504
// CompositeImplicitAutograd.
505
"_dispatch_has_computed_kernel_for_dispatch_key",
506
[](const char* name, const char* dispatch) -> bool {
508
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
509
TORCH_CHECK(op, "operator ", name, " does not exist");
510
return op->hasComputedKernelForDispatchKey(
511
c10::parseDispatchKey(dispatch));
514
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
515
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
517
std::vector<std::string> states;
518
states.reserve(danglingImpls.size());
519
for (auto& danglingImpl : danglingImpls) {
520
states.emplace_back(danglingImpl.dumpState());
526
m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
527
auto op_names = c10::Dispatcher::singleton().getAllOpNames();
529
std::vector<std::string> names;
530
names.reserve(op_names.size());
531
for (auto& op : op_names) {
532
std::stringstream ss;
534
if (!op.overload_name.empty()) {
535
ss << "." << op.overload_name;
537
names.emplace_back(ss.str());
544
"_dispatch_tls_set_dispatch_key_excluded",
545
[](c10::DispatchKey dispatch_key, bool desired_state) {
546
c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
549
"_dispatch_tls_is_dispatch_key_excluded",
550
[](c10::DispatchKey dispatch_key) {
551
return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
554
"_dispatch_tls_set_dispatch_key_included",
555
[](c10::DispatchKey dispatch_key, bool desired_state) {
556
c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
559
"_dispatch_tls_is_dispatch_key_included",
560
[](c10::DispatchKey dispatch_key) {
561
return c10::impl::tls_is_dispatch_key_included(dispatch_key);
564
m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
565
return at::isTensorSubclassLike(tensor);
568
m.def("_dispatch_key_name", [](c10::DispatchKey k) {
569
return c10::toString(k);
571
m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
572
m.def("_to_functionality_key", [](c10::DispatchKey k) {
573
return c10::toFunctionalityKey(k);
575
// E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
579
// AutogradPrivateUse3
580
m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
581
std::vector<c10::DispatchKey> keys;
582
if (c10::isPerBackendFunctionalityKey(key)) {
583
auto ks = c10::DispatchKeySet(key) |
584
c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
593
m.def("_dispatch_num_backends", []() { return c10::num_backends; });
595
#define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
597
py::enum_<c10::DispatchKey>(m, "DispatchKey")
600
DEF_ONE(CompositeExplicitAutogradNonFunctional)
601
DEF_ONE(CompositeExplicitAutograd)
602
DEF_ONE(CompositeImplicitAutogradNestedTensor)
603
DEF_ONE(CompositeImplicitAutograd)
604
// NestedTensor is not a backend key
605
DEF_ONE(AutogradNestedTensor)
606
DEF_ONE(AutogradOther)
611
DEF_ONE(BackendSelect)
612
DEF_ONE(ADInplaceOrView)
613
DEF_ONE(PythonTLSSnapshot)
615
DEF_ONE(FuncTorchDynamicLayerFrontMode)
616
DEF_ONE(FuncTorchDynamicLayerBackMode)
617
DEF_ONE(FuncTorchBatchedDecomposition)
618
DEF_ONE(FuncTorchBatched)
619
DEF_ONE(FuncTorchVmapMode)
620
DEF_ONE(FuncTorchGradWrapper)
621
DEF_ONE(PythonDispatcher)
623
DEF_ONE(Functionalize)
628
DEF_ONE(AutocastCUDA)
629
DEF_ONE(AutocastPrivateUse1)
632
#define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
633
#define DEF_MULTIPLE(fullname, prefix) \
634
DEF_SINGLE(, fullname) \
635
DEF_SINGLE(, StartOf##fullname##Backends) \
636
C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
637
DEF_SINGLE(, EndOf##fullname##Backends)
640
C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
647
py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
648
.def(py::init<c10::DispatchKey>())
649
.def("__or__", &c10::DispatchKeySet::operator|)
650
.def("__sub__", &c10::DispatchKeySet::operator-)
651
.def("__and__", &c10::DispatchKeySet::operator&)
652
.def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
655
[](c10::DispatchKeySet self, c10::DispatchKey k) {
656
return self.remove(k);
660
[](c10::DispatchKeySet self, c10::DispatchKey k) {
663
.def("has", &c10::DispatchKeySet::has)
664
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
666
m.attr("_dispatch_autogradother_backends") =
667
py::cast(c10::autogradother_backends);
669
m.attr("_additional_keys_to_prop_for_wrapper_tensors") =
670
py::cast(at::functorch::kKeysToPropagateToWrapper);
672
m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
673
return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
676
m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
677
return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
680
m.def("_dispatch_keyset_full", []() {
681
return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
684
m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);
686
m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
687
return c10::toString(keyset);
690
m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
691
return c10::getBackendKeySetFromAutograd(k);
694
m.def("_dispatch_keys", [](const at::Tensor& tensor) {
695
auto* impl = tensor.unsafeGetTensorImpl();
696
return impl->key_set();
698
m.def("_dispatch_tls_local_include_set", []() {
699
return c10::impl::tls_local_dispatch_key_set().included_;
701
m.def("_dispatch_tls_local_exclude_set", []() {
702
return c10::impl::tls_local_dispatch_key_set().excluded_;
704
m.def("_functionalization_reapply_views_tls", []() {
705
return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
708
"_dispatch_is_included_in_alias",
709
[](c10::DispatchKey a, c10::DispatchKey b) {
710
return c10::isIncludedInAlias(a, b);
713
// DEPRECATED, please don't use this. Instead use
714
// torch._C._ExcludeDispatchKeyGuard
715
py_context_manager_DEPRECATED<
716
c10::impl::ExcludeDispatchKeyGuard,
717
c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");
720
c10::impl::ForceDispatchKeyGuard,
722
c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
723
py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
724
m, "_IncludeDispatchKeyGuard");
725
py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
726
m, "_ExcludeDispatchKeyGuard");
727
py_context_manager<SetExcludeDispatchKeyGuard, c10::DispatchKey, bool>(
728
m, "_SetExcludeDispatchKeyGuard");
730
py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
731
m, "_AutoDispatchBelowAutograd");
733
// Prints out the name of every operator that has a kernel registered to the
734
// Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
735
// out the name of every operator that the Dispatcher knows of. This can be
736
// useful to answer questions like "list all operators that do not have a CPU
739
"_dispatch_print_registrations_for_dispatch_key",
740
[](const char* dispatch_key = "") {
741
auto k = std::string(dispatch_key).empty()
743
: c10::make_optional(c10::parseDispatchKey(dispatch_key));
745
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
746
for (auto& op : op_names) {
747
std::cout << op << std::endl;
750
py::arg("dispatch_key") = static_cast<const char*>(""));
753
"_parse_dispatch_key",
754
[](const char* dispatch_key) -> c10::optional<c10::DispatchKey> {
756
return c10::parseDispatchKey(dispatch_key);
757
} catch (const c10::Error& err) {
763
"_dispatch_get_registrations_for_dispatch_key",
764
[](const char* dispatch_key = "") {
765
auto k = std::string(dispatch_key).empty()
767
: c10::make_optional(c10::parseDispatchKey(dispatch_key));
769
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
770
std::vector<std::string> names;
771
names.reserve(op_names.size());
772
for (auto& op : op_names) {
775
(op.overload_name.empty() ? "" : "." + op.overload_name));
779
py::arg("dispatch_key") = static_cast<const char*>(""));
781
"_dispatch_set_report_error_callback",
782
[](c10::OperatorHandle& handle, py::object callback) {
783
auto obj = callback.release().ptr();
785
std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
786
handle.setReportErrorCallback_(std::move(callback_obj));
790
"_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
791
m.def("_dispatch_pystub", [](const char* name, const char* overload) {
792
return c10::Dispatcher::singleton().getAbstractImplPyStub(
793
c10::OperatorName(name, overload));
796
m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
797
return at::functionalization::impl::replace_(a, b);
799
m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
800
at::functionalization::impl::propagate_xla_data(a, b);
802
m.def("_commit_update", [](const at::Tensor& a) {
803
return at::functionalization::impl::commit_update(a);
805
m.def("_unsafe_reset_storage", [](const at::Tensor& a) {
806
return at::functionalization::impl::unsafe_reset_storage(a);
809
m.def("_dispatch_key_for_device", [](const std::string& device_type) {
810
auto device = c10::Device(device_type);
813
"Expected device_type string to not have a device index; got ",
815
return c10::toString(
816
c10::computeDispatchKey(c10::nullopt, c10::nullopt, device));
819
m.def("_are_functorch_transforms_active", []() {
820
auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
822
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
823
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
826
m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
827
return c10::SymInt(c10::SymNode(
828
c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
831
m.def("_get_constant_bool_symnode", [](int64_t data) {
833
c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data));
836
m.def("_non_sym_sizes", [](const at::Tensor& a) {
837
return a.sizes(); // NB: NOT sym_size
840
using c10::impl::TorchDispatchModeKey;
841
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
842
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
843
.value("PROXY", TorchDispatchModeKey::PROXY)
844
.value("FAKE", TorchDispatchModeKey::FAKE);
847
// TODO: dedupe with the kernel
848
void python_op_registration_trampoline_impl(
849
const c10::OperatorHandle& op,
850
c10::DispatchKey key,
851
torch::jit::Stack* stack) {
852
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
853
py::gil_scoped_acquire g;
854
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
855
const auto& func = python_registrations_[op.operator_name()][key];
856
TORCH_INTERNAL_ASSERT(func != nullptr);
857
auto* pyobj = func->ptr(getPyInterpreter());
858
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
859
auto obj = py::reinterpret_steal<py::object>(
860
PyObject_Call(pyobj, args_kwargs.first.ptr(), args_kwargs.second.ptr()));
862
throw python_error();
864
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
867
} // namespace dispatch