1
#include <ATen/core/ivalue.h>
2
#include <torch/csrc/jit/mobile/code.h>
3
#include <torch/csrc/jit/mobile/parse_bytecode.h>
4
#include <torch/csrc/jit/mobile/type_parser.h>
5
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
6
#include <torch/csrc/jit/runtime/instruction.h>
7
#include <torch/csrc/jit/serialization/import_export_constants.h>
8
#include <torch/csrc/jit/serialization/import_export_functions.h>
9
#include <torch/custom_class_detail.h>
13
OpCode parseOpCode(const char* str);
17
c10::ivalue::TupleElements& elements,
18
const std::string& expected_name,
20
auto row = std::move(elements.at(entry)).toTuple();
21
TORCH_INTERNAL_ASSERT(
22
row->elements().at(0).toStringRef() == expected_name,
26
row->elements().at(0).toStringRef());
27
return std::move(row)->elements().at(1);
33
#define COUNT_OPCODE(_, _a) 1 +
34
constexpr size_t numOpcodes = FORALL_OPCODES(COUNT_OPCODE) 0;
37
// Pickled strings are memoized, so we can cache a mapping from
38
// pointers to parsed OpCodes to speed up parsing.
41
// We store as void* to emphasize that we care only about the
42
// address and should not be dereferencing these pointers.
43
std::array<const void*, numOpcodes> keys_{};
44
std::array<OpCode, numOpcodes> values_{};
45
size_t usedEntries_ = 0;
49
memset(keys_.data(), 0, keys_.size() * sizeof(keys_[0]));
52
OpCode parse(const c10::ivalue::ConstantString& s) {
53
const auto endIt = keys_.begin() + usedEntries_;
54
auto it = std::find_if(
55
keys_.begin(), endIt, [&s](const void* k) { return k == &s; });
57
OpCode result = parseOpCode(s.string().c_str());
58
if (usedEntries_ < numOpcodes) {
59
keys_[usedEntries_] = &s;
60
values_[usedEntries_++] = result;
64
// NOTE: I tried implementing the transpose heuristic here to
65
// speed up the search, but it removed the benefit of this cache.
66
return values_[it - keys_.begin()];
71
void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
72
Code& code = function->get_code();
73
auto& operator_version_map = getOperatorVersionMapForMobile();
74
for (size_t i = 0; i < code.instructions_.size(); i++) {
75
Instruction& inst = code.instructions_[i];
76
if (inst.op == OpCode::OP) {
77
std::string op_name = code.op_names_[inst.X].name;
78
std::string operator_name = code.op_names_[inst.X].name +
79
(code.op_names_[inst.X].overload_name.empty()
81
: "." + code.op_names_[inst.X].overload_name);
83
auto it = operator_version_map.find(operator_name);
84
// Find out if there is an upgrader for this operator
85
if (it != operator_version_map.end()) {
86
auto upgrader_list = it->second;
87
// Loop all upgraders for this operator, and find out if there exists a
88
// valid upgrader. Use iteration here instead of other faster search
89
// algorithm, because the number of upgrader per operator will be just a
90
// few and tend to keep the code light-weight from binary size concern.
91
for (const auto& upgrader : upgrader_list) {
92
if (static_cast<int>(operator_version) <= upgrader.max_version &&
93
static_cast<int>(operator_version) >= upgrader.min_version) {
94
// If there exists a valid upgrader, change the instruction OP to
95
// CALL, and the index will point to the according upgrader
96
// function. All upgrader function are available in
97
// function->get_code().functions_. It's a vector of function
98
// pointer and they are initialized in the same order as the global
99
// vector kUpgraderBytecode.
100
// Instruction new_inst = inst;
101
// new_inst.op = OpCode::CALL;
102
// new_inst.X = upgrader.index;
103
// code->instructions_[i] = new_inst;
105
upgrader.index < static_cast<int>(code.functions_.size()),
106
"upgrader index is, ",
108
" and it's larger than the upgrader function list length ",
109
code.functions_.size());
110
inst.op = OpCode::CALL;
111
inst.X = upgrader.index;
119
void parseInstructions(
120
const std::string& function_name,
121
c10::ivalue::TupleElements&& ins_list,
122
c10::ivalue::TupleElements& debug_handles_m_tuple,
123
mobile::Function* function) {
124
c10::List<int64_t> debug_handles_list;
125
if (!debug_handles_m_tuple.empty()) {
126
const std::string& debug_info_function_name =
127
debug_handles_m_tuple[0].toStringRef();
129
debug_info_function_name == function_name,
130
"The function names in the bytecode table and the debug info table do not match.");
131
IValue& debug_handles_table = debug_handles_m_tuple[1];
132
auto debugHandlesTableElements =
133
std::move(*std::move(debug_handles_table).toTuple()).elements();
134
debug_handles_list = (expect_field(
135
debugHandlesTableElements,
136
"function_debug_handles",
137
BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
142
debug_handles_list.size() == ins_list.size(),
143
"The numbers of instructions and debug handles strings do not match.");
146
// NOTE: this won't perform particularly well if the ins_list IValue
147
// didn't come from unpickler and thus have its strings
148
// interned. Consider adding a flag to bypass the cache if that
149
// becomes an important use case.
150
OpCodeCache opCodeCache;
151
for (const auto j : c10::irange(ins_list.size())) {
152
auto ins_tuple = std::move(ins_list[j]).toTuple();
153
c10::ArrayRef<IValue> ins_item = ins_tuple->elements();
155
ins_item.size() == 3,
156
"There should be three parts in an instruction. The function name is ",
158
OpCode op_code = opCodeCache.parse(*ins_item[0].toString());
159
int X = ins_item[1].toInt();
160
int N = ins_item[2].toInt();
162
if (!debug_handles_list.empty()) {
163
int64_t debug_handle = debug_handles_list[j];
164
function->append_instruction(op_code, X, N, debug_handle);
166
function->append_instruction(op_code, X, N);
172
const c10::ivalue::TupleElements& consts_list,
173
mobile::Function* function) {
174
for (const auto& constant : consts_list) {
175
function->append_constant(constant);
179
const c10::ivalue::TupleElements& types_list,
180
mobile::Function* function) {
181
std::vector<std::string> types_string_list;
182
types_string_list.resize(types_list.size());
183
for (size_t i = 0; i < types_list.size(); i++) {
184
types_string_list[i] = types_list[i].toStringRef();
187
std::vector<c10::TypePtr> types_ptr_list = c10::parseType(types_string_list);
188
for (auto& type_ptr : types_ptr_list) {
189
function->append_type(type_ptr);
193
void parseRegisterSize(size_t rsize, mobile::Function* function) {
194
function->set_register_size(rsize);