1
#ifdef FLATBUFFERS_VERSION_MAJOR
2
#error "flatbuffer_loader.h must not include any flatbuffers headers"
3
#endif // FLATBUFFERS_VERSION_MAJOR
10
#include <unordered_map>
11
#include <unordered_set>
16
#include <ATen/core/dynamic_type.h>
17
#include <ATen/core/ivalue.h>
18
#include <ATen/core/qualified_name.h>
19
#include <c10/core/CPUAllocator.h>
20
#include <c10/core/impl/alloc_cpu.h>
21
#include <c10/util/Exception.h>
22
#include <c10/util/Optional.h>
23
#include <c10/util/ScopeExit.h>
24
#include <caffe2/serialize/inline_container.h>
25
#include <torch/csrc/jit/mobile/file_format.h>
26
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
27
#include <torch/csrc/jit/mobile/function.h>
28
#include <torch/csrc/jit/mobile/import.h>
29
#include <torch/csrc/jit/mobile/interpreter.h>
30
#include <torch/csrc/jit/mobile/module.h>
31
#include <torch/csrc/jit/mobile/observer.h>
32
#include <torch/csrc/jit/mobile/type_parser.h>
33
#include <torch/csrc/jit/runtime/instruction.h>
34
#include <torch/csrc/jit/serialization/export_bytecode.h>
35
#include <torch/csrc/jit/serialization/import_export_constants.h>
36
#include <torch/csrc/jit/serialization/import_read.h>
37
#include <torch/custom_class.h>
39
#ifndef DISABLE_UPGRADER
40
#include <torch/csrc/jit/mobile/parse_bytecode.h>
41
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
50
#if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
51
#include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
52
namespace flatbuffers = flatbuffers_fbsource;
53
#define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
55
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
61
// Our own alignment requirement does not need to be exactly the same as what
62
// flatbuffers supports, but what flatbuffers supports needs to satisfy our
65
kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
66
"Sizes must be compatible");
68
(kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
69
kFlatbufferDataAlignmentBytes,
70
"Must be a power of 2");
74
static constexpr c10::string_view kCustomClassPrefix =
75
"__torch__.torch.classes";
76
static constexpr c10::string_view kTorchPrefix = "__torch__";
77
static constexpr c10::string_view kJitPrefix = "torch.jit";
79
class FlatbufferLoader final {
84
*IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
85
void registerIValueParser(
86
mobile::serialization::IValueUnion ivalue_type,
88
mobile::Module parseModule(mobile::serialization::Module* module, char* end);
90
void extractJitSourceAndConstants(
91
ExtraFilesMap* jit_sources,
92
std::vector<IValue>* constants);
94
typedef TypePtr (*TypeResolver)(
95
const std::string& type_str,
96
std::shared_ptr<CompilationUnit> cu);
98
void internal_registerTypeResolver(TypeResolver type_resolver);
100
IValue& getIValue(uint32_t pos) {
101
TORCH_CHECK(pos < all_ivalues_.size());
102
return all_ivalues_[pos];
105
mobile::Function* getFunction(uint32_t pos) {
106
return all_functions_[pos];
109
ClassTypePtr getType(uint32_t pos) {
110
TORCH_CHECK(pos < all_types_.size());
111
return all_types_[pos];
114
c10::Storage getStorage(uint32_t index);
115
TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
116
ClassTypePtr getOrCreateClassTypeForObject(
117
const mobile::serialization::Object* object);
119
const mobile::serialization::Module* getCurrentFlatbufferInput() {
123
void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
124
should_copy_tensor_memory_ = should_copy_tensor_memory;
127
std::shared_ptr<mobile::CompilationUnit> mcu_;
128
std::shared_ptr<CompilationUnit> cu_;
131
IValue parseIValue(const mobile::serialization::IValue* ivalue);
132
std::unique_ptr<mobile::Function> parseFunction(
133
const mobile::serialization::Function* method);
134
void parseAndPopulate(
136
const mobile::serialization::IValue* ivalue);
138
std::unordered_map<uint32_t, mobile::Function*> all_functions_;
139
std::vector<ClassTypePtr> all_types_;
140
std::unordered_set<uint32_t> initialized_types_;
141
std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
142
std::vector<bool> storage_loaded_;
143
std::vector<c10::Storage> storages_;
144
std::vector<IValue> all_ivalues_;
147
static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
149
TypeResolver type_resolver_ = nullptr;
150
mobile::serialization::Module* module_ = nullptr;
151
bool module_parsed_ = false;
152
bool should_copy_tensor_memory_ = false;
153
// 0 -> mobile_ivalue_size_ elements are from the mobile module.
154
uint32_t mobile_ivalue_size_ = 0;
159
const mobile::serialization::IValue& ivalue);
162
const mobile::serialization::IValue& ivalue);
165
const mobile::serialization::IValue& ivalue);
168
const mobile::serialization::IValue& ivalue);
171
const mobile::serialization::IValue& ivalue);
174
const mobile::serialization::IValue& ivalue);
175
IValue parseDoubleList(
177
const mobile::serialization::IValue& ivalue);
180
const mobile::serialization::IValue& ivalue);
183
const mobile::serialization::IValue& ivalue);
186
const mobile::serialization::IValue& ivalue);
189
const std::string& type_string,
190
std::shared_ptr<CompilationUnit> cu) {
192
c10::string_view type_str(type_string);
193
if (type_str.starts_with(kCustomClassPrefix)) {
194
type = getCustomClass(type_string);
196
type, "The implementation of class ", type_string, " cannot be found.");
198
type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) {
199
c10::QualifiedName qn(type_string);
200
if (cu->get_class(qn) == nullptr) {
201
auto classtype = ClassType::create(qn, cu, true);
202
cu->register_type(classtype);
205
type = cu->get_class(qn);
208
type = c10::parseType(type_string);
213
FlatbufferLoader::FlatbufferLoader()
214
: mcu_(std::make_shared<mobile::CompilationUnit>()),
215
cu_(std::make_shared<CompilationUnit>()),
216
ivalue_parsers_{nullptr} {
217
registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic);
218
registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic);
219
registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic);
220
registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic);
221
registerIValueParser(
222
mobile::serialization::IValueUnion::ComplexDouble, &parseBasic);
223
registerIValueParser(
224
mobile::serialization::IValueUnion::TensorMetadata, &parseTensor);
225
registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic);
226
registerIValueParser(mobile::serialization::IValueUnion::List, &parseList);
227
registerIValueParser(
228
mobile::serialization::IValueUnion::IntList, &parseIntList);
229
registerIValueParser(
230
mobile::serialization::IValueUnion::DoubleList, &parseDoubleList);
231
registerIValueParser(
232
mobile::serialization::IValueUnion::BoolList, &parseBoolList);
233
registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple);
234
registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict);
235
registerIValueParser(
236
mobile::serialization::IValueUnion::Object, &parseObject);
237
registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic);
238
registerIValueParser(
239
mobile::serialization::IValueUnion::EnumValue, &parseEnum);
240
internal_registerTypeResolver(&resolveType);
243
void FlatbufferLoader::registerIValueParser(
244
mobile::serialization::IValueUnion ivalue_type,
245
IValueParser parser) {
246
ivalue_parsers_[static_cast<uint8_t>(ivalue_type)] = parser;
249
void FlatbufferLoader::internal_registerTypeResolver(
250
TypeResolver type_resolver) {
251
type_resolver_ = type_resolver;
254
void parseExtraFilesFromVector(
255
const flatbuffers::Vector<flatbuffers::Offset<
256
torch::jit::mobile::serialization::ExtraFile>>* files,
257
ExtraFilesMap* extra_files) {
258
for (uint32_t i = 0; i < files->size(); ++i) {
259
const auto* extra_file = files->Get(i);
260
(*extra_files)[extra_file->name()->str()] = extra_file->content()->str();
265
mobile::serialization::Module* module,
266
ExtraFilesMap& extra_files) {
267
auto extra_files_offsets = module->extra_files();
268
parseExtraFilesFromVector(extra_files_offsets, &extra_files);
271
void FlatbufferLoader::parseAndPopulate(
273
const mobile::serialization::IValue* ivalue) {
274
if (const auto* func = ivalue->val_as_Function()) {
275
auto func_ptr = parseFunction(func);
276
all_functions_[i] = func_ptr.get();
277
mcu_->register_function(std::move(func_ptr));
279
all_ivalues_[i] = parseIValue(ivalue);
283
mobile::Module FlatbufferLoader::parseModule(
284
mobile::serialization::Module* module,
287
all_ivalues_.clear();
290
storage_loaded_.clear();
291
module_parsed_ = false;
293
const auto* ivalues = module->ivalues();
295
ivalues && module->object_types(),
296
"Parsing flatbuffer module: Corrupted ivalues/object_types field");
298
reinterpret_cast<const char*>(ivalues) < end, "Corrupted ivalues field");
299
all_ivalues_.resize(ivalues->size());
300
all_types_.resize(module->object_types()->size());
301
storages_.resize(module->storage_data_size());
302
storage_loaded_.resize(module->storage_data_size(), false);
304
mobile_ivalue_size_ = module_->mobile_ivalue_size();
305
if (mobile_ivalue_size_ == 0 || mobile_ivalue_size_ > ivalues->size()) {
306
mobile_ivalue_size_ = ivalues->size();
309
for (uint32_t i = 0; i < mobile_ivalue_size_; i++) {
310
const auto* ival = ivalues->Get(i);
312
reinterpret_cast<const char*>(ival) < end, "Corrupted ivalue item")
313
parseAndPopulate(i, ival);
315
IValue& module_ivalue = getIValue(module->state_obj());
317
// register functions
318
for (const auto& f : all_functions_) {
319
uint32_t class_index =
320
ivalues->Get(f.first)->val_as_Function()->class_type();
321
ClassTypePtr class_type = all_types_[class_index];
322
class_type->addMethod(f.second);
325
module_parsed_ = true;
326
auto m = mobile::Module(module_ivalue.toObject(), mcu_);
327
m.set_min_operator_version(module->operator_version());
328
m.set_bytecode_version(module->bytecode_version());
332
void appendUpgraderFunctions(mobile::Function* function) {
333
#ifndef DISABLE_UPGRADER
334
for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
335
function->append_function(byteCodeFunctionWithOperator.function);
340
std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
341
const mobile::serialization::Function* method) {
342
auto function = std::make_unique<mobile::Function>(
343
c10::QualifiedName(method->qn()->str()));
344
// TODO(qihan) add debug handle
345
// const auto* debug_handle = method->debug_info()->debug_handle();
346
for (const auto* inst : *method->instructions()) {
347
function->append_instruction(
348
static_cast<OpCode>(inst->op()), inst->x(), inst->n());
351
for (uint32_t i : *method->constants()) {
352
function->append_constant(getIValue(i));
355
appendUpgraderFunctions(function.get());
356
// 2. Decides if upgrader is needed
357
const uint32_t operator_version = module_->operator_version();
359
(operator_version < caffe2::serialize::kProducedFileFormatVersion);
361
for (const auto* op : *method->operators()) {
362
c10::optional<int> num_args = c10::nullopt;
363
if (op->num_args_serialized() > -1) {
364
num_args = op->num_args_serialized();
367
function->append_operator(
368
op->name()->str(), op->overload_name()->str(), num_args);
371
function->initialize_operators(true);
373
for (const auto i : *method->type_annotations()) {
374
function->append_type(getOrCreateTypeAnnotations(i));
377
// 3. If upgrader is needed, change change the OP instrunction to CALL
378
// instruction (In next PR, use_upgrader will be parsed to parseInstruction
379
// function and do the actual change)
381
#ifndef DISABLE_UPGRADER
382
applyUpgrader(function.get(), operator_version);
386
function->set_register_size(method->register_size());
387
if (method->schema()) {
389
auto parseArgList = [this](const auto* args_fb) {
390
std::vector<c10::Argument> args;
391
for (const auto* arg_tb : *args_fb) {
392
IValue default_value = getIValue(arg_tb->default_value());
393
TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type());
394
auto arg = c10::Argument(
395
arg_tb->name()->str(),
398
std::move(default_value));
399
args.emplace_back(std::move(arg));
403
c10::FunctionSchema schema(
405
"" /*overload_name*/,
406
parseArgList(method->schema()->arguments()),
407
parseArgList(method->schema()->returns()),
408
false /*is_varargs*/,
409
false /*is_varret*/);
411
function->setSchema(std::move(schema));
412
} catch (const c10::Error& e) {
419
FlatbufferLoader& loader,
420
const mobile::serialization::IValue& ivalue) {
421
const auto* enum_val = ivalue.val_as_EnumValue();
422
auto enum_type = loader.getOrCreateTypeAnnotations(enum_val->type_name())
423
->cast<c10::EnumType>();
426
"Enum with type: " + enum_val->type_name()->str() + " not found.");
427
IValue val = loader.getIValue(enum_val->value());
428
for (const auto& p : enum_type->enumNamesValues()) {
429
if (p.second == val) {
430
auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
431
enum_type, p.first, p.second);
432
return IValue(std::move(enum_holder));
436
false, "Enum with type: " + enum_val->type_name()->str() + " not found.");
441
const mobile::serialization::IValue& ivalue) {
442
switch (ivalue.val_type()) {
443
case mobile::serialization::IValueUnion::NONE:
445
case mobile::serialization::IValueUnion::Int:
446
return ivalue.val_as_Int()->int_val();
447
case mobile::serialization::IValueUnion::Bool:
448
return ivalue.val_as_Bool()->bool_val();
449
case mobile::serialization::IValueUnion::Double:
450
return ivalue.val_as_Double()->double_val();
451
case mobile::serialization::IValueUnion::ComplexDouble: {
452
const auto* comp = ivalue.val_as_ComplexDouble();
453
return c10::complex<double>(comp->real(), comp->imag());
455
case mobile::serialization::IValueUnion::String:
456
return ivalue.val_as_String()->data()->str();
457
case mobile::serialization::IValueUnion::Device: {
458
return c10::Device(ivalue.val_as_Device()->str()->str());
465
at::Tensor parseTensorFromMetadata(
466
FlatbufferLoader* loader,
467
const mobile::serialization::TensorMetadata* tensor_md) {
468
at::ScalarType type = static_cast<at::ScalarType>(tensor_md->scalar_type());
469
auto options = at::CPU(type).options();
471
if (tensor_md->quantized_schema() != nullptr) {
473
const auto* schema = tensor_md->quantized_schema();
474
auto qscheme_type = static_cast<at::QScheme>(schema->qscheme());
475
switch (qscheme_type) {
476
case at::kPerTensorAffine: {
477
tensor = at::_empty_affine_quantized(
478
{0}, options, schema->scale(), schema->zero_point());
480
case at::kPerChannelAffineFloatQParams:
481
case at::kPerChannelAffine: {
482
at::Tensor scales = parseTensorFromMetadata(loader, schema->scales());
483
at::Tensor zero_points =
484
parseTensorFromMetadata(loader, schema->zero_points());
485
tensor = at::_empty_per_channel_affine_quantized(
486
{0}, scales, zero_points, schema->axis(), options);
491
"Unsupported tensor quantization type in serialization ",
492
toString(qscheme_type));
496
tensor = at::empty({0}, options);
498
at::TensorImpl* impl = tensor.unsafeGetTensorImpl();
500
c10::Storage storage;
501
storage = loader->getStorage(tensor_md->storage_location_index());
502
impl->set_storage_keep_dtype(storage);
503
impl->set_storage_offset(tensor_md->storage_offset());
505
std::vector<int64_t> size{
506
tensor_md->sizes()->begin(), tensor_md->sizes()->end()};
507
std::vector<int64_t> stride{
508
tensor_md->strides()->begin(), tensor_md->strides()->end()};
509
impl->set_sizes_and_strides(size, stride);
510
#ifndef MIN_EDGE_RUNTIME
511
tensor = autograd::make_variable(tensor, tensor_md->requires_grad());
517
FlatbufferLoader& loader,
518
const mobile::serialization::IValue& ivalue) {
519
const mobile::serialization::TensorMetadata* tensor_md =
520
ivalue.val_as_TensorMetadata();
521
return parseTensorFromMetadata(&loader, tensor_md);
525
FlatbufferLoader& loader,
526
const mobile::serialization::IValue& ivalue) {
527
const mobile::serialization::List* list = ivalue.val_as_List();
528
auto res = c10::impl::GenericList(AnyType::get());
529
for (int i : *list->items()) {
530
res.emplace_back(loader.getIValue(i));
532
auto type = loader.getOrCreateTypeAnnotations(list->annotation_str());
533
res.unsafeSetElementType(type->containedType(0));
537
template <typename T, typename U>
538
std::vector<T> parseListNative(const U* list) {
539
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
540
return {list->items()->begin(), list->items()->end()};
545
const mobile::serialization::IValue& ivalue) {
546
const auto& list = ivalue.val_as_IntList();
547
return parseListNative<int64_t>(list);
550
IValue parseDoubleList(
552
const mobile::serialization::IValue& ivalue) {
553
const auto& list = ivalue.val_as_DoubleList();
554
return parseListNative<double>(list);
559
const mobile::serialization::IValue& ivalue) {
560
const auto& list = ivalue.val_as_BoolList();
561
std::vector<uint8_t> res = parseListNative<uint8_t>(list);
562
c10::List<bool> boollist;
564
boollist.push_back(x);
570
FlatbufferLoader& loader,
571
const mobile::serialization::IValue& ivalue) {
572
const auto& tuple = ivalue.val_as_Tuple();
573
std::vector<IValue> res;
574
for (int i : *tuple->items()) {
575
res.emplace_back(loader.getIValue(i));
577
return c10::ivalue::Tuple::create(res);
581
FlatbufferLoader& loader,
582
const mobile::serialization::IValue& ivalue) {
583
const auto* dict = ivalue.val_as_Dict();
584
auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get());
585
const auto* keys = dict->keys();
586
const auto* values = dict->values();
587
for (size_t i = 0; i < keys->size(); ++i) {
588
uint32_t key = keys->Get(i);
589
uint32_t val = values->Get(i);
590
result.insert_or_assign(loader.getIValue(key), loader.getIValue(val));
592
auto type = loader.getOrCreateTypeAnnotations(dict->annotation_str());
593
result.unsafeSetKeyType(type->containedType(0));
594
result.unsafeSetValueType(type->containedType(1));
598
ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject(
599
const mobile::serialization::Object* object) {
600
auto cls = getType(object->type_index());
601
const mobile::serialization::ObjectType* obj_type =
602
module_->object_types()->Get(object->type_index());
603
if (cls == nullptr) {
604
c10::string_view qn_str(
605
obj_type->type_name()->c_str(), obj_type->type_name()->size());
606
if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
607
c10::QualifiedName qn(obj_type->type_name()->str());
608
cls = cu_->get_class(qn);
609
if (cls == nullptr) {
610
cls = ClassType::create(qn, cu_, true);
611
cu_->register_type(cls);
614
cls = c10::parseType(std::string(qn_str))->cast<ClassType>();
616
TORCH_CHECK(object->type_index() < all_ivalues_.size());
617
all_types_[object->type_index()] = cls;
619
if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) {
620
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
621
IValue val = getIValue(object->attrs()->Get(i));
622
// Need to use concrete object's field's type to set type of field.
624
obj_type->attr_names()->Get(i)->str(),
625
val.type<c10::DynamicType>());
628
initialized_types_.insert(object->type_index());
634
FlatbufferLoader& loader,
635
const mobile::serialization::IValue& ivalue) {
636
const mobile::serialization::Object* object = ivalue.val_as_Object();
637
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(object != nullptr);
638
const auto* cur_input = loader.getCurrentFlatbufferInput();
639
const mobile::serialization::ObjectType* obj_type =
640
cur_input->object_types()->Get(object->type_index());
641
auto cls = loader.getOrCreateClassTypeForObject(object);
643
switch (obj_type->type()) {
644
case mobile::serialization::TypeType::CLASS_WITH_FIELD: {
645
auto obj = c10::ivalue::Object::create(
646
at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
647
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
648
IValue val = loader.getIValue(object->attrs()->Get(i));
649
obj->setSlot(i, std::move(val));
653
case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: {
654
IValue input = loader.getIValue(object->state());
655
mobile::Function* setstate = loader.getFunction(object->setstate_func());
657
c10::ivalue::Object::create(at::StrongTypePtr(loader.cu_, cls), 0);
658
stack.emplace_back(obj);
659
stack.emplace_back(std::move(input));
660
setstate->run(stack);
663
case mobile::serialization::TypeType::CUSTOM_CLASS: {
664
auto custom_class_type =
665
torch::jit::getCustomClass(cls->name()->qualifiedName());
666
IValue input = loader.getIValue(object->state());
667
auto obj = c10::ivalue::Object::create(
668
c10::StrongTypePtr(nullptr, custom_class_type), 1);
669
stack.emplace_back(obj);
670
stack.emplace_back(std::move(input));
671
custom_class_type->getMethod("__setstate__").run(stack);
675
AT_ASSERT(false, "need to be object");
679
IValue FlatbufferLoader::parseIValue(
680
const mobile::serialization::IValue* ivalue) {
681
return ivalue_parsers_[static_cast<uint32_t>(ivalue->val_type())](
685
void deleteNothing2(void*);
686
void deleteNothing2(void*) {}
688
c10::Storage FlatbufferLoader::getStorage(uint32_t index) {
689
TORCH_CHECK(index < storage_loaded_.size());
690
TORCH_CHECK(index < storages_.size());
691
if (!storage_loaded_[index]) {
692
auto* storage = module_->storage_data()->GetMutableObject(index);
693
size_t size = storage->data()->size();
696
if (should_copy_tensor_memory_) {
697
auto* allocator = at::GetCPUAllocator();
698
data = allocator->allocate(size);
699
memcpy(data.get(), storage->data()->data(), size);
701
void* ptr = static_cast<void*>(storage->mutable_data()->data());
702
data = at::DataPtr(ptr, ptr, deleteNothing2, DeviceType::CPU);
705
c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data));
706
storage_loaded_[index] = true;
708
return storages_[index];
711
TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
712
const flatbuffers::String* offset) {
713
auto iter = type_annotations_.find(offset);
714
if (iter != type_annotations_.end()) {
717
TypePtr type = type_resolver_(offset->str(), cu_);
718
type_annotations_[offset] = type;
722
void FlatbufferLoader::extractJitSourceAndConstants(
723
ExtraFilesMap* jit_sources,
724
std::vector<IValue>* constants) {
727
"Need to first parse a flatbuffer file before extracting jit_sources");
729
const auto* ivalues = module_->ivalues();
730
for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) {
731
const auto* ival = ivalues->Get(i);
732
parseAndPopulate(i, ival);
734
// register functions
735
for (const auto& f : all_functions_) {
736
if (f.first >= mobile_ivalue_size_) {
737
uint32_t class_index =
738
ivalues->Get(f.first)->val_as_Function()->class_type();
739
ClassTypePtr class_type = all_types_[class_index];
740
class_type->addMethod(f.second);
743
const auto* jit_constants = module_->jit_constants();
744
for (const auto i : c10::irange(jit_constants->size())) {
745
constants->emplace_back(getIValue(jit_constants->Get(i)));
747
parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
752
mobile::Module parse_and_initialize_mobile_module(
755
c10::optional<at::Device>,
756
ExtraFilesMap* extra_files,
757
bool should_copy_tensor_memory) {
758
// TODO(T128189662): If not copying, enforce that data is aligned to
759
// kFlatbufferDataAlignmentBytes, and add unit tests.
761
// Validate Flatbuffer module before parsing.
762
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
764
mobile::serialization::VerifyModuleBuffer(verifier),
765
"Malformed Flatbuffer module");
767
FlatbufferLoader loader;
768
loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
770
// Flatbuffer doesn't seem to have a way to provide the buffer size when
771
// interacting with the buffer.
772
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
773
auto* end = static_cast<char*>(data) + size;
774
mobile::Module m = loader.parseModule(flatbuffer_module, end);
775
if (extra_files != nullptr) {
776
parseExtraFiles(flatbuffer_module, *extra_files);
781
mobile::Module parse_and_initialize_mobile_module(
782
std::shared_ptr<char> data,
784
c10::optional<at::Device> device,
785
ExtraFilesMap* extra_files) {
786
mobile::Module m = parse_and_initialize_mobile_module(
791
/*should_copy_tensor_memory=*/false);
792
m.set_delete_memory(std::move(data));
796
mobile::Module parse_and_initialize_mobile_module_for_jit(
799
ExtraFilesMap& jit_sources,
800
std::vector<IValue>& jit_constants,
801
c10::optional<at::Device>,
802
ExtraFilesMap* extra_files) {
804
mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
805
// TODO(T128189662): Enforce that data is aligned to
806
// kFlatbufferDataAlignmentBytes, and add unit tests.
808
// Validate Flatbuffer module before parsing.
809
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
811
mobile::serialization::VerifyModuleBuffer(verifier),
812
"Malformed Flatbuffer module");
814
FlatbufferLoader loader;
815
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
816
auto* end = static_cast<char*>(data) + size;
817
mobile::Module m = loader.parseModule(flatbuffer_module, end);
818
if (extra_files != nullptr) {
819
parseExtraFiles(flatbuffer_module, *extra_files);
822
loader.extractJitSourceAndConstants(&jit_sources, &jit_constants);
826
mobile::Module load_mobile_module_from_file(
827
const std::string& filename,
828
c10::optional<c10::Device> device,
829
ExtraFilesMap* extra_files) {
830
auto [data, size] = get_file_content(filename.c_str());
831
return parse_and_initialize_mobile_module(
832
std::move(data), size, device, extra_files);
835
uint64_t get_bytecode_version(std::istream& in) {
836
auto [data, size] = get_stream_content(in);
837
return get_bytecode_version_from_bytes(data.get());
840
uint64_t get_bytecode_version(const std::string& filename) {
841
auto [data, size] = get_file_content(filename.c_str());
842
return get_bytecode_version_from_bytes(data.get());
845
uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) {
847
mobile::serialization::ModuleBufferHasIdentifier(flatbuffer_content),
849
auto* flatbuffer_module =
850
mobile::serialization::GetMutableModule(flatbuffer_content);
851
return flatbuffer_module->bytecode_version();
854
mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
855
auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content);
856
mobile::ModuleInfo minfo;
857
minfo.operator_version = ff_module->operator_version();
858
minfo.bytecode_version = ff_module->bytecode_version();
860
uint32_t mobile_ivalue_size = ff_module->mobile_ivalue_size();
861
if (mobile_ivalue_size == 0) {
862
mobile_ivalue_size = ff_module->ivalues()->size();
865
std::vector<std::string> type_name_list;
866
for (uint32_t i = 0; i < mobile_ivalue_size; i++) {
867
const auto* ival = ff_module->ivalues()->Get(i);
868
if (const auto* func = ival->val_as_Function()) {
869
minfo.function_names.insert(func->qn()->str());
870
for (const auto* op : *func->operators()) {
871
at::OperatorName opname(op->name()->str(), op->overload_name()->str());
872
minfo.opname_to_num_args[mobile::operator_str(opname)] =
873
op->num_args_serialized();
875
for (const auto* type_ann : *func->type_annotations()) {
876
type_name_list.push_back(type_ann->str());
880
c10::TypeParser parser(type_name_list);
882
minfo.type_names = parser.getContainedTypes();
886
mobile::Module load_mobile_module_from_stream_with_copy(
888
c10::optional<at::Device> device,
889
ExtraFilesMap* extra_files) {
890
auto [data, size] = get_stream_content(in);
891
return parse_and_initialize_mobile_module(
892
std::move(data), size, device, extra_files);
895
mobile::Module parse_flatbuffer_no_object(
896
std::shared_ptr<char> data,
898
c10::optional<at::Device> device) {
902
// Validate Flatbuffer module before parsing.
903
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data.get()), size);
905
mobile::serialization::VerifyModuleBuffer(verifier),
906
"Malformed Flatbuffer module");
908
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
909
FlatbufferLoader loader;
910
// replace parserObject with to handle only class with field case
912
loader.registerIValueParser(
913
mobile::serialization::IValueUnion::Object,
914
+[](FlatbufferLoader& loader,
915
const mobile::serialization::IValue& ivalue) {
916
const mobile::serialization::Object* object = ivalue.val_as_Object();
917
auto cls = loader.getOrCreateClassTypeForObject(object);
918
auto obj = c10::ivalue::Object::create(
919
at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
920
for (uint32_t i = 0; i < object->attrs()->size(); i++) {
921
IValue val = loader.getIValue(object->attrs()->Get(i));
922
obj->setSlot(i, std::move(val));
924
return static_cast<c10::IValue>(obj);
927
auto* end = data.get() + size;
928
mobile::Module m = loader.parseModule(flatbuffer_module, end);
929
m.set_delete_memory(std::move(data));
933
bool register_flatbuffer_loader() {