pytorch

Форк
0
/
flatbuffer_loader.cpp 
938 строк · 31.6 Кб
1
#ifdef FLATBUFFERS_VERSION_MAJOR
2
#error "flatbuffer_loader.h must not include any flatbuffers headers"
3
#endif // FLATBUFFERS_VERSION_MAJOR
4

5
#include <array>
6
#include <istream>
7
#include <memory>
8
#include <string>
9
#include <tuple>
10
#include <unordered_map>
11
#include <unordered_set>
12
#include <utility>
13
#include <vector>
14

15
#include <ATen/ATen.h>
16
#include <ATen/core/dynamic_type.h>
17
#include <ATen/core/ivalue.h>
18
#include <ATen/core/qualified_name.h>
19
#include <c10/core/CPUAllocator.h>
20
#include <c10/core/impl/alloc_cpu.h>
21
#include <c10/util/Exception.h>
22
#include <c10/util/Optional.h>
23
#include <c10/util/ScopeExit.h>
24
#include <caffe2/serialize/inline_container.h>
25
#include <torch/csrc/jit/mobile/file_format.h>
26
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
27
#include <torch/csrc/jit/mobile/function.h>
28
#include <torch/csrc/jit/mobile/import.h>
29
#include <torch/csrc/jit/mobile/interpreter.h>
30
#include <torch/csrc/jit/mobile/module.h>
31
#include <torch/csrc/jit/mobile/observer.h>
32
#include <torch/csrc/jit/mobile/type_parser.h>
33
#include <torch/csrc/jit/runtime/instruction.h>
34
#include <torch/csrc/jit/serialization/export_bytecode.h>
35
#include <torch/csrc/jit/serialization/import_export_constants.h>
36
#include <torch/csrc/jit/serialization/import_read.h>
37
#include <torch/custom_class.h>
38

39
#ifndef DISABLE_UPGRADER
40
#include <torch/csrc/jit/mobile/parse_bytecode.h>
41
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
42
#endif
43

44
#ifdef _WIN32
45
#include <malloc.h>
46
#else
47
#include <cstdlib>
48
#endif
49

50
#if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
51
#include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
52
namespace flatbuffers = flatbuffers_fbsource;
53
#define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
54
#else
55
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
56
#endif
57

58
namespace torch {
59
namespace jit {
60

61
// Our own alignment requirement does not need to be exactly the same as what
62
// flatbuffers supports, but what flatbuffers supports needs to satisfy our
63
// requirement.
64
static_assert(
65
    kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
66
    "Sizes must be compatible");
67
static_assert(
68
    (kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
69
        kFlatbufferDataAlignmentBytes,
70
    "Must be a power of 2");
71

72
namespace {
73

74
static constexpr c10::string_view kCustomClassPrefix =
75
    "__torch__.torch.classes";
76
static constexpr c10::string_view kTorchPrefix = "__torch__";
77
static constexpr c10::string_view kJitPrefix = "torch.jit";
78

79
class FlatbufferLoader final {
80
 public:
81
  FlatbufferLoader();
82

83
  typedef IValue (
84
      *IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
85
  void registerIValueParser(
86
      mobile::serialization::IValueUnion ivalue_type,
87
      IValueParser parser);
88
  mobile::Module parseModule(mobile::serialization::Module* module, char* end);
89

90
  void extractJitSourceAndConstants(
91
      ExtraFilesMap* jit_sources,
92
      std::vector<IValue>* constants);
93

94
  typedef TypePtr (*TypeResolver)(
95
      const std::string& type_str,
96
      std::shared_ptr<CompilationUnit> cu);
97

98
  void internal_registerTypeResolver(TypeResolver type_resolver);
99

100
  IValue& getIValue(uint32_t pos) {
101
    TORCH_CHECK(pos < all_ivalues_.size());
102
    return all_ivalues_[pos];
103
  }
104

105
  mobile::Function* getFunction(uint32_t pos) {
106
    return all_functions_[pos];
107
  }
108

109
  ClassTypePtr getType(uint32_t pos) {
110
    TORCH_CHECK(pos < all_types_.size());
111
    return all_types_[pos];
112
  }
113

114
  c10::Storage getStorage(uint32_t index);
115
  TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
116
  ClassTypePtr getOrCreateClassTypeForObject(
117
      const mobile::serialization::Object* object);
118

119
  const mobile::serialization::Module* getCurrentFlatbufferInput() {
120
    return module_;
121
  }
122

123
  void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
124
    should_copy_tensor_memory_ = should_copy_tensor_memory;
125
  }
126

127
  std::shared_ptr<mobile::CompilationUnit> mcu_;
128
  std::shared_ptr<CompilationUnit> cu_;
129

130
 private:
131
  IValue parseIValue(const mobile::serialization::IValue* ivalue);
132
  std::unique_ptr<mobile::Function> parseFunction(
133
      const mobile::serialization::Function* method);
134
  void parseAndPopulate(
135
      uint32_t i,
136
      const mobile::serialization::IValue* ivalue);
137

138
  std::unordered_map<uint32_t, mobile::Function*> all_functions_;
139
  std::vector<ClassTypePtr> all_types_;
140
  std::unordered_set<uint32_t> initialized_types_;
141
  std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
142
  std::vector<bool> storage_loaded_;
143
  std::vector<c10::Storage> storages_;
144
  std::vector<IValue> all_ivalues_;
145
  std::array<
146
      IValueParser,
147
      static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
148
      ivalue_parsers_;
149
  TypeResolver type_resolver_ = nullptr;
150
  mobile::serialization::Module* module_ = nullptr;
151
  bool module_parsed_ = false;
152
  bool should_copy_tensor_memory_ = false;
153
  // 0 -> mobile_ivalue_size_ elements are from the mobile module.
154
  uint32_t mobile_ivalue_size_ = 0;
155
};
156

157
IValue parseList(
158
    FlatbufferLoader&,
159
    const mobile::serialization::IValue& ivalue);
160
IValue parseTensor(
161
    FlatbufferLoader&,
162
    const mobile::serialization::IValue& ivalue);
163
IValue parseTuple(
164
    FlatbufferLoader&,
165
    const mobile::serialization::IValue& ivalue);
166
IValue parseDict(
167
    FlatbufferLoader&,
168
    const mobile::serialization::IValue& ivalue);
169
IValue parseObject(
170
    FlatbufferLoader&,
171
    const mobile::serialization::IValue& ivalue);
172
IValue parseIntList(
173
    FlatbufferLoader&,
174
    const mobile::serialization::IValue& ivalue);
175
IValue parseDoubleList(
176
    FlatbufferLoader&,
177
    const mobile::serialization::IValue& ivalue);
178
IValue parseBoolList(
179
    FlatbufferLoader&,
180
    const mobile::serialization::IValue& ivalue);
181
IValue parseBasic(
182
    FlatbufferLoader&,
183
    const mobile::serialization::IValue& ivalue);
184
IValue parseEnum(
185
    FlatbufferLoader&,
186
    const mobile::serialization::IValue& ivalue);
187

188
TypePtr resolveType(
189
    const std::string& type_string,
190
    std::shared_ptr<CompilationUnit> cu) {
191
  TypePtr type;
192
  c10::string_view type_str(type_string);
193
  if (type_str.starts_with(kCustomClassPrefix)) {
194
    type = getCustomClass(type_string);
195
    TORCH_CHECK(
196
        type, "The implementation of class ", type_string, " cannot be found.");
197
  } else if (
198
      type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) {
199
    c10::QualifiedName qn(type_string);
200
    if (cu->get_class(qn) == nullptr) {
201
      auto classtype = ClassType::create(qn, cu, true);
202
      cu->register_type(classtype);
203
      type = classtype;
204
    } else {
205
      type = cu->get_class(qn);
206
    }
207
  } else {
208
    type = c10::parseType(type_string);
209
  }
210
  return type;
211
}
212

213
FlatbufferLoader::FlatbufferLoader()
214
    : mcu_(std::make_shared<mobile::CompilationUnit>()),
215
      cu_(std::make_shared<CompilationUnit>()),
216
      ivalue_parsers_{nullptr} {
217
  registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic);
218
  registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic);
219
  registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic);
220
  registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic);
221
  registerIValueParser(
222
      mobile::serialization::IValueUnion::ComplexDouble, &parseBasic);
223
  registerIValueParser(
224
      mobile::serialization::IValueUnion::TensorMetadata, &parseTensor);
225
  registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic);
226
  registerIValueParser(mobile::serialization::IValueUnion::List, &parseList);
227
  registerIValueParser(
228
      mobile::serialization::IValueUnion::IntList, &parseIntList);
229
  registerIValueParser(
230
      mobile::serialization::IValueUnion::DoubleList, &parseDoubleList);
231
  registerIValueParser(
232
      mobile::serialization::IValueUnion::BoolList, &parseBoolList);
233
  registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple);
234
  registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict);
235
  registerIValueParser(
236
      mobile::serialization::IValueUnion::Object, &parseObject);
237
  registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic);
238
  registerIValueParser(
239
      mobile::serialization::IValueUnion::EnumValue, &parseEnum);
240
  internal_registerTypeResolver(&resolveType);
241
}
242

243
void FlatbufferLoader::registerIValueParser(
244
    mobile::serialization::IValueUnion ivalue_type,
245
    IValueParser parser) {
246
  ivalue_parsers_[static_cast<uint8_t>(ivalue_type)] = parser;
247
}
248

249
void FlatbufferLoader::internal_registerTypeResolver(
250
    TypeResolver type_resolver) {
251
  type_resolver_ = type_resolver;
252
}
253

254
void parseExtraFilesFromVector(
255
    const flatbuffers::Vector<flatbuffers::Offset<
256
        torch::jit::mobile::serialization::ExtraFile>>* files,
257
    ExtraFilesMap* extra_files) {
258
  for (uint32_t i = 0; i < files->size(); ++i) {
259
    const auto* extra_file = files->Get(i);
260
    (*extra_files)[extra_file->name()->str()] = extra_file->content()->str();
261
  }
262
}
263

264
void parseExtraFiles(
265
    mobile::serialization::Module* module,
266
    ExtraFilesMap& extra_files) {
267
  auto extra_files_offsets = module->extra_files();
268
  parseExtraFilesFromVector(extra_files_offsets, &extra_files);
269
}
270

271
void FlatbufferLoader::parseAndPopulate(
272
    uint32_t i,
273
    const mobile::serialization::IValue* ivalue) {
274
  if (const auto* func = ivalue->val_as_Function()) {
275
    auto func_ptr = parseFunction(func);
276
    all_functions_[i] = func_ptr.get();
277
    mcu_->register_function(std::move(func_ptr));
278
  } else {
279
    all_ivalues_[i] = parseIValue(ivalue);
280
  }
281
}
282

283
mobile::Module FlatbufferLoader::parseModule(
284
    mobile::serialization::Module* module,
285
    char* end) {
286
  module_ = module;
287
  all_ivalues_.clear();
288
  all_types_.clear();
289
  storages_.clear();
290
  storage_loaded_.clear();
291
  module_parsed_ = false;
292

293
  const auto* ivalues = module->ivalues();
294
  TORCH_CHECK(
295
      ivalues && module->object_types(),
296
      "Parsing flatbuffer module: Corrupted ivalues/object_types field");
297
  TORCH_CHECK(
298
      reinterpret_cast<const char*>(ivalues) < end, "Corrupted ivalues field");
299
  all_ivalues_.resize(ivalues->size());
300
  all_types_.resize(module->object_types()->size());
301
  storages_.resize(module->storage_data_size());
302
  storage_loaded_.resize(module->storage_data_size(), false);
303

304
  mobile_ivalue_size_ = module_->mobile_ivalue_size();
305
  if (mobile_ivalue_size_ == 0 || mobile_ivalue_size_ > ivalues->size()) {
306
    mobile_ivalue_size_ = ivalues->size();
307
  }
308

309
  for (uint32_t i = 0; i < mobile_ivalue_size_; i++) {
310
    const auto* ival = ivalues->Get(i);
311
    TORCH_CHECK(
312
        reinterpret_cast<const char*>(ival) < end, "Corrupted ivalue item")
313
    parseAndPopulate(i, ival);
314
  }
315
  IValue& module_ivalue = getIValue(module->state_obj());
316

317
  // register functions
318
  for (const auto& f : all_functions_) {
319
    uint32_t class_index =
320
        ivalues->Get(f.first)->val_as_Function()->class_type();
321
    ClassTypePtr class_type = all_types_[class_index];
322
    class_type->addMethod(f.second);
323
  }
324

325
  module_parsed_ = true;
326
  auto m = mobile::Module(module_ivalue.toObject(), mcu_);
327
  m.set_min_operator_version(module->operator_version());
328
  m.set_bytecode_version(module->bytecode_version());
329
  return m;
330
}
331

332
void appendUpgraderFunctions(mobile::Function* function) {
333
#ifndef DISABLE_UPGRADER
334
  for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
335
    function->append_function(byteCodeFunctionWithOperator.function);
336
  }
337
#endif
338
}
339

340
std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
341
    const mobile::serialization::Function* method) {
342
  auto function = std::make_unique<mobile::Function>(
343
      c10::QualifiedName(method->qn()->str()));
344
  // TODO(qihan) add debug handle
345
  // const auto* debug_handle = method->debug_info()->debug_handle();
346
  for (const auto* inst : *method->instructions()) {
347
    function->append_instruction(
348
        static_cast<OpCode>(inst->op()), inst->x(), inst->n());
349
  }
350

351
  for (uint32_t i : *method->constants()) {
352
    function->append_constant(getIValue(i));
353
  }
354

355
  appendUpgraderFunctions(function.get());
356
  // 2. Decides if upgrader is needed
357
  const uint32_t operator_version = module_->operator_version();
358
  bool use_upgrader =
359
      (operator_version < caffe2::serialize::kProducedFileFormatVersion);
360

361
  for (const auto* op : *method->operators()) {
362
    c10::optional<int> num_args = c10::nullopt;
363
    if (op->num_args_serialized() > -1) {
364
      num_args = op->num_args_serialized();
365
    }
366

367
    function->append_operator(
368
        op->name()->str(), op->overload_name()->str(), num_args);
369
  }
370

371
  function->initialize_operators(true);
372

373
  for (const auto i : *method->type_annotations()) {
374
    function->append_type(getOrCreateTypeAnnotations(i));
375
  }
376

377
  // 3. If upgrader is needed, change change the OP instrunction to CALL
378
  // instruction (In next PR, use_upgrader will be parsed to parseInstruction
379
  // function and do the actual change)
380
  if (use_upgrader) {
381
#ifndef DISABLE_UPGRADER
382
    applyUpgrader(function.get(), operator_version);
383
#endif
384
  }
385

386
  function->set_register_size(method->register_size());
387
  if (method->schema()) {
388
    try {
389
      auto parseArgList = [this](const auto* args_fb) {
390
        std::vector<c10::Argument> args;
391
        for (const auto* arg_tb : *args_fb) {
392
          IValue default_value = getIValue(arg_tb->default_value());
393
          TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type());
394
          auto arg = c10::Argument(
395
              arg_tb->name()->str(),
396
              std::move(type_ptr),
397
              c10::nullopt /*N*/,
398
              std::move(default_value));
399
          args.emplace_back(std::move(arg));
400
        }
401
        return args;
402
      };
403
      c10::FunctionSchema schema(
404
          method->qn()->str(),
405
          "" /*overload_name*/,
406
          parseArgList(method->schema()->arguments()),
407
          parseArgList(method->schema()->returns()),
408
          false /*is_varargs*/,
409
          false /*is_varret*/);
410

411
      function->setSchema(std::move(schema));
412
    } catch (const c10::Error& e) {
413
    }
414
  }
415
  return function;
416
}
417

418
IValue parseEnum(
419
    FlatbufferLoader& loader,
420
    const mobile::serialization::IValue& ivalue) {
421
  const auto* enum_val = ivalue.val_as_EnumValue();
422
  auto enum_type = loader.getOrCreateTypeAnnotations(enum_val->type_name())
423
                       ->cast<c10::EnumType>();
424
  AT_ASSERT(
425
      enum_type,
426
      "Enum with type: " + enum_val->type_name()->str() + " not found.");
427
  IValue val = loader.getIValue(enum_val->value());
428
  for (const auto& p : enum_type->enumNamesValues()) {
429
    if (p.second == val) {
430
      auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
431
          enum_type, p.first, p.second);
432
      return IValue(std::move(enum_holder));
433
    }
434
  }
435
  AT_ASSERT(
436
      false, "Enum with type: " + enum_val->type_name()->str() + " not found.");
437
}
438

439
IValue parseBasic(
440
    FlatbufferLoader&,
441
    const mobile::serialization::IValue& ivalue) {
442
  switch (ivalue.val_type()) {
443
    case mobile::serialization::IValueUnion::NONE:
444
      return {};
445
    case mobile::serialization::IValueUnion::Int:
446
      return ivalue.val_as_Int()->int_val();
447
    case mobile::serialization::IValueUnion::Bool:
448
      return ivalue.val_as_Bool()->bool_val();
449
    case mobile::serialization::IValueUnion::Double:
450
      return ivalue.val_as_Double()->double_val();
451
    case mobile::serialization::IValueUnion::ComplexDouble: {
452
      const auto* comp = ivalue.val_as_ComplexDouble();
453
      return c10::complex<double>(comp->real(), comp->imag());
454
    }
455
    case mobile::serialization::IValueUnion::String:
456
      return ivalue.val_as_String()->data()->str();
457
    case mobile::serialization::IValueUnion::Device: {
458
      return c10::Device(ivalue.val_as_Device()->str()->str());
459
    }
460
    default:
461
      return {};
462
  }
463
}
464

465
at::Tensor parseTensorFromMetadata(
466
    FlatbufferLoader* loader,
467
    const mobile::serialization::TensorMetadata* tensor_md) {
468
  at::ScalarType type = static_cast<at::ScalarType>(tensor_md->scalar_type());
469
  auto options = at::CPU(type).options();
470
  at::Tensor tensor;
471
  if (tensor_md->quantized_schema() != nullptr) {
472
    // is quantized
473
    const auto* schema = tensor_md->quantized_schema();
474
    auto qscheme_type = static_cast<at::QScheme>(schema->qscheme());
475
    switch (qscheme_type) {
476
      case at::kPerTensorAffine: {
477
        tensor = at::_empty_affine_quantized(
478
            {0}, options, schema->scale(), schema->zero_point());
479
      } break;
480
      case at::kPerChannelAffineFloatQParams:
481
      case at::kPerChannelAffine: {
482
        at::Tensor scales = parseTensorFromMetadata(loader, schema->scales());
483
        at::Tensor zero_points =
484
            parseTensorFromMetadata(loader, schema->zero_points());
485
        tensor = at::_empty_per_channel_affine_quantized(
486
            {0}, scales, zero_points, schema->axis(), options);
487
      } break;
488
      default:
489
        TORCH_CHECK(
490
            false,
491
            "Unsupported tensor quantization type in serialization ",
492
            toString(qscheme_type));
493
        break;
494
    }
495
  } else {
496
    tensor = at::empty({0}, options);
497
  }
498
  at::TensorImpl* impl = tensor.unsafeGetTensorImpl();
499

500
  c10::Storage storage;
501
  storage = loader->getStorage(tensor_md->storage_location_index());
502
  impl->set_storage_keep_dtype(storage);
503
  impl->set_storage_offset(tensor_md->storage_offset());
504

505
  std::vector<int64_t> size{
506
      tensor_md->sizes()->begin(), tensor_md->sizes()->end()};
507
  std::vector<int64_t> stride{
508
      tensor_md->strides()->begin(), tensor_md->strides()->end()};
509
  impl->set_sizes_and_strides(size, stride);
510
#ifndef MIN_EDGE_RUNTIME
511
  tensor = autograd::make_variable(tensor, tensor_md->requires_grad());
512
#endif
513
  return tensor;
514
}
515

516
IValue parseTensor(
517
    FlatbufferLoader& loader,
518
    const mobile::serialization::IValue& ivalue) {
519
  const mobile::serialization::TensorMetadata* tensor_md =
520
      ivalue.val_as_TensorMetadata();
521
  return parseTensorFromMetadata(&loader, tensor_md);
522
}
523

524
IValue parseList(
525
    FlatbufferLoader& loader,
526
    const mobile::serialization::IValue& ivalue) {
527
  const mobile::serialization::List* list = ivalue.val_as_List();
528
  auto res = c10::impl::GenericList(AnyType::get());
529
  for (int i : *list->items()) {
530
    res.emplace_back(loader.getIValue(i));
531
  }
532
  auto type = loader.getOrCreateTypeAnnotations(list->annotation_str());
533
  res.unsafeSetElementType(type->containedType(0));
534
  return res;
535
}
536

537
template <typename T, typename U>
538
std::vector<T> parseListNative(const U* list) {
539
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
540
  return {list->items()->begin(), list->items()->end()};
541
}
542

543
IValue parseIntList(
544
    FlatbufferLoader&,
545
    const mobile::serialization::IValue& ivalue) {
546
  const auto& list = ivalue.val_as_IntList();
547
  return parseListNative<int64_t>(list);
548
}
549

550
IValue parseDoubleList(
551
    FlatbufferLoader&,
552
    const mobile::serialization::IValue& ivalue) {
553
  const auto& list = ivalue.val_as_DoubleList();
554
  return parseListNative<double>(list);
555
}
556

557
IValue parseBoolList(
558
    FlatbufferLoader&,
559
    const mobile::serialization::IValue& ivalue) {
560
  const auto& list = ivalue.val_as_BoolList();
561
  std::vector<uint8_t> res = parseListNative<uint8_t>(list);
562
  c10::List<bool> boollist;
563
  for (auto x : res) {
564
    boollist.push_back(x);
565
  }
566
  return boollist;
567
}
568

569
IValue parseTuple(
570
    FlatbufferLoader& loader,
571
    const mobile::serialization::IValue& ivalue) {
572
  const auto& tuple = ivalue.val_as_Tuple();
573
  std::vector<IValue> res;
574
  for (int i : *tuple->items()) {
575
    res.emplace_back(loader.getIValue(i));
576
  }
577
  return c10::ivalue::Tuple::create(res);
578
}
579

580
IValue parseDict(
581
    FlatbufferLoader& loader,
582
    const mobile::serialization::IValue& ivalue) {
583
  const auto* dict = ivalue.val_as_Dict();
584
  auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get());
585
  const auto* keys = dict->keys();
586
  const auto* values = dict->values();
587
  for (size_t i = 0; i < keys->size(); ++i) {
588
    uint32_t key = keys->Get(i);
589
    uint32_t val = values->Get(i);
590
    result.insert_or_assign(loader.getIValue(key), loader.getIValue(val));
591
  }
592
  auto type = loader.getOrCreateTypeAnnotations(dict->annotation_str());
593
  result.unsafeSetKeyType(type->containedType(0));
594
  result.unsafeSetValueType(type->containedType(1));
595
  return result;
596
}
597

598
ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject(
599
    const mobile::serialization::Object* object) {
600
  auto cls = getType(object->type_index());
601
  const mobile::serialization::ObjectType* obj_type =
602
      module_->object_types()->Get(object->type_index());
603
  if (cls == nullptr) {
604
    c10::string_view qn_str(
605
        obj_type->type_name()->c_str(), obj_type->type_name()->size());
606
    if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
607
      c10::QualifiedName qn(obj_type->type_name()->str());
608
      cls = cu_->get_class(qn);
609
      if (cls == nullptr) {
610
        cls = ClassType::create(qn, cu_, true);
611
        cu_->register_type(cls);
612
      }
613
    } else {
614
      cls = c10::parseType(std::string(qn_str))->cast<ClassType>();
615
    }
616
    TORCH_CHECK(object->type_index() < all_ivalues_.size());
617
    all_types_[object->type_index()] = cls;
618

619
    if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) {
620
      for (uint32_t i = 0; i < object->attrs()->size(); i++) {
621
        IValue val = getIValue(object->attrs()->Get(i));
622
        // Need to use concrete object's field's type to set type of field.
623
        cls->addAttribute(
624
            obj_type->attr_names()->Get(i)->str(),
625
            val.type<c10::DynamicType>());
626
      }
627
    }
628
    initialized_types_.insert(object->type_index());
629
  }
630
  return cls;
631
}
632

633
IValue parseObject(
634
    FlatbufferLoader& loader,
635
    const mobile::serialization::IValue& ivalue) {
636
  const mobile::serialization::Object* object = ivalue.val_as_Object();
637
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(object != nullptr);
638
  const auto* cur_input = loader.getCurrentFlatbufferInput();
639
  const mobile::serialization::ObjectType* obj_type =
640
      cur_input->object_types()->Get(object->type_index());
641
  auto cls = loader.getOrCreateClassTypeForObject(object);
642
  Stack stack;
643
  switch (obj_type->type()) {
644
    case mobile::serialization::TypeType::CLASS_WITH_FIELD: {
645
      auto obj = c10::ivalue::Object::create(
646
          at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
647
      for (uint32_t i = 0; i < object->attrs()->size(); i++) {
648
        IValue val = loader.getIValue(object->attrs()->Get(i));
649
        obj->setSlot(i, std::move(val));
650
      }
651
      return obj;
652
    }
653
    case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: {
654
      IValue input = loader.getIValue(object->state());
655
      mobile::Function* setstate = loader.getFunction(object->setstate_func());
656
      auto obj =
657
          c10::ivalue::Object::create(at::StrongTypePtr(loader.cu_, cls), 0);
658
      stack.emplace_back(obj);
659
      stack.emplace_back(std::move(input));
660
      setstate->run(stack);
661
      return obj;
662
    }
663
    case mobile::serialization::TypeType::CUSTOM_CLASS: {
664
      auto custom_class_type =
665
          torch::jit::getCustomClass(cls->name()->qualifiedName());
666
      IValue input = loader.getIValue(object->state());
667
      auto obj = c10::ivalue::Object::create(
668
          c10::StrongTypePtr(nullptr, custom_class_type), 1);
669
      stack.emplace_back(obj);
670
      stack.emplace_back(std::move(input));
671
      custom_class_type->getMethod("__setstate__").run(stack);
672
      return obj;
673
    }
674
    default:
675
      AT_ASSERT(false, "need to be object");
676
  }
677
}
678

679
IValue FlatbufferLoader::parseIValue(
680
    const mobile::serialization::IValue* ivalue) {
681
  return ivalue_parsers_[static_cast<uint32_t>(ivalue->val_type())](
682
      *this, *ivalue);
683
}
684

685
void deleteNothing2(void*);
686
void deleteNothing2(void*) {}
687

688
c10::Storage FlatbufferLoader::getStorage(uint32_t index) {
689
  TORCH_CHECK(index < storage_loaded_.size());
690
  TORCH_CHECK(index < storages_.size());
691
  if (!storage_loaded_[index]) {
692
    auto* storage = module_->storage_data()->GetMutableObject(index);
693
    size_t size = storage->data()->size();
694

695
    at::DataPtr data;
696
    if (should_copy_tensor_memory_) {
697
      auto* allocator = at::GetCPUAllocator();
698
      data = allocator->allocate(size);
699
      memcpy(data.get(), storage->data()->data(), size);
700
    } else {
701
      void* ptr = static_cast<void*>(storage->mutable_data()->data());
702
      data = at::DataPtr(ptr, ptr, deleteNothing2, DeviceType::CPU);
703
    }
704
    storages_[index] =
705
        c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data));
706
    storage_loaded_[index] = true;
707
  }
708
  return storages_[index];
709
}
710

711
TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
712
    const flatbuffers::String* offset) {
713
  auto iter = type_annotations_.find(offset);
714
  if (iter != type_annotations_.end()) {
715
    return iter->second;
716
  }
717
  TypePtr type = type_resolver_(offset->str(), cu_);
718
  type_annotations_[offset] = type;
719
  return type;
720
}
721

722
void FlatbufferLoader::extractJitSourceAndConstants(
723
    ExtraFilesMap* jit_sources,
724
    std::vector<IValue>* constants) {
725
  AT_ASSERT(
726
      module_parsed_,
727
      "Need to first parse a flatbuffer file before extracting jit_sources");
728

729
  const auto* ivalues = module_->ivalues();
730
  for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) {
731
    const auto* ival = ivalues->Get(i);
732
    parseAndPopulate(i, ival);
733
  }
734
  // register functions
735
  for (const auto& f : all_functions_) {
736
    if (f.first >= mobile_ivalue_size_) {
737
      uint32_t class_index =
738
          ivalues->Get(f.first)->val_as_Function()->class_type();
739
      ClassTypePtr class_type = all_types_[class_index];
740
      class_type->addMethod(f.second);
741
    }
742
  }
743
  const auto* jit_constants = module_->jit_constants();
744
  for (const auto i : c10::irange(jit_constants->size())) {
745
    constants->emplace_back(getIValue(jit_constants->Get(i)));
746
  }
747
  parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
748
}
749

750
} // namespace
751

752
mobile::Module parse_and_initialize_mobile_module(
753
    void* data,
754
    size_t size,
755
    c10::optional<at::Device>,
756
    ExtraFilesMap* extra_files,
757
    bool should_copy_tensor_memory) {
758
  // TODO(T128189662): If not copying, enforce that data is aligned to
759
  // kFlatbufferDataAlignmentBytes, and add unit tests.
760

761
  // Validate Flatbuffer module before parsing.
762
  flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
763
  TORCH_CHECK(
764
      mobile::serialization::VerifyModuleBuffer(verifier),
765
      "Malformed Flatbuffer module");
766

767
  FlatbufferLoader loader;
768
  loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
769

770
  // Flatbuffer doesn't seem to have a way to provide the buffer size when
771
  // interacting with the buffer.
772
  auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
773
  auto* end = static_cast<char*>(data) + size;
774
  mobile::Module m = loader.parseModule(flatbuffer_module, end);
775
  if (extra_files != nullptr) {
776
    parseExtraFiles(flatbuffer_module, *extra_files);
777
  }
778
  return m;
779
}
780

781
mobile::Module parse_and_initialize_mobile_module(
782
    std::shared_ptr<char> data,
783
    size_t size,
784
    c10::optional<at::Device> device,
785
    ExtraFilesMap* extra_files) {
786
  mobile::Module m = parse_and_initialize_mobile_module(
787
      data.get(),
788
      size,
789
      device,
790
      extra_files,
791
      /*should_copy_tensor_memory=*/false);
792
  m.set_delete_memory(std::move(data));
793
  return m;
794
}
795

796
mobile::Module parse_and_initialize_mobile_module_for_jit(
797
    void* data,
798
    size_t size,
799
    ExtraFilesMap& jit_sources,
800
    std::vector<IValue>& jit_constants,
801
    c10::optional<at::Device>,
802
    ExtraFilesMap* extra_files) {
803
  TORCH_CHECK(
804
      mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
805
  // TODO(T128189662): Enforce that data is aligned to
806
  // kFlatbufferDataAlignmentBytes, and add unit tests.
807

808
  // Validate Flatbuffer module before parsing.
809
  flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
810
  TORCH_CHECK(
811
      mobile::serialization::VerifyModuleBuffer(verifier),
812
      "Malformed Flatbuffer module");
813

814
  FlatbufferLoader loader;
815
  auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
816
  auto* end = static_cast<char*>(data) + size;
817
  mobile::Module m = loader.parseModule(flatbuffer_module, end);
818
  if (extra_files != nullptr) {
819
    parseExtraFiles(flatbuffer_module, *extra_files);
820
  }
821

822
  loader.extractJitSourceAndConstants(&jit_sources, &jit_constants);
823
  return m;
824
}
825

826
mobile::Module load_mobile_module_from_file(
827
    const std::string& filename,
828
    c10::optional<c10::Device> device,
829
    ExtraFilesMap* extra_files) {
830
  auto [data, size] = get_file_content(filename.c_str());
831
  return parse_and_initialize_mobile_module(
832
      std::move(data), size, device, extra_files);
833
}
834

835
uint64_t get_bytecode_version(std::istream& in) {
836
  auto [data, size] = get_stream_content(in);
837
  return get_bytecode_version_from_bytes(data.get());
838
}
839

840
uint64_t get_bytecode_version(const std::string& filename) {
841
  auto [data, size] = get_file_content(filename.c_str());
842
  return get_bytecode_version_from_bytes(data.get());
843
}
844

845
uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) {
846
  TORCH_CHECK(
847
      mobile::serialization::ModuleBufferHasIdentifier(flatbuffer_content),
848
      "Format error");
849
  auto* flatbuffer_module =
850
      mobile::serialization::GetMutableModule(flatbuffer_content);
851
  return flatbuffer_module->bytecode_version();
852
}
853

854
mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
855
  auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content);
856
  mobile::ModuleInfo minfo;
857
  minfo.operator_version = ff_module->operator_version();
858
  minfo.bytecode_version = ff_module->bytecode_version();
859

860
  uint32_t mobile_ivalue_size = ff_module->mobile_ivalue_size();
861
  if (mobile_ivalue_size == 0) {
862
    mobile_ivalue_size = ff_module->ivalues()->size();
863
  }
864

865
  std::vector<std::string> type_name_list;
866
  for (uint32_t i = 0; i < mobile_ivalue_size; i++) {
867
    const auto* ival = ff_module->ivalues()->Get(i);
868
    if (const auto* func = ival->val_as_Function()) {
869
      minfo.function_names.insert(func->qn()->str());
870
      for (const auto* op : *func->operators()) {
871
        at::OperatorName opname(op->name()->str(), op->overload_name()->str());
872
        minfo.opname_to_num_args[mobile::operator_str(opname)] =
873
            op->num_args_serialized();
874
      }
875
      for (const auto* type_ann : *func->type_annotations()) {
876
        type_name_list.push_back(type_ann->str());
877
      }
878
    }
879
  }
880
  c10::TypeParser parser(type_name_list);
881
  parser.parseList();
882
  minfo.type_names = parser.getContainedTypes();
883
  return minfo;
884
}
885

886
mobile::Module load_mobile_module_from_stream_with_copy(
887
    std::istream& in,
888
    c10::optional<at::Device> device,
889
    ExtraFilesMap* extra_files) {
890
  auto [data, size] = get_stream_content(in);
891
  return parse_and_initialize_mobile_module(
892
      std::move(data), size, device, extra_files);
893
}
894

895
mobile::Module parse_flatbuffer_no_object(
896
    std::shared_ptr<char> data,
897
    size_t size,
898
    c10::optional<at::Device> device) {
899
  (void)device;
900
  (void)size;
901

902
  // Validate Flatbuffer module before parsing.
903
  flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data.get()), size);
904
  TORCH_CHECK(
905
      mobile::serialization::VerifyModuleBuffer(verifier),
906
      "Malformed Flatbuffer module");
907

908
  auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
909
  FlatbufferLoader loader;
910
  // replace parserObject with to handle only class with field case
911
  // function.
912
  loader.registerIValueParser(
913
      mobile::serialization::IValueUnion::Object,
914
      +[](FlatbufferLoader& loader,
915
          const mobile::serialization::IValue& ivalue) {
916
        const mobile::serialization::Object* object = ivalue.val_as_Object();
917
        auto cls = loader.getOrCreateClassTypeForObject(object);
918
        auto obj = c10::ivalue::Object::create(
919
            at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
920
        for (uint32_t i = 0; i < object->attrs()->size(); i++) {
921
          IValue val = loader.getIValue(object->attrs()->Get(i));
922
          obj->setSlot(i, std::move(val));
923
        }
924
        return static_cast<c10::IValue>(obj);
925
      });
926

927
  auto* end = data.get() + size;
928
  mobile::Module m = loader.parseModule(flatbuffer_module, end);
929
  m.set_delete_memory(std::move(data));
930
  return m;
931
}
932

933
bool register_flatbuffer_loader() {
934
  return true;
935
}
936

937
} // namespace jit
938
} // namespace torch
939

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.