pytorch
1#include <c10/util/Exception.h>
2#include <operator_registry.h>
3
4namespace torch {
5namespace executor {
6
7KernelRegistry& getKernelRegistry() {
8static KernelRegistry kernel_registry;
9return kernel_registry;
10}
11
12bool register_kernels(const ArrayRef<Kernel>& kernels) {
13return getKernelRegistry().register_kernels(kernels);
14}
15
16bool KernelRegistry::register_kernels(
17const ArrayRef<Kernel>& kernels) {
18for (const auto& kernel : kernels) {
19this->kernels_map_[kernel.name_] = kernel.kernel_;
20}
21return true;
22}
23
24bool hasKernelFn(const char* name) {
25return getKernelRegistry().hasKernelFn(name);
26}
27
28bool KernelRegistry::hasKernelFn(const char* name) {
29auto kernel = this->kernels_map_.find(name);
30return kernel != this->kernels_map_.end();
31}
32
33KernelFunction& getKernelFn(const char* name) {
34return getKernelRegistry().getKernelFn(name);
35}
36
37KernelFunction& KernelRegistry::getKernelFn(const char* name) {
38auto kernel = this->kernels_map_.find(name);
39TORCH_CHECK_MSG(kernel != this->kernels_map_.end(), "Kernel not found!");
40return kernel->second;
41}
42
43
44} // namespace executor
45} // namespace torch
46