3
#include <torch/csrc/jit/backends/backend.h>
4
#include <torch/csrc/jit/mobile/nnc/context.h>
11
class NNCBackend : public PyTorchBackendInterface {
13
explicit NNCBackend() = default;
14
~NNCBackend() override = default;
16
bool is_available() override {
20
c10::impl::GenericDict compile(
21
c10::IValue processed,
22
c10::impl::GenericDict method_compile_spec) override {
23
cu_ = std::make_shared<CompilationUnit>(processed);
25
// Input method_compile_spec:
27
// Value: compile spec for each method
30
// Value: a backend handle for each method
32
c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
33
auto handles = c10::Dict<std::string, std::string>();
34
for (const auto& it : spec) {
35
// The handle for each method is the key (method name) itself.
36
handles.insert(it.key(), it.key());
38
return c10::impl::toGenericDict(handles);
41
c10::impl::GenericList execute(
43
c10::impl::GenericList inputs) override {
44
const std::string& method_name = handle.toStringRef();
45
auto function_name = c10::QualifiedName(method_name);
46
return cu_->run(function_name, inputs);
50
std::shared_ptr<CompilationUnit> cu_;
54
// TODO(mvz): temporarily disable NNC backend in mobile builds.
55
// static const auto cls = torch::jit::backend<NNCBackend>("nnc");