pytorch

Форк
0
/
xnnpack_backend_preprocess.cpp 
132 строки · 4.4 Кб
1
#include <torch/csrc/jit/backends/backend.h>
2
#include <torch/csrc/jit/backends/backend_preprocess.h>
3

4
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
5
#include <torch/torch.h>
6
#include <xnnpack.h>
7

8
#include <ATen/core/List.h>
9
#include <torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h>
10

11
namespace torch {
12
namespace jit {
13
namespace xnnpack {
14
namespace delegate {
15

16
// Expected method_compile_spec should look something like this:
17
// {
18
//     "forward" : {"inputs" : at::Tensor}
19
// }
20
// or
21
// {
22
//     "forward" : {
23
//                  "inputs" : c10::List<at::Tensor>,
24
//                  "outputs" : c10::List<at::Tensor>
25
//                  }
26
// }
27
// in which the value for "inputs" is the input shape to the module.
28
// The module fed to the xnnpack backend must first be traced in order
29
// to propagate input shapes through the module. This is important
30
// for building the xnnpack_subgraph_t object.
31
c10::IValue preprocess(
32
    const Module& mod,
33
    const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
34
    const BackendDebugHandleGenerator& generate_debug_handles) {
35
  auto eval_mod = mod.clone();
36
  eval_mod.eval();
37
  eval_mod = torch::jit::freeze(eval_mod);
38

39
  c10::Dict<IValue, IValue> compiled(StringType::get(), TensorType::get());
40

41
  c10::IValue inp;
42
  c10::IValue out;
43

44
  TORCH_CHECK(
45
      method_compile_spec.contains("forward"),
46
      "method_compile_spec does not contain the \"forward\" key.");
47
  auto innerDict = method_compile_spec.at("forward");
48

49
  TORCH_CHECK(
50
      innerDict.isGenericDict() &&
51
          innerDict.toGenericDict().contains("inputs") &&
52
          innerDict.toGenericDict().contains("outputs"),
53
      "method_compile_spec does not contain a dictionary with an \"inputs\" key, under \"forward\" key.");
54

55
  inp = innerDict.toGenericDict().at("inputs");
56
  out = innerDict.toGenericDict().at("outputs");
57

58
  TORCH_CHECK(
59
      inp.isTensor() || inp.isTensorList(),
60
      "method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key.");
61
  TORCH_CHECK(
62
      out.isTensor() || out.isTensorList(),
63
      "method_compile_spec does not contain either a Tensor or TensorList, under it's \"outputs\" key.");
64

65
  // Graph preprocessing
66
  const auto& forward_method = eval_mod.get_method("forward");
67

68
  auto graph = toGraphFunction(forward_method.function()).graph()->copy();
69
  graph = tensorexpr::removeUnusedSelfArgument(graph);
70
  std::vector<c10::IValue> example_inputs;
71
  if (inp.isTensorList()) {
72
    c10::List<at::Tensor> inp_list = inp.toTensorList();
73
    TORCH_CHECK(
74
        graph->inputs().size() == inp_list.size(),
75
        "method_compile_spec inputs do not match expected number of forward inputs");
76

77
    example_inputs.reserve(inp_list.size());
78
    for (const auto i : c10::irange(inp_list.size())) {
79
      example_inputs.emplace_back(inp_list[i]);
80
    }
81
  } else {
82
    TORCH_CHECK(
83
        graph->inputs().size() == 1,
84
        "method_compile_spec inputs do not match expected number of forward inputs");
85

86
    example_inputs.emplace_back(inp.toTensor());
87
  }
88

89
  // inp above has been confirmed to be either Tensor or TensorList
90
  XNNGraph graph_builder;
91
  graph_builder.buildXNNGraph(graph, example_inputs);
92
  // at this point graph is complete, for the sake of testing preprocess at this
93
  // point we will do runtime setup and run with some default values
94

95
  // grabbing the inputs from compile spec for testing
96

97
  // gather sample inputs from compile spec
98
  std::vector<at::Tensor> inputs;
99
  auto input_list = inp.toList();
100

101
  for (int i = 0; i < input_list.size(); i++) {
102
    inputs.push_back(input_list.get(i).toTensor());
103
  }
104
  std::vector<at::Tensor> outputs;
105
  auto output_list = out.toList();
106
  std::vector<c10::IntList> output_shapes;
107

108
  // gather sample outputs from compile spec
109
  for (int i = 0; i < output_list.size(); i++) {
110
    auto sample_output = output_list.get(i).toTensor();
111
    outputs.push_back(sample_output);
112
    // also gather output shapes to forward along to device
113
    output_shapes.push_back(sample_output.sizes());
114
  }
115

116
  // sample run on sample inputs
117
  graph_builder.runGraphOnInputs(inputs, outputs);
118
  c10::List<c10::IntList> shapes_list(output_shapes);
119

120
  compiled.insert("ser_model", graph_builder.serializedXNNGraph());
121
  compiled.insert("outputs", shapes_list);
122
  compiled.insert("Answer", outputs);
123

124
  return compiled;
125
}
126
constexpr auto backend_name = "xnnpack";
127
static auto pre_reg = backend_preprocess_register(backend_name, preprocess);
128

129
} // namespace delegate
130
} // namespace xnnpack
131
} // namespace jit
132
} // namespace torch
133

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.