pytorch

Форк
0
121 строка · 4.2 Кб
1
// Copyright (c) Meta Platforms, Inc. and affiliates.
2
//
3
// This source code is licensed under the BSD-style license found in the
4
// LICENSE file in the root directory of this source tree.
5

6
#include <caffe2/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h>
7
#include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
8

9
#include <ATen/Utils.h>
10

11
namespace torch {
12
namespace jit {
13
namespace xnnpack {
14
namespace delegate {
15

16
void XNNCompiler::compileModel(
17
    const void* buffer_pointer,
18
    size_t num_bytes,
19
    XNNExecutor* executor) {
20
  auto output_min = -std::numeric_limits<float>::infinity();
21
  auto output_max = std::numeric_limits<float>::infinity();
22

23
  auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(buffer_pointer);
24
  // initialize xnnpack
25
  xnn_status status = xnn_initialize(/*allocator =*/nullptr);
26
  TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack");
27

28
  // create xnnpack subgraph
29
  xnn_subgraph_t subgraph_ptr = nullptr;
30
  status = xnn_create_subgraph(
31
      /*external_value_ids=*/flatbuffer_graph->num_externs(),
32
      /*flags=*/0,
33
      &subgraph_ptr);
34
  TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph");
35

36
  // mapping from old ids to new created value ids
37
  // The old ids that were serialied were generated AoT, since
38
  // we are re-defining tensor values, the defined IDs could be
39
  // different from the ones generated AoT, as a result, we need
40
  // a new mapping from the old ids to the newly created ones
41
  std::unordered_map<uint32_t, uint32_t> remapped_ids;
42

43
  for (auto value : *flatbuffer_graph->xvalues()) {
44
    switch (value->xvalue_type()) {
45
      case fb_xnnpack::XValueUnion::XNNTensorValue: {
46
        auto tensor_value = value->xvalue_as_XNNTensorValue();
47

48
        std::vector<size_t> dims_data;
49
        for (auto dim : *tensor_value->dims()) {
50
          dims_data.push_back(static_cast<size_t>(dim));
51
        }
52

53
        uint32_t id = XNN_INVALID_VALUE_ID;
54
        const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
55
        auto buffer_idx = tensor_value->constant_buffer_idx();
56
        const auto buffer_ptr = buffer_idx == 0
57
            ? nullptr
58
            : constant_buffer[buffer_idx]->storage()->data();
59
        status = xnn_define_tensor_value(
60
            /*subgraph=*/subgraph_ptr,
61
            /*datatype=*/xnn_datatype_fp32,
62
            /*num_dims=*/tensor_value->num_dims(),
63
            /*dims=*/dims_data.data(),
64
            /*data=*/buffer_ptr,
65
            /*external_id=*/tensor_value->external_id(),
66
            /*flags=*/tensor_value->flags(),
67
            /*id_out=*/&id);
68
        TORCH_CHECK(
69
            status == xnn_status_success,
70
            "Failed to define tensor values in graph")
71
        // map serialized id to newly generated id
72
        remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id));
73
        break;
74
      }
75
      default: {
76
        TORCH_CHECK(false, "Unhandled value type found in deserialization");
77
      }
78
    }
79
  }
80

81
  for (auto node : *flatbuffer_graph->xnodes()) {
82
    switch (node->xnode_type()) {
83
      case fb_xnnpack::XNodeUnion::XNNAdd: {
84
        auto graph_node = node->xnode_as_XNNAdd();
85
        status = xnn_define_add2(
86
            subgraph_ptr,
87
            output_min,
88
            output_max,
89
            remapped_ids.at(graph_node->input1_id()),
90
            remapped_ids.at(graph_node->input2_id()),
91
            remapped_ids.at(graph_node->output_id()),
92
            graph_node->flags());
93
        TORCH_CHECK(status == xnn_status_success, "Failed to create add node")
94
        break;
95
      }
96
      default:
97
        TORCH_CHECK(false, "Unhandled node type found in deserialization");
98
    }
99
  }
100

101
  xnn_runtime_t runtime_ptr = nullptr;
102
  status = xnn_create_runtime_v2(subgraph_ptr, nullptr, 0, &runtime_ptr);
103
  TORCH_CHECK(xnn_status_success == status);
104

105
  executor->runtime_ =
106
      std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)>(
107
          runtime_ptr, xnn_delete_runtime);
108

109
  for (auto old_id : *flatbuffer_graph->input_ids()) {
110
    executor->input_ids_.emplace_back(remapped_ids.at(old_id));
111
  }
112

113
  for (auto old_id : *flatbuffer_graph->output_ids()) {
114
    executor->output_ids_.emplace_back(remapped_ids.at(old_id));
115
  }
116
};
117

118
} // namespace delegate
119
} // namespace xnnpack
120
} // namespace jit
121
} // namespace torch
122

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.