pytorch

Форк
0
105 строк · 2.8 Кб
1
// Copyright (c) Meta Platforms, Inc. and affiliates.
2
//
3
// This source code is licensed under the BSD-style license found in the
4
// LICENSE file in the root directory of this source tree.
5

6
#include <caffe2/torch/csrc/jit/backends/xnnpack/serialization/serializer.h>
7
#include <torch/csrc/jit/backends/xnnpack/serialization/schema_generated.h>
8

9
#include <sstream>
10

11
namespace torch {
12
namespace jit {
13
namespace xnnpack {
14
namespace delegate {
15

16
using namespace fb_xnnpack;
17

18
void XNNSerializer::serializeAddNode(
19
    uint32_t input1_id,
20
    uint32_t input2_id,
21
    uint32_t output_id,
22
    uint32_t flags) {
23
  const auto addNode =
24
      CreateXNNAdd(_builder, input1_id, input2_id, output_id, flags);
25
  const auto flatbufferNode =
26
      CreateXNode(_builder, XNodeUnion::XNNAdd, addNode.Union());
27
  _nodes.push_back(flatbufferNode);
28
}
29

30
size_t XNNSerializer::serializeData(const uint8_t* data_ptr, size_t num_bytes) {
31
  size_t constant_buffer_idx = 0;
32
  // Handling the tensor _values with data
33
  if (data_ptr != nullptr) {
34
    // steps:
35
    // 1. creating flatbuffer byte-vector for tensor data
36
    auto storage = _builder.CreateVector(data_ptr, num_bytes);
37

38
    // 2. put it in the common buffer
39
    constant_buffer_idx = _constantBuffer.size();
40
    _constantBuffer.emplace_back(CreateBuffer(_builder, storage));
41

42
    // 3. record size into bufferSizes
43
    _bufferSizes.push_back(num_bytes);
44
    assert(_bufferSizes.size() == _constantBuffer.size());
45
  }
46
  return constant_buffer_idx;
47
}
48

49
void XNNSerializer::serializeTensorValue(
50
    uint32_t xnn_datatype,
51
    size_t num_dims,
52
    std::vector<size_t> dims,
53
    size_t data_buffer_idx,
54
    uint32_t external_id,
55
    uint32_t flags,
56
    uint32_t id_out) {
57
  std::vector<uint32_t> serialized_dims;
58
  serialized_dims.reserve(dims.size());
59
  for (auto dim : dims) {
60
    serialized_dims.push_back(static_cast<uint32_t>(dim));
61
  }
62

63
  const auto tensorValue = CreateXNNTensorValueDirect(
64
      _builder,
65
      XNNDatatype(xnn_datatype),
66
      num_dims,
67
      &serialized_dims,
68
      data_buffer_idx,
69
      external_id,
70
      flags,
71
      id_out);
72

73
  const auto flatbufferValue =
74
      CreateXValue(_builder, XValueUnion::XNNTensorValue, tensorValue.Union());
75
  _values.push_back(flatbufferValue);
76
}
77

78
std::string XNNSerializer::finishAndSerialize(
79
    std::vector<uint32_t> input_ids,
80
    std::vector<uint32_t> output_ids,
81
    size_t num_extern_ids) {
82
  auto xnnGraph = CreateXNNGraphDirect(
83
      _builder,
84
      _version_sha1,
85
      &_nodes,
86
      &_values,
87
      num_extern_ids,
88
      &input_ids,
89
      &output_ids,
90
      &_constantBuffer,
91
      &_bufferSizes);
92

93
  _builder.Finish(xnnGraph);
94

95
  std::stringstream ss;
96
  ss.write(
97
      reinterpret_cast<char*>(_builder.GetBufferPointer()), _builder.GetSize());
98

99
  return ss.str();
100
}
101

102
} // namespace delegate
103
} // namespace xnnpack
104
} // namespace jit
105
} // namespace torch
106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.