pytorch

Форк
0
/
xnnpack_graph_builder.cpp 
327 строк · 10.0 Кб
1
// Copyright (c) Meta Platforms, Inc. and affiliates.
2
//
3
// This source code is licensed under the BSD-style license found in the
4
// LICENSE file in the root directory of this source tree.
5

6
#include <caffe2/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h>
7
#include <torch/csrc/jit/runtime/graph_iterator.h>
8
#include <xnnpack.h>
9

10
// graph passes
11
#include <torch/csrc/jit/passes/constant_propagation.h>
12
#include <torch/csrc/jit/passes/dead_code_elimination.h>
13
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
14
#include <torch/csrc/jit/passes/lower_tuples.h>
15
#include <torch/csrc/jit/passes/remove_mutation.h>
16
#include <torch/csrc/jit/runtime/jit_trace.h>
17
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
18

19
namespace torch {
20
namespace jit {
21
namespace xnnpack {
22
namespace delegate {
23

24
std::shared_ptr<torch::jit::Graph> XNNGraph::optimizeAndTraceGraph(
25
    std::shared_ptr<torch::jit::Graph> graph,
26
    std::vector<c10::IValue>& example_inputs) {
27
  OptimizeFrozenGraph(graph, true);
28
  RemoveListMutation(graph);
29
  RemoveTensorMutation(graph);
30
  LowerAllTuples(graph);
31
  ConstantPropagation(graph);
32
  graph = TraceGraph(graph, example_inputs);
33

34
  return graph;
35
}
36

37
void XNNGraph::buildXNNGraph(
38
    std::shared_ptr<torch::jit::Graph>& graph,
39
    std::vector<c10::IValue> example_inputs) {
40
  graph = optimizeAndTraceGraph(graph, example_inputs);
41
  checkOpsToDelegate(graph);
42
  gatherTensorValues(graph);
43

44
  // count unique input/outputs (some inputs can be outputs)
45
  std::unordered_set<torch::jit::Value*> externals;
46
  for (auto inp : _inputs) {
47
    externals.insert(inp);
48
  }
49
  for (auto out : _outputs) {
50
    externals.insert(out);
51
  }
52

53
  // create subgraph
54
  xnn_status status = xnn_create_subgraph(
55
      /*external_value_ids=*/externals.size(),
56
      /*flags=*/0,
57
      &_subgraph_ptr);
58
  TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph");
59

60
  defineAllTensorValues();
61
  defineAllNodes(graph);
62
  // at this point graph is complete, for the sake of testing preprocess at
63
  // this point we will do runtime setup and run with some default values
64
}
65

66
void XNNGraph::runGraphOnInputs(
67
    std::vector<at::Tensor> tensor_inputs,
68
    std::vector<at::Tensor> tensor_outputs) {
69
  TORCH_CHECK(
70
      _subgraph_ptr != nullptr,
71
      "run buildXNNGraph before running graph on inputs");
72
  xnn_runtime_t runtime = nullptr;
73
  xnn_status status =
74
      xnn_create_runtime_v2(_subgraph_ptr, nullptr, /*flags=*/0, &runtime);
75
  TORCH_CHECK(
76
      xnn_status_success == status,
77
      "failed to create runtime for running inputs");
78

79
  // smart pointer for runtime
80
  std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(
81
      runtime, xnn_delete_runtime);
82

83
  std::vector<xnn_external_value> external_values;
84
  TORCH_CHECK(
85
      tensor_inputs.size() == _inputs.size(),
86
      "supplied inputs does not match expected inputs");
87
  for (int i = 0; i < tensor_inputs.size(); i++) {
88
    external_values.push_back(
89
        {_val_to_ids[_inputs[i]], tensor_inputs[i].data_ptr<float>()});
90
  }
91

92
  TORCH_CHECK(
93
      tensor_outputs.size() == _outputs.size(),
94
      "supplied outputs does not match expected outputs");
95
  for (int i = 0; i < tensor_outputs.size(); i++) {
96
    external_values.push_back(
97
        {_val_to_ids[_outputs[i]], tensor_outputs[i].data_ptr<float>()});
98
  }
99
  status = xnn_setup_runtime(
100
      auto_runtime.get(), external_values.size(), external_values.data());
101
  TORCH_CHECK(xnn_status_success == status, "runtime not properly setup");
102

103
  TORCH_CHECK(xnn_status_success == xnn_invoke_runtime(auto_runtime.get()));
104
}
105

106
void XNNGraph::checkOpsToDelegate(std::shared_ptr<torch::jit::Graph>& graph) {
107
  std::unordered_set<string> unsupported_ops;
108
  DepthFirstGraphNodeIterator it(graph);
109
  Node* node = nullptr;
110
  while ((node = it.next()) != nullptr) {
111
    switch (node->kind()) {
112
      case prim::Constant:
113
      case aten::add: {
114
        break;
115
      }
116
      default: {
117
        unsupported_ops.insert(node->kind().toDisplayString());
118
      }
119
    }
120
  }
121
  std::stringstream error;
122
  for (auto itr = unsupported_ops.begin(); itr != unsupported_ops.end();
123
       itr++) {
124
    error << *itr << std::endl;
125
    ;
126
  }
127
  TORCH_CHECK(
128
      unsupported_ops.empty(),
129
      "the module contains the following unsupported ops:\n" + error.str());
130
}
131

132
std::string XNNGraph::serializedXNNGraph() {
133
  std::vector<uint32_t> input_ids;
134
  std::vector<uint32_t> output_ids;
135
  std::unordered_set<uint32_t> num_externs;
136

137
  for (auto val : _inputs) {
138
    input_ids.push_back(_val_to_ids[val]);
139
    num_externs.emplace(_val_to_ids[val]);
140
  }
141

142
  for (auto val : _outputs) {
143
    output_ids.push_back(_val_to_ids[val]);
144
    num_externs.emplace(_val_to_ids[val]);
145
  }
146

147
  return _serializer.finishAndSerialize(
148
      input_ids, output_ids, num_externs.size());
149
}
150

151
std::vector<std::vector<long>> XNNGraph::getGraphOutputShapes() {
152
  std::vector<std::vector<long>> output_shapes;
153
  for (auto val : _outputs) {
154
    auto tensor_ptr = val->type()->cast<TensorType>();
155
    std::vector<long> sizes = tensor_ptr->sizes().concrete_sizes().value();
156
    output_shapes.push_back(sizes);
157
  }
158

159
  return output_shapes;
160
}
161

162
void XNNGraph::defineAllNodes(std::shared_ptr<torch::jit::Graph>& graph) {
163
  DepthFirstGraphNodeIterator it(graph);
164
  Node* node = nullptr;
165
  while ((node = it.next()) != nullptr) {
166
    switch (node->kind()) {
167
      case prim::Constant: {
168
        break;
169
      }
170
      case aten::add: {
171
        // todo: handle alpha for aten::add
172
        uint32_t input1_id = _val_to_ids[node->inputs()[0]];
173
        uint32_t input2_id = _val_to_ids[node->inputs()[1]];
174
        TORCH_CHECK(
175
            node->inputs()[2]->type()->cast<IntType>() == 1,
176
            "non-1 alpha values not supported");
177
        uint32_t output_id = _val_to_ids[node->outputs()[0]];
178

179
        xnn_status status = xnn_define_add2(
180
            _subgraph_ptr,
181
            output_min,
182
            output_max,
183
            input1_id,
184
            input2_id,
185
            output_id,
186
            /*flags=*/0);
187
        _serializer.serializeAddNode(input1_id, input2_id, output_id, 0);
188
        TORCH_CHECK(status == xnn_status_success, "failed to create add node");
189
        break;
190
      }
191
      default: {
192
        throw std::exception();
193
        TORCH_CHECK(
194
            false,
195
            "The node of ",
196
            node->kind().toQualString(),
197
            " is not supported yet");
198
        break;
199
      }
200
    }
201
  }
202
}
203

204
void XNNGraph::defineAllTensorValues() {
205
  uint32_t external_id =
206
      std::numeric_limits<decltype(XNN_INVALID_VALUE_ID)>::min();
207
  for (auto val : _intermediate_tensors) {
208
    if (_val_to_ids.find(val) == _val_to_ids.end()) {
209
      uint32_t id = XNN_INVALID_VALUE_ID;
210

211
      // cast value to tensortype
212
      auto tensor_ptr = val->type()->cast<TensorType>();
213
      auto num_dims = tensor_ptr->dim().value();
214

215
      // create size_t* for tensor shape, casting must be done from long ->
216
      // size_t
217
      std::vector<long> sizes = tensor_ptr->sizes().concrete_sizes().value();
218
      std::vector<size_t> tensor_shape;
219
      tensor_shape.reserve(sizes.size());
220
      for (auto dim : sizes) {
221
        TORCH_CHECK(dim >= 0, "Input Dims should be unsigned");
222
        tensor_shape.push_back(static_cast<size_t>(dim));
223
      }
224

225
      // ext_id value
226
      uint32_t ext_id = XNN_INVALID_VALUE_ID;
227

228
      // update flag for if tensor is either graph input/output
229
      uint32_t flags = 0;
230

231
      // Check if value was produced by prim::Constant
232
      void* value_data = nullptr;
233
      size_t buffer_idx = 0;
234
      size_t num_bytes = 0;
235
      if (val->node()->kind() == prim::Constant) {
236
        c10::optional<IValue> constant = val->node()->t(attr::value);
237
        auto const_val = constant->toIValue().toTensor();
238
        // Need tensor data to be contiguous for serialization
239
        auto cont_const_val = const_val.contiguous();
240
        value_data = cont_const_val.data_ptr();
241

242
        num_bytes = const_val.storage().nbytes();
243
        buffer_idx = _serializer.serializeData(
244
            static_cast<const uint8_t*>(value_data), num_bytes);
245
      }
246

247
      if (isGraphInput(val) || isGraphOutput(val)) {
248
        if (isGraphInput(val)) {
249
          flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT;
250
        }
251
        if (isGraphOutput(val)) {
252
          flags |= XNN_VALUE_FLAG_EXTERNAL_OUTPUT;
253
        }
254
        ext_id = external_id++;
255
      }
256
      xnn_status status = xnn_define_tensor_value(
257
          /*subgraph=*/_subgraph_ptr,
258
          /*datatype=*/xnn_datatype_fp32,
259
          /*num_dims=*/num_dims,
260
          /*dims=*/tensor_shape.data(),
261
          /*data=*/value_data,
262
          /*external_id=*/ext_id,
263
          /*flags=*/flags,
264
          /*id_out=*/&id);
265
      TORCH_CHECK(
266
          status == xnn_status_success,
267
          "failed to define xnn_tensor_id for: " + val->debugName());
268
      _serializer.serializeTensorValue(
269
          xnn_datatype_fp32,
270
          num_dims,
271
          tensor_shape,
272
          buffer_idx,
273
          ext_id,
274
          flags,
275
          id);
276
      _val_to_ids.insert({val, id});
277
    }
278
  }
279
}
280

281
void XNNGraph::gatherTensorValues(std::shared_ptr<torch::jit::Graph>& graph) {
282
  for (auto input : graph->inputs()) {
283
    if (input->isCompleteTensor()) {
284
      _intermediate_tensors.insert(input);
285
      _inputs.push_back(input);
286
    }
287
  }
288

289
  DepthFirstGraphNodeIterator it(graph);
290
  Node* n = nullptr;
291
  while ((n = it.next()) != nullptr) {
292
    gatherNodeInputs(*n);
293
  }
294

295
  for (auto output : graph->outputs()) {
296
    if (output->isCompleteTensor()) {
297
      _intermediate_tensors.insert(output);
298
      _outputs.push_back(output);
299
    }
300
  }
301
}
302

303
void XNNGraph::gatherNodeInputs(torch::jit::Node& node) {
304
  switch (node.kind()) {
305
    case aten::add: {
306
      // this case will support all ops with only two inputs i.e. sub, add,
307
      for (auto value : node.inputs()) {
308
        if (value->isCompleteTensor()) {
309
          _intermediate_tensors.insert(value);
310
        }
311
      }
312
    }
313
  }
314
}
315

316
bool XNNGraph::isGraphInput(torch::jit::Value* val) {
317
  return std::count(_inputs.begin(), _inputs.end(), val) > 0;
318
};
319

320
bool XNNGraph::isGraphOutput(torch::jit::Value* val) {
321
  return std::count(_outputs.begin(), _outputs.end(), val) > 0;
322
};
323

324
} // namespace delegate
325
} // namespace xnnpack
326
} // namespace jit
327
} // namespace torch
328

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.