1
#include <ATen/Functions.h>
3
#include <c10/core/TensorImpl.h>
4
#include <torch/csrc/jit/backends/backend.h>
5
#include <torch/csrc/jit/backends/backend_exception.h>
7
#include <caffe2/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h>
8
#include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
15
class XNNModelWrapper : public CustomClassHolder {
17
XNNExecutor executor_;
18
XNNModelWrapper(XNNExecutor executor) : executor_(std::move(executor)){};
20
XNNModelWrapper() = delete;
22
XNNModelWrapper(const XNNModelWrapper& oldObject) = delete;
25
class XNNPackBackend : public PyTorchBackendInterface {
28
// NOLINTNEXTLINE(modernize-use-equals-default)
29
explicit XNNPackBackend() {}
30
virtual ~XNNPackBackend() override = default;
32
bool is_available() override {
33
return xnn_status_success == xnn_initialize(/*allocator=*/nullptr);
36
c10::impl::GenericDict compile(
37
c10::IValue processed,
38
c10::impl::GenericDict method_compile_spec) override {
39
auto dict = processed.toGenericDict();
41
// Compiling and wrapping exeuction object
42
const std::string& ser_model = dict.at("ser_model").toStringRef();
44
XNNCompiler::compileModel(ser_model.data(), ser_model.length(), &executor);
46
auto model_ptr = c10::make_intrusive<XNNModelWrapper>(std::move(executor));
47
auto runtime_handle = IValue::make_capsule(model_ptr);
48
auto wrapper = c10::static_intrusive_pointer_cast<XNNModelWrapper>(
49
runtime_handle.toCapsule());
51
// Packing outputs into generic dict
52
c10::Dict<c10::IValue, c10::IValue> handles(
53
c10::StringType::get(), c10::AnyType::get());
55
c10::Dict<c10::IValue, c10::IValue> ret(
56
c10::StringType::get(), c10::AnyType::get());
58
ret.insert("runtime", runtime_handle);
59
ret.insert("output_shapes", dict.at("outputs"));
61
handles.insert("forward", ret);
66
// Currently this is not implemented, and everything is computed a head of
67
// time the current implementation just takes the computed results from ahead
68
// of time and grabs them. The inputs are fed in through the compile spec for
69
// the sake of testing. In reality, the inputs will be fed in at this stage
71
c10::impl::GenericList execute(
73
c10::impl::GenericList inputs) override {
74
auto dict = handle.toGenericDict();
75
auto output_shapes = dict.at("output_shapes").toList();
77
auto capsule = dict.at("runtime").toCapsule();
79
c10::static_intrusive_pointer_cast<XNNModelWrapper>(capsule);
81
XNNExecutor& executor = model_wrapper->executor_;
83
std::vector<float*> input_pointers;
84
for (int i = 0; i < inputs.size(); ++i) {
85
at::IValue val = inputs.get(i);
86
TORCH_CHECK(val.isTensor(), "Non-tensor inputs not supported");
87
input_pointers.push_back(val.toTensor().data_ptr<float>());
90
std::vector<at::Tensor> output_tensors;
91
std::vector<float*> output_pointers;
92
output_tensors.reserve(output_shapes.size());
93
for (int i = 0; i < output_shapes.size(); i++) {
94
auto o_shape = output_shapes.get(i).toIntVector();
95
auto output = at::empty(o_shape, c10::ScalarType::Float);
96
output_tensors.push_back(output);
97
output_pointers.push_back(output.data_ptr<float>());
101
executor.set_inputs(input_pointers, output_pointers),
102
"Number of inputs/outputs does not match expected number of inputs/outputs");
103
TORCH_CHECK(executor.forward(), "Failed to invoke XNNPack runtime");
105
c10::List<at::Tensor> output_list(output_tensors);
106
return c10::impl::toList(output_list);
111
constexpr auto backend_name = "xnnpack";
112
static auto cls = torch::jit::backend<XNNPackBackend>(backend_name);
115
} // namespace delegate
116
} // namespace xnnpack