pytorch

Форк
0
/
xnnpack_backend_lib.cpp 
118 строк · 3.9 Кб
1
#include <ATen/Functions.h>
2
#include <ATen/Utils.h>
3
#include <c10/core/TensorImpl.h>
4
#include <torch/csrc/jit/backends/backend.h>
5
#include <torch/csrc/jit/backends/backend_exception.h>
6

7
#include <caffe2/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h>
8
#include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
9

10
namespace torch {
11
namespace jit {
12
namespace xnnpack {
13
namespace delegate {
14

15
class XNNModelWrapper : public CustomClassHolder {
16
 public:
17
  XNNExecutor executor_;
18
  XNNModelWrapper(XNNExecutor executor) : executor_(std::move(executor)){};
19

20
  XNNModelWrapper() = delete;
21

22
  XNNModelWrapper(const XNNModelWrapper& oldObject) = delete;
23
};
24

25
class XNNPackBackend : public PyTorchBackendInterface {
26
 public:
27
  // Constructor.
28
  // NOLINTNEXTLINE(modernize-use-equals-default)
29
  explicit XNNPackBackend() {}
30
  virtual ~XNNPackBackend() override = default;
31

32
  bool is_available() override {
33
    return xnn_status_success == xnn_initialize(/*allocator=*/nullptr);
34
  }
35

36
  c10::impl::GenericDict compile(
37
      c10::IValue processed,
38
      c10::impl::GenericDict method_compile_spec) override {
39
    auto dict = processed.toGenericDict();
40

41
    // Compiling and wrapping exeuction object
42
    const std::string& ser_model = dict.at("ser_model").toStringRef();
43
    XNNExecutor executor;
44
    XNNCompiler::compileModel(ser_model.data(), ser_model.length(), &executor);
45

46
    auto model_ptr = c10::make_intrusive<XNNModelWrapper>(std::move(executor));
47
    auto runtime_handle = IValue::make_capsule(model_ptr);
48
    auto wrapper = c10::static_intrusive_pointer_cast<XNNModelWrapper>(
49
        runtime_handle.toCapsule());
50

51
    // Packing outputs into generic dict
52
    c10::Dict<c10::IValue, c10::IValue> handles(
53
        c10::StringType::get(), c10::AnyType::get());
54

55
    c10::Dict<c10::IValue, c10::IValue> ret(
56
        c10::StringType::get(), c10::AnyType::get());
57

58
    ret.insert("runtime", runtime_handle);
59
    ret.insert("output_shapes", dict.at("outputs"));
60

61
    handles.insert("forward", ret);
62

63
    return handles;
64
  }
65

66
  // Currently this is not implemented, and everything is computed a head of
67
  // time the current implementation just takes the computed results from ahead
68
  // of time and grabs them. The inputs are fed in through the compile spec for
69
  // the sake of testing. In reality, the inputs will be fed in at this stage
70
  // and ran here.
71
  c10::impl::GenericList execute(
72
      c10::IValue handle,
73
      c10::impl::GenericList inputs) override {
74
    auto dict = handle.toGenericDict();
75
    auto output_shapes = dict.at("output_shapes").toList();
76

77
    auto capsule = dict.at("runtime").toCapsule();
78
    auto model_wrapper =
79
        c10::static_intrusive_pointer_cast<XNNModelWrapper>(capsule);
80

81
    XNNExecutor& executor = model_wrapper->executor_;
82

83
    std::vector<float*> input_pointers;
84
    for (int i = 0; i < inputs.size(); ++i) {
85
      at::IValue val = inputs.get(i);
86
      TORCH_CHECK(val.isTensor(), "Non-tensor inputs not supported");
87
      input_pointers.push_back(val.toTensor().data_ptr<float>());
88
    }
89

90
    std::vector<at::Tensor> output_tensors;
91
    std::vector<float*> output_pointers;
92
    output_tensors.reserve(output_shapes.size());
93
    for (int i = 0; i < output_shapes.size(); i++) {
94
      auto o_shape = output_shapes.get(i).toIntVector();
95
      auto output = at::empty(o_shape, c10::ScalarType::Float);
96
      output_tensors.push_back(output);
97
      output_pointers.push_back(output.data_ptr<float>());
98
    }
99

100
    TORCH_CHECK(
101
        executor.set_inputs(input_pointers, output_pointers),
102
        "Number of inputs/outputs does not match expected number of inputs/outputs");
103
    TORCH_CHECK(executor.forward(), "Failed to invoke XNNPack runtime");
104

105
    c10::List<at::Tensor> output_list(output_tensors);
106
    return c10::impl::toList(output_list);
107
  }
108
};
109

110
namespace {
111
constexpr auto backend_name = "xnnpack";
112
static auto cls = torch::jit::backend<XNNPackBackend>(backend_name);
113
} // namespace
114

115
} // namespace delegate
116
} // namespace xnnpack
117
} // namespace jit
118
} // namespace torch
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.