pytorch

Форк
0
/
python_dispatch.cpp 
869 строк · 30.3 Кб
1
#include <torch/csrc/jit/frontend/function_schema_parser.h>
2
#include <torch/csrc/utils/python_dispatch.h>
3

4
#include <ATen/ATen.h>
5
#include <ATen/FuncTorchTLS.h>
6
#include <ATen/FunctionalTensorWrapper.h>
7
#include <ATen/TensorSubclassLikeUtils.h>
8
#include <ATen/core/NestedIntSymNodeImpl.h>
9
#include <ATen/core/PythonOpRegistrationTrampoline.h>
10
#include <ATen/core/dispatch/Dispatcher.h>
11

12
#include <ATen/functorch/BatchedTensorImpl.h>
13
#include <torch/library.h>
14

15
#include <c10/core/SafePyObject.h>
16
#include <torch/csrc/PyInterpreter.h>
17
#include <torch/csrc/autograd/python_variable.h>
18
#include <torch/csrc/jit/python/pybind_utils.h>
19

20
#include <c10/util/flat_hash_map.h>
21
#include <pybind11/operators.h>
22
#include <pybind11/stl.h>
23
#include <torch/csrc/utils/pybind.h>
24
#include <torch/csrc/utils/python_raii.h>
25

26
#include <iostream>
27
#include <utility>
28

29
namespace py = pybind11;
30

31
namespace torch {
32
namespace impl {
33
namespace dispatch {
34

35
// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
36
// guarantee that the main interpreter has finish doing all registrations before
37
// the other interpreters start banging on it
38
static ska::flat_hash_map<
39
    c10::OperatorName,
40
    ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
41
    python_registrations_;
42

43
static torch::Library::Kind parseKind(const std::string& k) {
44
  static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
45
      {"DEF", torch::Library::DEF},
46
      {"IMPL", torch::Library::IMPL},
47
      {"FRAGMENT", torch::Library::FRAGMENT},
48
  };
49
  auto it = kind_map.find(k);
50
  TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
51
  return it->second;
52
}
53
static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
54
  static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
55
      {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
56
      {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
57
      {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
58
      {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
59
  };
60
  auto it = key_map.find(k);
61
  TORCH_CHECK(it != key_map.end(), "could not parse ", k);
62
  return it->second;
63
}
64

65
template <typename Func>
66
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
67
  auto mb_key = std::string(key).empty()
68
      ? c10::nullopt
69
      : c10::make_optional(c10::parseDispatchKey(key));
70
  if (mb_key) {
71
    return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
72
  } else {
73
    torch::CppFunction f(std::forward<Func>(raw_f));
74
    return f;
75
  }
76
}
77

78
struct EnableHermeticPyObject {
79
  EnableHermeticPyObject()
80
      : old_(c10::impl::HermeticPyObjectTLS::get_state()),
81
        old_excluded_python_(
82
            c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
83
        old_python_(
84
            c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
85
        old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
86
            at::DispatchKey::PythonTLSSnapshot)) {
87
    c10::impl::HermeticPyObjectTLS::set_state(true);
88
    c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
89
    c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
90
    c10::impl::tls_set_dispatch_key_included(
91
        at::DispatchKey::PythonTLSSnapshot, false);
92
  }
93
  ~EnableHermeticPyObject() {
94
    c10::impl::HermeticPyObjectTLS::set_state(old_);
95
    c10::impl::tls_set_dispatch_key_excluded(
96
        at::DispatchKey::Python, old_excluded_python_);
97
    c10::impl::tls_set_dispatch_key_included(
98
        at::DispatchKey::Python, old_python_);
99
    c10::impl::tls_set_dispatch_key_included(
100
        at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
101
  }
102
  bool old_;
103
  bool old_excluded_python_;
104
  bool old_python_;
105
  bool old_python_snapshot_;
106
};
107

108
class PythonKernelHolder : public c10::OperatorKernel {
109
  c10::SafePyObject func_;
110
  c10::DispatchKey dispatch_key_;
111

112
 public:
113
  PythonKernelHolder(py::object func, c10::DispatchKey dispatch_key)
114
      : func_(func.release().ptr(), getPyInterpreter()),
115
        dispatch_key_(dispatch_key) {}
116

117
  void operator()(
118
      const c10::OperatorHandle& op,
119
      c10::DispatchKeySet keyset,
120
      torch::jit::Stack* stack) {
121
    // Figure out if we can handle it hermetically, or if we have
122
    // to double dispatch
123

124
    // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
125
    const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
126
    if (mode_stack_len > 0) {
127
      const auto& cur_torch_dispatch_mode_state =
128
          c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
129
      cur_torch_dispatch_mode_state->pyinterpreter()
130
          ->python_op_registration_trampoline(op, dispatch_key_, stack);
131
      return;
132
    }
133

134
    const auto& schema = op.schema();
135
    const auto num_arguments = schema.arguments().size();
136

137
    // Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
138
    // means it's a nontrivial tensor subclass)
139
    for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
140
      if (ivalue.isTensor()) {
141
        auto* interpreter =
142
            ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
143
        if (interpreter &&
144
            ivalue.unsafeToTensorImpl()->key_set().has(
145
                at::DispatchKey::Python)) {
146
          (*interpreter)
147
              ->python_op_registration_trampoline(op, dispatch_key_, stack);
148
          return;
149
        }
150
      } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
151
        // NB: use toListRef as it doesn't induce refcount bumps
152
        // (toTensorListRef is not a thing)
153
        for (const auto& nv : ivalue.toListRef()) {
154
          if (nv.isNone()) {
155
            continue;
156
          }
157
          auto* interpreter =
158
              nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
159
          if (interpreter &&
160
              nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
161
            (*interpreter)
162
                ->python_op_registration_trampoline(op, dispatch_key_, stack);
163
            return;
164
          }
165
        }
166
      }
167
    }
168

169
    // Nothing requires the operator to be homed to a specific interpreter, so
170
    // run it on the current interpreter
171

172
    auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
173
    py::gil_scoped_acquire g;
174
    // Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic
175
    // mode unconditionally in all situations when you're using multipy.
176
    // Eventually just delete this entirely.  (Note that you may break multipy
177
    // anyway this way with dispatcher registered functions that require
178
    // hermetic to be off.)
179
#if defined(USE_DEPLOY)
180
    EnableHermeticPyObject g2;
181
#endif
182
    auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
183
    auto obj = py::reinterpret_steal<py::object>(PyObject_Call(
184
        func_.ptr(getPyInterpreter()),
185
        args_kwargs.first.ptr(),
186
        args_kwargs.second.ptr()));
187
    if (!obj) {
188
      throw python_error();
189
    }
190
    pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
191
  }
192
};
193

194
static torch::_RegisterOrVerify register_or_verify() {
195
  if (isMainPyInterpreter()) {
196
    return torch::_RegisterOrVerify::REGISTER;
197
  } else {
198
    return torch::_RegisterOrVerify::VERIFY;
199
  }
200
}
201

202
static py::object ophandle_call_boxed(
203
    const c10::OperatorHandle& handle,
204
    py::args args,
205
    const py::kwargs& kwargs) {
206
  auto stack = torch::jit::createStackForSchema(
207
      handle.schema(),
208
      std::move(args),
209
      kwargs,
210
      /*self=*/c10::nullopt);
211
  {
212
    pybind11::gil_scoped_release no_gil_guard;
213
    handle.callBoxed(stack);
214
  }
215
  return torch::jit::createPyObjectForStack(std::move(stack));
216
}
217

218
// A small RAII guard that lets you explicitly *remove* a key from the TLS
219
// exclude set.
220
class SetExcludeDispatchKeyGuard {
221
 public:
222
  SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded)
223
      : k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) {
224
    c10::impl::tls_set_dispatch_key_excluded(k, set_excluded);
225
  }
226
  ~SetExcludeDispatchKeyGuard() {
227
    c10::impl::tls_set_dispatch_key_excluded(k, old);
228
  }
229
  SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete;
230
  SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) =
231
      delete;
232
  SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete;
233
  SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete;
234

235
 private:
236
  at::DispatchKey k;
237
  bool old;
238
};
239

240
void initDispatchBindings(PyObject* module) {
241
  auto m = py::handle(module).cast<py::module>();
242

243
  py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
244
      .def("schema", &c10::OperatorHandle::schema);
245

246
  m.def("_dispatch_call_boxed", &ophandle_call_boxed);
247

248
  // TODO: figure out how to do chaining
249
  py::class_<torch::Library>(m, "_DispatchModule")
250
      .def(
251
          "reset",
252
          [](const py::object& self) {
253
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
254
            self.cast<torch::Library&>().reset();
255
            return;
256
          },
257
          "")
258
      // Some of these APIs are only for testing and do not work in multipy
259
      // environment
260
      .def(
261
          "def_",
262
          [](py::object self, const char* schema, const char* alias) {
263
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
264
            self.cast<torch::Library&>().def(
265
                torch::schema(schema, parseAliasAnalysisKind(alias)));
266
            return self;
267
          },
268
          "",
269
          py::arg("schema"),
270
          py::arg("alias") = "")
271
      // Simulated "legacy" def where alias analysis kind is not set.
272
      // Ordinarily this can only be exercised from RegisterOperators() API
273
      // but I am not going to bind that here
274
      .def(
275
          "def_legacy",
276
          [](py::object self, const char* schema) {
277
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
278
            self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
279
            return self;
280
          },
281
          "",
282
          py::arg("schema"))
283
      // We can't conveniently turn Python functions into valid functions
284
      // in the dispatcher.  So instead we provide a bunch of precanned
285
      // functions for testing purposes.  You're NOT intended to actually
286
      // call these functions; they're just here so we can actually register
287
      // something
288
      //
289
      // Mangling scheme: args_rets.  One character per.
290
      //  t = Tensor
291
      .def(
292
          "def_name_t_t",
293
          [](py::object self,
294
             const char* name,
295
             const char* dispatch,
296
             const char* debug) {
297
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
298
            self.cast<torch::Library&>().def(
299
                name, dispatch_str(dispatch, [](const at::Tensor& a) {
300
                        return a;
301
                      }).debug(debug));
302
            return self;
303
          },
304
          "",
305
          py::arg("name"),
306
          py::arg("dispatch") = "",
307
          py::arg("debug") = "default_def_name_t_t")
308
      .def(
309
          "def_schema_t_t",
310
          [](py::object self,
311
             const char* schema,
312
             const char* dispatch,
313
             const char* alias,
314
             const char* debug) {
315
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
316
            self.cast<torch::Library&>().def(
317
                torch::schema(schema, parseAliasAnalysisKind(alias)),
318
                dispatch_str(dispatch, [](const at::Tensor& a) {
319
                  return a;
320
                }).debug(debug));
321
            return self;
322
          },
323
          "",
324
          py::arg("name"),
325
          py::arg("dispatch") = "",
326
          py::arg("alias") = "",
327
          py::arg("debug") = "default_def_schema_t_t")
328
      // TODO: maybe consider deduplicating the definitions here, it's getting
329
      // pretty long
330
      .def(
331
          "impl_t_t",
332
          [](py::object self,
333
             const char* name,
334
             const char* dispatch,
335
             const char* debug) {
336
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
337
            self.cast<torch::Library&>().impl(
338
                name, dispatch_str(dispatch, [](const at::Tensor& a) {
339
                        return a;
340
                      }).debug(debug));
341
            return self;
342
          },
343
          "",
344
          py::arg("name"),
345
          py::arg("dispatch") = "",
346
          py::arg("debug") = "impl_t_t")
347
      .def(
348
          "impl",
349
          [](const py::object& self,
350
             const char* name,
351
             // TODO: empty string no longer works
352
             c10::DispatchKey dispatch,
353
             py::object func) {
354
            HANDLE_TH_ERRORS
355
            auto& lib = self.cast<torch::Library&>();
356
            if (func.is(py::module::import("torch.library")
357
                            .attr("fallthrough_kernel"))) {
358
              lib.impl(
359
                  name,
360
                  torch::dispatch(dispatch, CppFunction::makeFallthrough()),
361
                  register_or_verify());
362
            } else {
363
              lib.impl(
364
                  name,
365
                  torch::dispatch(
366
                      dispatch,
367
                      CppFunction::makeFromBoxedFunctor(
368
                          std::make_unique<PythonKernelHolder>(
369
                              func, dispatch))),
370
                  register_or_verify());
371
              python_registrations_[lib._resolve(name)].insert_or_assign(
372
                  dispatch,
373
                  std::make_shared<c10::SafePyObject>(
374
                      func.release().ptr(), getPyInterpreter()));
375
            }
376
            END_HANDLE_TH_ERRORS_PYBIND
377
          },
378
          "",
379
          py::arg("name"),
380
          py::arg("dispatch"),
381
          py::arg("func"))
382
      .def(
383
          "define",
384
          [](const py::object& self,
385
             const char* schema,
386
             const char* alias_analysis,
387
             const std::vector<at::Tag>& tags) {
388
            auto parsed_schema =
389
                torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
390
            self.cast<torch::Library&>().def(
391
                std::move(parsed_schema), tags, register_or_verify());
392
            // TODO: this is dumb, had to make a second copy
393
            return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
394
                .name();
395
          },
396
          "",
397
          py::arg("schema"),
398
          py::arg("alias_analysis") = "",
399
          py::arg("tags") = std::vector<at::Tag>())
400
      .def(
401
          "fallback_fallthrough",
402
          [](py::object self, const char* dispatch) {
403
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
404
            self.cast<torch::Library&>().fallback(
405
                dispatch_str(dispatch, CppFunction::makeFallthrough()));
406
            return self;
407
          },
408
          "",
409
          py::arg("dispatch") = "");
410

411
  m.def(
412
      "_dispatch_library",
413
      [](const char* kind,
414
         std::string name,
415
         const char* dispatch,
416
         const char* file,
417
         uint32_t linenum) {
418
        HANDLE_TH_ERRORS
419
        return std::make_unique<torch::Library>(
420
            parseKind(kind),
421
            std::move(name),
422
            std::string(dispatch).empty()
423
                ? c10::nullopt
424
                : c10::make_optional(c10::parseDispatchKey(dispatch)),
425
            "/dev/null", // temporary workaround
426
            linenum);
427
        END_HANDLE_TH_ERRORS_PYBIND
428
      },
429
      "",
430
      py::arg("kind"),
431
      py::arg("name"),
432
      py::arg("dispatch"),
433
      py::arg("file") = "/dev/null",
434
      py::arg("linenum") = 0);
435

436
  m.def(
437
      "_dispatch_find_schema_or_throw",
438
      [](const char* name, const char* overload_name) -> c10::OperatorHandle {
439
        return c10::Dispatcher::singleton().findSchemaOrThrow(
440
            name, overload_name);
441
      });
442

443
  m.def("_dispatch_dump", [](const char* name) -> std::string {
444
    auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
445
    if (!op) {
446
      return "";
447
    } else {
448
      return op->dumpState();
449
    }
450
  });
451

452
  m.def("_dispatch_dump_table", [](const char* name) -> std::string {
453
    auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
454
    if (!op) {
455
      return "";
456
    } else {
457
      return op->dumpComputedTable();
458
    }
459
  });
460

461
  m.def("_dispatch_check_invariants", [](const char* name) {
462
    auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
463
    if (!op) {
464
    } else {
465
      return op->checkInvariants();
466
    }
467
  });
468

469
  m.def("_dispatch_check_all_invariants", []() {
470
    c10::Dispatcher::singleton().checkInvariants();
471
  });
472

473
  m.def("_dispatch_has_kernel", [](const char* name) -> bool {
474
    auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
475
    return static_cast<bool>(op);
476
  });
477

478
  m.def(
479
      // Returns whether or not a direct kernel registration exists
480
      // for this <op_name, dispatch_key> pair.
481
      "_dispatch_has_kernel_for_dispatch_key",
482
      [](const char* name, c10::DispatchKey dispatch) -> bool {
483
        auto op =
484
            c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
485
        TORCH_CHECK(op, "operator ", name, " does not exist");
486
        return op->hasKernelForDispatchKey(dispatch);
487
      });
488

489
  m.def(
490
      "_dispatch_has_kernel_for_any_dispatch_key",
491
      [](const char* name, c10::DispatchKeySet ks) -> bool {
492
        auto op =
493
            c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
494
        TORCH_CHECK(op, "operator ", name, " does not exist");
495
        return op->hasKernelForAnyDispatchKey(ks);
496
      });
497

498
  m.def(
499
      // Returns whether or not there is an entry in the runtime computed
500
      // dispatch table, for this <op_name, dispatch_key> pair. For example, if
501
      // "op" has a `CompositeImplicitAutograd` kernel, Then
502
      // _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
503
      // true for all backends that are part of the alias set for
504
      // CompositeImplicitAutograd.
505
      "_dispatch_has_computed_kernel_for_dispatch_key",
506
      [](const char* name, const char* dispatch) -> bool {
507
        auto op =
508
            c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
509
        TORCH_CHECK(op, "operator ", name, " does not exist");
510
        return op->hasComputedKernelForDispatchKey(
511
            c10::parseDispatchKey(dispatch));
512
      });
513

514
  m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
515
    auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
516

517
    std::vector<std::string> states;
518
    states.reserve(danglingImpls.size());
519
    for (auto& danglingImpl : danglingImpls) {
520
      states.emplace_back(danglingImpl.dumpState());
521
    }
522

523
    return states;
524
  });
525

526
  m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
527
    auto op_names = c10::Dispatcher::singleton().getAllOpNames();
528

529
    std::vector<std::string> names;
530
    names.reserve(op_names.size());
531
    for (auto& op : op_names) {
532
      std::stringstream ss;
533
      ss << op.name;
534
      if (!op.overload_name.empty()) {
535
        ss << "." << op.overload_name;
536
      }
537
      names.emplace_back(ss.str());
538
    }
539

540
    return names;
541
  });
542

543
  m.def(
544
      "_dispatch_tls_set_dispatch_key_excluded",
545
      [](c10::DispatchKey dispatch_key, bool desired_state) {
546
        c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
547
      });
548
  m.def(
549
      "_dispatch_tls_is_dispatch_key_excluded",
550
      [](c10::DispatchKey dispatch_key) {
551
        return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
552
      });
553
  m.def(
554
      "_dispatch_tls_set_dispatch_key_included",
555
      [](c10::DispatchKey dispatch_key, bool desired_state) {
556
        c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
557
      });
558
  m.def(
559
      "_dispatch_tls_is_dispatch_key_included",
560
      [](c10::DispatchKey dispatch_key) {
561
        return c10::impl::tls_is_dispatch_key_included(dispatch_key);
562
      });
563

564
  m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
565
    return at::isTensorSubclassLike(tensor);
566
  });
567

568
  m.def("_dispatch_key_name", [](c10::DispatchKey k) {
569
    return c10::toString(k);
570
  });
571
  m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
572
  m.def("_to_functionality_key", [](c10::DispatchKey k) {
573
    return c10::toFunctionalityKey(k);
574
  });
575
  // E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
576
  //  AutogradCPU
577
  //  AutogradCUDA
578
  //  ...
579
  //  AutogradPrivateUse3
580
  m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
581
    std::vector<c10::DispatchKey> keys;
582
    if (c10::isPerBackendFunctionalityKey(key)) {
583
      auto ks = c10::DispatchKeySet(key) |
584
          c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
585
      for (auto k : ks) {
586
        keys.push_back(k);
587
      }
588
    } else {
589
      keys.push_back(key);
590
    }
591
    return keys;
592
  });
593
  m.def("_dispatch_num_backends", []() { return c10::num_backends; });
594

595
#define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
596

597
  py::enum_<c10::DispatchKey>(m, "DispatchKey")
598
      // clang-format off
599
      DEF_ONE(Undefined)
600
      DEF_ONE(CompositeExplicitAutogradNonFunctional)
601
      DEF_ONE(CompositeExplicitAutograd)
602
      DEF_ONE(CompositeImplicitAutogradNestedTensor)
603
      DEF_ONE(CompositeImplicitAutograd)
604
      // NestedTensor is not a backend key
605
      DEF_ONE(AutogradNestedTensor)
606
      DEF_ONE(AutogradOther)
607
      DEF_ONE(Autograd)
608
      DEF_ONE(Conjugate)
609
      DEF_ONE(ZeroTensor)
610
      DEF_ONE(Negative)
611
      DEF_ONE(BackendSelect)
612
      DEF_ONE(ADInplaceOrView)
613
      DEF_ONE(PythonTLSSnapshot)
614
      DEF_ONE(Python)
615
      DEF_ONE(FuncTorchDynamicLayerFrontMode)
616
      DEF_ONE(FuncTorchDynamicLayerBackMode)
617
      DEF_ONE(FuncTorchBatchedDecomposition)
618
      DEF_ONE(FuncTorchBatched)
619
      DEF_ONE(FuncTorchVmapMode)
620
      DEF_ONE(FuncTorchGradWrapper)
621
      DEF_ONE(PythonDispatcher)
622
      DEF_ONE(PreDispatch)
623
      DEF_ONE(Functionalize)
624
      DEF_ONE(AutocastCPU)
625
      DEF_ONE(AutocastXPU)
626
      DEF_ONE(AutocastHPU)
627
      DEF_ONE(AutocastIPU)
628
      DEF_ONE(AutocastCUDA)
629
      DEF_ONE(AutocastPrivateUse1)
630
  // clang-format on
631

632
#define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
633
#define DEF_MULTIPLE(fullname, prefix)              \
634
  DEF_SINGLE(, fullname)                            \
635
  DEF_SINGLE(, StartOf##fullname##Backends)         \
636
  C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
637
  DEF_SINGLE(, EndOf##fullname##Backends)
638

639
      // clang-format off
640
  C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
641
  // clang-format on
642

643
#undef DEF_MULTIPLE
644
#undef DEF_SINGLE
645
          ;
646

647
  py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
648
      .def(py::init<c10::DispatchKey>())
649
      .def("__or__", &c10::DispatchKeySet::operator|)
650
      .def("__sub__", &c10::DispatchKeySet::operator-)
651
      .def("__and__", &c10::DispatchKeySet::operator&)
652
      .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
653
      .def(
654
          "remove",
655
          [](c10::DispatchKeySet self, c10::DispatchKey k) {
656
            return self.remove(k);
657
          })
658
      .def(
659
          "add",
660
          [](c10::DispatchKeySet self, c10::DispatchKey k) {
661
            return self.add(k);
662
          })
663
      .def("has", &c10::DispatchKeySet::has)
664
      .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
665

666
  m.attr("_dispatch_autogradother_backends") =
667
      py::cast(c10::autogradother_backends);
668

669
  m.attr("_additional_keys_to_prop_for_wrapper_tensors") =
670
      py::cast(at::functorch::kKeysToPropagateToWrapper);
671

672
  m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
673
    return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
674
  });
675

676
  m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
677
    return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
678
  });
679

680
  m.def("_dispatch_keyset_full", []() {
681
    return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
682
  });
683

684
  m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);
685

686
  m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
687
    return c10::toString(keyset);
688
  });
689

690
  m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
691
    return c10::getBackendKeySetFromAutograd(k);
692
  });
693

694
  m.def("_dispatch_keys", [](const at::Tensor& tensor) {
695
    auto* impl = tensor.unsafeGetTensorImpl();
696
    return impl->key_set();
697
  });
698
  m.def("_dispatch_tls_local_include_set", []() {
699
    return c10::impl::tls_local_dispatch_key_set().included_;
700
  });
701
  m.def("_dispatch_tls_local_exclude_set", []() {
702
    return c10::impl::tls_local_dispatch_key_set().excluded_;
703
  });
704
  m.def("_functionalization_reapply_views_tls", []() {
705
    return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
706
  });
707
  m.def(
708
      "_dispatch_is_included_in_alias",
709
      [](c10::DispatchKey a, c10::DispatchKey b) {
710
        return c10::isIncludedInAlias(a, b);
711
      });
712

713
  // DEPRECATED, please don't use this. Instead use
714
  // torch._C._ExcludeDispatchKeyGuard
715
  py_context_manager_DEPRECATED<
716
      c10::impl::ExcludeDispatchKeyGuard,
717
      c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");
718

719
  py_context_manager<
720
      c10::impl::ForceDispatchKeyGuard,
721
      c10::DispatchKeySet,
722
      c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
723
  py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
724
      m, "_IncludeDispatchKeyGuard");
725
  py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
726
      m, "_ExcludeDispatchKeyGuard");
727
  py_context_manager<SetExcludeDispatchKeyGuard, c10::DispatchKey, bool>(
728
      m, "_SetExcludeDispatchKeyGuard");
729

730
  py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
731
      m, "_AutoDispatchBelowAutograd");
732

733
  // Prints out the name of every operator that has a kernel registered to the
734
  // Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
735
  // out the name of every operator that the Dispatcher knows of. This can be
736
  // useful to answer questions like "list all operators that do not have a CPU
737
  // kernel".
738
  m.def(
739
      "_dispatch_print_registrations_for_dispatch_key",
740
      [](const char* dispatch_key = "") {
741
        auto k = std::string(dispatch_key).empty()
742
            ? c10::nullopt
743
            : c10::make_optional(c10::parseDispatchKey(dispatch_key));
744
        auto op_names =
745
            c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
746
        for (auto& op : op_names) {
747
          std::cout << op << std::endl;
748
        }
749
      },
750
      py::arg("dispatch_key") = static_cast<const char*>(""));
751

752
  m.def(
753
      "_parse_dispatch_key",
754
      [](const char* dispatch_key) -> c10::optional<c10::DispatchKey> {
755
        try {
756
          return c10::parseDispatchKey(dispatch_key);
757
        } catch (const c10::Error& err) {
758
          return c10::nullopt;
759
        }
760
      });
761

762
  m.def(
763
      "_dispatch_get_registrations_for_dispatch_key",
764
      [](const char* dispatch_key = "") {
765
        auto k = std::string(dispatch_key).empty()
766
            ? c10::nullopt
767
            : c10::make_optional(c10::parseDispatchKey(dispatch_key));
768
        auto op_names =
769
            c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
770
        std::vector<std::string> names;
771
        names.reserve(op_names.size());
772
        for (auto& op : op_names) {
773
          names.emplace_back(
774
              op.name +
775
              (op.overload_name.empty() ? "" : "." + op.overload_name));
776
        }
777
        return names;
778
      },
779
      py::arg("dispatch_key") = static_cast<const char*>(""));
780
  m.def(
781
      "_dispatch_set_report_error_callback",
782
      [](c10::OperatorHandle& handle, py::object callback) {
783
        auto obj = callback.release().ptr();
784
        auto callback_obj =
785
            std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
786
        handle.setReportErrorCallback_(std::move(callback_obj));
787
      });
788

789
  m.def(
790
      "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
791
  m.def("_dispatch_pystub", [](const char* name, const char* overload) {
792
    return c10::Dispatcher::singleton().getAbstractImplPyStub(
793
        c10::OperatorName(name, overload));
794
  });
795

796
  m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
797
    return at::functionalization::impl::replace_(a, b);
798
  });
799
  m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
800
    at::functionalization::impl::propagate_xla_data(a, b);
801
  });
802
  m.def("_commit_update", [](const at::Tensor& a) {
803
    return at::functionalization::impl::commit_update(a);
804
  });
805
  m.def("_unsafe_reset_storage", [](const at::Tensor& a) {
806
    return at::functionalization::impl::unsafe_reset_storage(a);
807
  });
808

809
  m.def("_dispatch_key_for_device", [](const std::string& device_type) {
810
    auto device = c10::Device(device_type);
811
    TORCH_CHECK(
812
        !device.has_index(),
813
        "Expected device_type string to not have a device index; got ",
814
        device_type);
815
    return c10::toString(
816
        c10::computeDispatchKey(c10::nullopt, c10::nullopt, device));
817
  });
818

819
  m.def("_are_functorch_transforms_active", []() {
820
    auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
821
    return (
822
        include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
823
        include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
824
  });
825

826
  m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
827
    return c10::SymInt(c10::SymNode(
828
        c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
829
  });
830

831
  m.def("_get_constant_bool_symnode", [](int64_t data) {
832
    return c10::SymNode(
833
        c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data));
834
  });
835

836
  m.def("_non_sym_sizes", [](const at::Tensor& a) {
837
    return a.sizes(); // NB: NOT sym_size
838
  });
839

840
  using c10::impl::TorchDispatchModeKey;
841
  py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
842
      .value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
843
      .value("PROXY", TorchDispatchModeKey::PROXY)
844
      .value("FAKE", TorchDispatchModeKey::FAKE);
845
}
846

847
// TODO: dedupe with the kernel
848
void python_op_registration_trampoline_impl(
849
    const c10::OperatorHandle& op,
850
    c10::DispatchKey key,
851
    torch::jit::Stack* stack) {
852
  auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
853
  py::gil_scoped_acquire g;
854
  auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
855
  const auto& func = python_registrations_[op.operator_name()][key];
856
  TORCH_INTERNAL_ASSERT(func != nullptr);
857
  auto* pyobj = func->ptr(getPyInterpreter());
858
  TORCH_INTERNAL_ASSERT(pyobj != nullptr);
859
  auto obj = py::reinterpret_steal<py::object>(
860
      PyObject_Call(pyobj, args_kwargs.first.ptr(), args_kwargs.second.ptr()));
861
  if (!obj) {
862
    throw python_error();
863
  }
864
  pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
865
}
866

867
} // namespace dispatch
868
} // namespace impl
869
} // namespace torch
870

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.