pytorch

Форк
0
/
parse_bytecode.cpp 
199 строк · 7.3 Кб
1
#include <ATen/core/ivalue.h>
2
#include <torch/csrc/jit/mobile/code.h>
3
#include <torch/csrc/jit/mobile/parse_bytecode.h>
4
#include <torch/csrc/jit/mobile/type_parser.h>
5
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
6
#include <torch/csrc/jit/runtime/instruction.h>
7
#include <torch/csrc/jit/serialization/import_export_constants.h>
8
#include <torch/csrc/jit/serialization/import_export_functions.h>
9
#include <torch/custom_class_detail.h>
10

11
namespace torch {
12
namespace jit {
13
OpCode parseOpCode(const char* str);
14
using c10::IValue;
15

16
IValue expect_field(
17
    c10::ivalue::TupleElements& elements,
18
    const std::string& expected_name,
19
    size_t entry) {
20
  auto row = std::move(elements.at(entry)).toTuple();
21
  TORCH_INTERNAL_ASSERT(
22
      row->elements().at(0).toStringRef() == expected_name,
23
      "Expected ",
24
      expected_name,
25
      " found ",
26
      row->elements().at(0).toStringRef());
27
  return std::move(row)->elements().at(1);
28
}
29

30
namespace mobile {
31

32
namespace {
33
#define COUNT_OPCODE(_, _a) 1 +
34
constexpr size_t numOpcodes = FORALL_OPCODES(COUNT_OPCODE) 0;
35
#undef COUNT_OPCODE
36

37
// Pickled strings are memoized, so we can cache a mapping from
38
// pointers to parsed OpCodes to speed up parsing.
39
class OpCodeCache {
40
 private:
41
  // We store as void* to emphasize that we care only about the
42
  // address and should not be dereferencing these pointers.
43
  std::array<const void*, numOpcodes> keys_{};
44
  std::array<OpCode, numOpcodes> values_{};
45
  size_t usedEntries_ = 0;
46

47
 public:
48
  OpCodeCache() {
49
    memset(keys_.data(), 0, keys_.size() * sizeof(keys_[0]));
50
  }
51

52
  OpCode parse(const c10::ivalue::ConstantString& s) {
53
    const auto endIt = keys_.begin() + usedEntries_;
54
    auto it = std::find_if(
55
        keys_.begin(), endIt, [&s](const void* k) { return k == &s; });
56
    if (it == endIt) {
57
      OpCode result = parseOpCode(s.string().c_str());
58
      if (usedEntries_ < numOpcodes) {
59
        keys_[usedEntries_] = &s;
60
        values_[usedEntries_++] = result;
61
      }
62
      return result;
63
    }
64
    // NOTE: I tried implementing the transpose heuristic here to
65
    // speed up the search, but it removed the benefit of this cache.
66
    return values_[it - keys_.begin()];
67
  }
68
};
69
} // namespace
70

71
void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
72
  Code& code = function->get_code();
73
  auto& operator_version_map = getOperatorVersionMapForMobile();
74
  for (size_t i = 0; i < code.instructions_.size(); i++) {
75
    Instruction& inst = code.instructions_[i];
76
    if (inst.op == OpCode::OP) {
77
      std::string op_name = code.op_names_[inst.X].name;
78
      std::string operator_name = code.op_names_[inst.X].name +
79
          (code.op_names_[inst.X].overload_name.empty()
80
               ? ""
81
               : "." + code.op_names_[inst.X].overload_name);
82

83
      auto it = operator_version_map.find(operator_name);
84
      // Find out if there is an upgrader for this operator
85
      if (it != operator_version_map.end()) {
86
        auto upgrader_list = it->second;
87
        // Loop all upgraders for this operator, and find out if there exists a
88
        // valid upgrader. Use iteration here instead of other faster search
89
        // algorithm, because the number of upgrader per operator will be just a
90
        // few and tend to keep the code light-weight from binary size concern.
91
        for (const auto& upgrader : upgrader_list) {
92
          if (static_cast<int>(operator_version) <= upgrader.max_version &&
93
              static_cast<int>(operator_version) >= upgrader.min_version) {
94
            // If there exists a valid upgrader, change the instruction OP to
95
            // CALL, and the index will point to the according upgrader
96
            // function. All upgrader function are available in
97
            // function->get_code().functions_. It's a vector of function
98
            // pointer and they are initialized in the same order as the global
99
            // vector kUpgraderBytecode.
100
            // Instruction new_inst = inst;
101
            // new_inst.op = OpCode::CALL;
102
            // new_inst.X = upgrader.index;
103
            // code->instructions_[i] = new_inst;
104
            TORCH_CHECK(
105
                upgrader.index < static_cast<int>(code.functions_.size()),
106
                "upgrader index is, ",
107
                upgrader.index,
108
                " and it's larger than the upgrader function list length ",
109
                code.functions_.size());
110
            inst.op = OpCode::CALL;
111
            inst.X = upgrader.index;
112
          }
113
        }
114
      }
115
    }
116
  }
117
}
118

119
void parseInstructions(
120
    const std::string& function_name,
121
    c10::ivalue::TupleElements&& ins_list,
122
    c10::ivalue::TupleElements& debug_handles_m_tuple,
123
    mobile::Function* function) {
124
  c10::List<int64_t> debug_handles_list;
125
  if (!debug_handles_m_tuple.empty()) {
126
    const std::string& debug_info_function_name =
127
        debug_handles_m_tuple[0].toStringRef();
128
    TORCH_CHECK(
129
        debug_info_function_name == function_name,
130
        "The function names in the bytecode table and the debug info table do not match.");
131
    IValue& debug_handles_table = debug_handles_m_tuple[1];
132
    auto debugHandlesTableElements =
133
        std::move(*std::move(debug_handles_table).toTuple()).elements();
134
    debug_handles_list = (expect_field(
135
                              debugHandlesTableElements,
136
                              "function_debug_handles",
137
                              BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
138
                              .toTupleRef()
139
                              .elements())[0]
140
                             .toIntList();
141
    TORCH_CHECK(
142
        debug_handles_list.size() == ins_list.size(),
143
        "The numbers of instructions and debug handles strings do not match.");
144
  }
145

146
  // NOTE: this won't perform particularly well if the ins_list IValue
147
  // didn't come from unpickler and thus have its strings
148
  // interned. Consider adding a flag to bypass the cache if that
149
  // becomes an important use case.
150
  OpCodeCache opCodeCache;
151
  for (const auto j : c10::irange(ins_list.size())) {
152
    auto ins_tuple = std::move(ins_list[j]).toTuple();
153
    c10::ArrayRef<IValue> ins_item = ins_tuple->elements();
154
    TORCH_CHECK(
155
        ins_item.size() == 3,
156
        "There should be three parts in an instruction. The function name is ",
157
        function_name);
158
    OpCode op_code = opCodeCache.parse(*ins_item[0].toString());
159
    int X = ins_item[1].toInt();
160
    int N = ins_item[2].toInt();
161

162
    if (!debug_handles_list.empty()) {
163
      int64_t debug_handle = debug_handles_list[j];
164
      function->append_instruction(op_code, X, N, debug_handle);
165
    } else {
166
      function->append_instruction(op_code, X, N);
167
    }
168
  }
169
}
170

171
void parseConstants(
172
    const c10::ivalue::TupleElements& consts_list,
173
    mobile::Function* function) {
174
  for (const auto& constant : consts_list) {
175
    function->append_constant(constant);
176
  }
177
}
178
void parseTypes(
179
    const c10::ivalue::TupleElements& types_list,
180
    mobile::Function* function) {
181
  std::vector<std::string> types_string_list;
182
  types_string_list.resize(types_list.size());
183
  for (size_t i = 0; i < types_list.size(); i++) {
184
    types_string_list[i] = types_list[i].toStringRef();
185
  }
186

187
  std::vector<c10::TypePtr> types_ptr_list = c10::parseType(types_string_list);
188
  for (auto& type_ptr : types_ptr_list) {
189
    function->append_type(type_ptr);
190
  }
191
}
192

193
void parseRegisterSize(size_t rsize, mobile::Function* function) {
194
  function->set_register_size(rsize);
195
}
196

197
} // namespace mobile
198
} // namespace jit
199
} // namespace torch
200

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.