pytorch

Форк
0
/
import.cpp 
734 строки · 25.9 Кб
1
#include <torch/csrc/jit/mobile/import.h>
2
#include <torch/csrc/jit/mobile/parse_bytecode.h>
3
#include <torch/csrc/jit/mobile/parse_operators.h>
4

5
#include <ATen/core/ivalue.h>
6
#include <ATen/core/qualified_name.h>
7
#include <c10/util/Exception.h>
8
#include <c10/util/Optional.h>
9
#include <c10/util/ScopeExit.h>
10
#include <c10/util/irange.h>
11
#include <caffe2/serialize/in_memory_adapter.h>
12
#include <caffe2/serialize/inline_container.h>
13
#include <caffe2/serialize/read_adapter_interface.h>
14
#include <caffe2/serialize/versions.h>
15
#include <torch/csrc/jit/api/compilation_unit.h>
16
#include <torch/csrc/jit/mobile/file_format.h>
17
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
18
#include <torch/csrc/jit/mobile/observer.h>
19
#include <torch/csrc/jit/mobile/type_parser.h>
20
#include <torch/csrc/jit/mobile/upgrader_mobile.h>
21
#include <torch/csrc/jit/runtime/instruction.h>
22
#include <torch/csrc/jit/serialization/import_export_constants.h>
23
#include <torch/csrc/jit/serialization/import_export_functions.h>
24
#include <torch/csrc/jit/serialization/import_read.h>
25
#include <torch/custom_class.h>
26
#include <string>
27
#include <vector>
28

29
// The import process to serialize the bytecode package.
30
// An example for bytecode.pkl of a small mobile_module looks like:
31
// (4,  # model version number (caffe2::serialize::kProducedBytecodeVersion)
32
//  # first method
33
//  (
34
//   # function name
35
//   '__torch__.m.forward',
36
//   # code
37
//   (('instructions',
38
//     (('STOREN', 1, 2),
39
//      ('DROPR', 1, 0),
40
//      ('MOVE', 2, 0),
41
//      ('OP', 0, 0),
42
//      ('RET', 0, 0))),
43
//    ('operators', (('aten::Int', 'Tensor'),)),
44
//    ('constants', ()),
45
//    ('types', ()),
46
//    ('register_size', 2)),
47
//   # schema -- optional (forward-compatible addition to version 4)
48
//   (('arguments',
49
//     ((('name', 'x'), ('type', 'Tensor'), ('default_value', 13)),
50
//      ...)),  # more args follow here
51
//    ('returns',
52
//     ((('name', ''), ('type', 'Tensor'), ('default_value', None)),
53
//      ...)),  # more return values follow here
54
//   )),
55
//  # more methods follow here
56
//  ...)
57

58
// In addition, the module debugging information can be saved
59
// in mobile_debug_handles.pkl. An example for it looks like:
60
// (4,
61
//  ('__torch__.m.forward',
62
//   (('module_debug_handles', 10))))
63
//   Here 10 is the debug handle.
64
// We also store separately and optionally callstack_debug_map.
65
// This serializes inlined callstack (InlinedCallStack data structure)
66
// corresponding to the debug handles.
67
// Callstack_debug_map serializes tuples of
68
// (int64_t(debug_handle), int64_t(source_range_tag), InlinedCallStack)
69
// source_range_tag maps to .debug_pkl files where this tag maps it to
70
// source range.
71
// InlinedCallStack is serialized as:
72
// IValue(InlinedCallStack) = {IValue(ModuleInstanceInfo),
73
// int64_t(source_range_tag), IValue(InlinedCallStack)} ModuleInstanceInfo is
74
// serialized as a tuple of (class_type_name, instance_name)
75

76
// Note that currently the backward compatibility is not supported by bytecode.
77
// This format and process need to be revisited and redesigned if we want to
78
// support backward compatibility in future.
79

80
// Note that the following function-schema fields are not supported:
81
//  - Argument::{known_length_,kwarg_only_}
82
//  - FunctionSchema::{overload_name_, is_vararg_, is_varret_}
83

84
namespace torch {
85
namespace jit {
86
using caffe2::serialize::MemoryReadAdapter;
87
using caffe2::serialize::PyTorchStreamReader;
88
using caffe2::serialize::ReadAdapterInterface;
89

90
OpCode parseOpCode(const char* str);
91

92
TypePtr resolveTypeNameMobile(
93
    const c10::QualifiedName& qn,
94
    std::shared_ptr<CompilationUnit> compilation_unit) {
95
  // HACK: first we check whether the name starts with special prefix to
96
  // tell if it's a supported pytorch class type. There are two special
97
  // prefixes. "__torch__" for nn module, and "torch.jit" from to_backend.
98
  // This is a reliable
99
  // check today, but there is no guarantee that this is the case. The
100
  // real solution is to merge type parsers so we can share class
101
  // resolution logic.
102
  static const c10::QualifiedName torchPrefix = "__torch__";
103
  static const c10::QualifiedName jitPrefix = "torch.jit";
104
  if (torchPrefix.isPrefixOf(qn) || jitPrefix.isPrefixOf(qn)) {
105
    if (compilation_unit->get_class(qn) == nullptr) {
106
      auto typeptr = ClassType::create(qn, compilation_unit, true);
107
      compilation_unit->register_type(typeptr);
108
    }
109
    return compilation_unit->get_class(qn);
110
  } else {
111
    return c10::parseType(qn.qualifiedName());
112
  }
113
}
114

115
c10::StrongTypePtr typeResolverMobile(
116
    const c10::QualifiedName& qn,
117
    const std::shared_ptr<CompilationUnit>& compilation_unit) {
118
  return c10::StrongTypePtr(
119
      compilation_unit, resolveTypeNameMobile(qn, compilation_unit));
120
}
121

122
c10::intrusive_ptr<c10::ivalue::Object> objLoaderMobile(
123
    const at::StrongTypePtr& type,
124
    const IValue& input,
125
    mobile::CompilationUnit& mobile_compilation_unit) {
126
  auto cls = type.type_->expect<at::ClassType>();
127
  auto qn = cls->name();
128
  c10::QualifiedName method_name(qn.value(), "__setstate__");
129
  auto setstate = mobile_compilation_unit.find_function(method_name);
130
  auto find_custom_class_with_setstate = [&qn]() -> c10::ClassTypePtr {
131
    auto custom_class_type = torch::jit::getCustomClass(qn->qualifiedName());
132
    if (custom_class_type && custom_class_type->findMethod("__setstate__")) {
133
      return custom_class_type;
134
    }
135
    return nullptr;
136
  };
137
  if (setstate) {
138
    auto obj = c10::ivalue::Object::create(type, 0);
139
    Stack stack({obj, input});
140
    setstate->run(stack);
141
    return obj;
142
  } else if (auto custom_class_type = find_custom_class_with_setstate()) {
143
    auto obj = c10::ivalue::Object::create(
144
        c10::StrongTypePtr(nullptr, custom_class_type), 1);
145
    Stack stack({obj, input});
146
    custom_class_type->getMethod("__setstate__").run(stack);
147
    return obj;
148
  } else {
149
    auto dict = std::move(input).toGenericDict();
150
    size_t ndict = dict.size();
151
    auto obj = c10::ivalue::Object::create(type, ndict);
152
    auto it = dict.begin();
153
    for (const auto i : c10::irange(ndict)) {
154
      cls->addOrCheckAttribute(it->key().toStringRef(), it->key().type());
155
      obj->setSlot(i, it->value());
156
      ++it;
157
    }
158
    return obj;
159
  }
160
}
161

162
bool isTensorInBytecodeArchive(
163
    caffe2::serialize::PyTorchStreamReader& stream_reader) {
164
  auto records = stream_reader.getAllRecords();
165
  for (const auto& record : records) {
166
    if (record.find("bytecode/") != std::string::npos) {
167
      return true;
168
    }
169
  }
170
  return false;
171
}
172

173
namespace {
174

175
void tryRegisterMethod(const std::vector<c10::Argument>& args, Function& func) {
176
  if (args.empty() || args[0].name() != "self") {
177
    return;
178
  }
179

180
  if (auto cls = args[0].type()->castRaw<ClassType>()) {
181
    if (C10_UNLIKELY(cls->findMethod(func.name()))) {
182
      return;
183
    }
184
    cls->addMethod(&func);
185
  }
186
}
187

188
// The deserializer class which loads the bytecode package from bc files.
189
class BytecodeDeserializer final {
190
 public:
191
  explicit BytecodeDeserializer(
192
      std::unique_ptr<PyTorchStreamReader> reader,
193
      uint64_t module_load_options = 0);
194
  mobile::Module deserialize(c10::optional<at::Device> device);
195
  mobile::Module deserialize(
196
      c10::optional<at::Device> device,
197
      ExtraFilesMap& extra_files);
198
  void deserialize_only_extra(
199
      c10::optional<at::Device> device,
200
      ExtraFilesMap& extra_files);
201

202
 private:
203
  TypePtr resolveTypeName(const c10::QualifiedName& qn);
204
  void init_upgrader(mobile::Function* function);
205
  void parseMethods(
206
      c10::ivalue::TupleElements&& vals,
207
      c10::optional<c10::ivalue::TupleElements>&& debug_handles,
208
      mobile::CompilationUnit& mcu);
209
  c10::IValue readArchive(
210
      const std::string& archive_name,
211
      std::shared_ptr<mobile::CompilationUnit> mcu);
212
  void parseFunctionSchema(
213
      const std::string& function_name,
214
      IValue* schemaTable,
215
      const int64_t& model_version,
216
      mobile::Function* function);
217
  std::shared_ptr<CompilationUnit> compilation_unit_;
218
  std::unordered_set<std::string> imported_libs_;
219
  std::unique_ptr<PyTorchStreamReader> reader_{};
220
  c10::optional<at::Device> device_;
221
  uint64_t module_load_options_;
222
  // From `version` or `.data/version` in model.ptl and it's compute
223
  // dynamically. It's used for finding the minimum required runtime to run all
224
  // operators from the given model. If it's less than the current runtime,
225
  // upgrader will be applied at loading stage.
226
  uint64_t operator_version_;
227
  uint64_t bytecode_version_;
228
};
229

230
BytecodeDeserializer::BytecodeDeserializer(
231
    std::unique_ptr<PyTorchStreamReader> reader,
232
    uint64_t module_load_options)
233
    : compilation_unit_(std::make_shared<CompilationUnit>()),
234
      reader_(std::move(reader)),
235
      module_load_options_(module_load_options) {}
236

237
TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) {
238
  return resolveTypeNameMobile(qn, compilation_unit_);
239
}
240

241
// It requires compilation_unit_ when parsing function schema. Keep it in
242
// BytecodeDeserializer. It may be refacotred later to make it independent
243
// of the specific BytecodeDeserializer, like parsing other tables
244
void BytecodeDeserializer::parseFunctionSchema(
245
    const std::string& function_name,
246
    IValue* schemaTable,
247
    const int64_t& model_version,
248
    mobile::Function* function) {
249
  // function schema
250
  if (schemaTable) { // (schema is optional for back compat)
251
    auto parseArgList = [this,
252
                         function](c10::ivalue::TupleElements&& argTables) {
253
      std::vector<c10::Argument> args;
254
      for (auto& argTable : argTables) {
255
        auto argTableElements = std::move(argTable.toTupleRef()).elements();
256
        auto name =
257
            expect_field(argTableElements, "name", BYTECODE_INDEX_ARGUMENT_NAME)
258
                .toStringRef();
259
        c10::TypePtr type = resolveTypeName(
260
            (expect_field(
261
                 argTableElements, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
262
                .toStringRef());
263
        IValue default_value = expect_field(
264
            argTableElements,
265
            "default_value",
266
            BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE);
267
        args.emplace_back(
268
            name,
269
            std::move(type),
270
            c10::nullopt /*N*/,
271
            std::move(default_value));
272
      }
273
      tryRegisterMethod(args, *function);
274
      return args;
275
    };
276
    auto schemaTableElements = std::move(schemaTable->toTupleRef()).elements();
277
    auto arg_list = std::move(expect_field(
278
                                  schemaTableElements,
279
                                  "arguments",
280
                                  BYTECODE_INDEX_SCHEMA_ARGUMENTS)
281
                                  .toTupleRef())
282
                        .elements();
283
    auto ret_list =
284
        std::move(
285
            expect_field(
286
                schemaTableElements, "returns", BYTECODE_INDEX_SCHEMA_RETURNS)
287
                .toTupleRef())
288
            .elements();
289
    c10::FunctionSchema schema(
290
        function_name,
291
        "" /*overload_name*/,
292
        parseArgList(std::move(arg_list)),
293
        parseArgList(std::move(ret_list)),
294
        false /*is_varargs*/,
295
        false /*is_varret*/);
296
    function->setSchema(std::move(schema));
297
  }
298
}
299

300
void BytecodeDeserializer::init_upgrader(mobile::Function* function) {
301
  for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
302
    function->append_function(byteCodeFunctionWithOperator.function);
303
  }
304
}
305

306
void BytecodeDeserializer::parseMethods(
307
    c10::ivalue::TupleElements&& vals,
308
    c10::optional<c10::ivalue::TupleElements>&& debug_handles,
309
    mobile::CompilationUnit& mcu) {
310
  TORCH_CHECK(!vals.empty(), "Bytecode has no elements. ");
311
  // Initialized with the version number when kProducedBytecodeVersion was
312
  // introduced. The old models (some of them already in production) without
313
  // version number are seen as version 3 (deprecated).
314
  constexpr uint64_t default_version = 0x3L;
315
  bytecode_version_ = default_version;
316
  size_t method_i_start = 0;
317
  if (vals[0].isInt()) {
318
    bytecode_version_ = vals[0].toInt();
319
    method_i_start = 1;
320
  }
321
  TORCH_CHECK(
322
      // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
323
      caffe2::serialize::kMinSupportedBytecodeVersion <= bytecode_version_ &&
324
          // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
325
          bytecode_version_ <= caffe2::serialize::kMaxSupportedBytecodeVersion,
326
      "Lite Interpreter version number does not match. ",
327
      "The model version must be between ",
328
      caffe2::serialize::kMinSupportedBytecodeVersion,
329
      " and ",
330
      caffe2::serialize::kMaxSupportedBytecodeVersion,
331
      " but the model version is ",
332
      bytecode_version_);
333

334
  if (debug_handles) {
335
    TORCH_CHECK(
336
        debug_handles->size() == vals.size(),
337
        "The numbers of bytecode values and debug info values do not match.");
338
  }
339

340
  // Process all methods in this mobile module.
341
  for (const auto i : c10::irange(method_i_start, vals.size())) {
342
    auto element = std::move(vals[i]);
343
    auto m_tuple = std::move(element.toTupleRef()).elements();
344
    const std::string& function_name = m_tuple[0].toStringRef();
345
    auto codeTableElements =
346
        std::move(std::move(m_tuple[1]).toTupleRef()).elements();
347
    IValue* schemaTable = // older files do not store function schema
348
        (bytecode_version_ > 0x4L ||
349
         (bytecode_version_ == 0x4L && m_tuple.size() >= 3))
350
        ? &m_tuple[2]
351
        : nullptr;
352
    auto function =
353
        std::make_unique<mobile::Function>(c10::QualifiedName(function_name));
354

355
    auto ins_list =
356
        std::move(
357
            expect_field(
358
                codeTableElements, "instructions", BYTECODE_INDEX_INSTRUCTION)
359
                .toTupleRef())
360
            .elements();
361
    auto ops_list =
362
        std::move(expect_field(
363
                      codeTableElements, "operators", BYTECODE_INDEX_OPERATOR)
364
                      .toTupleRef())
365
            .elements();
366
    auto consts_list =
367
        std::move(expect_field(
368
                      codeTableElements, "constants", BYTECODE_INDEX_CONSTANT)
369
                      .toTupleRef())
370
            .elements();
371
    auto types_list =
372
        std::move(expect_field(codeTableElements, "types", BYTECODE_INDEX_TYPE)
373
                      .toTupleRef())
374
            .elements();
375
    int64_t register_size =
376
        expect_field(
377
            codeTableElements, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
378
            .toInt();
379

380
    c10::ivalue::TupleElements debug_handles_m_tuple;
381
    if (debug_handles) {
382
      debug_handles_m_tuple =
383
          std::move(std::move((*debug_handles)[i]).toTupleRef()).elements();
384
    }
385
    init_upgrader(function.get());
386
    // 1. First pass all operators from models
387
    parseOperators(std::move(ops_list), module_load_options_, function.get());
388

389
    // 2. Decides if upgrader is needed
390
    bool use_upgrader =
391
        (operator_version_ < caffe2::serialize::kProducedFileFormatVersion);
392

393
    parseInstructions(
394
        function_name,
395
        std::move(ins_list),
396
        debug_handles_m_tuple,
397
        function.get());
398

399
    // 3. If upgrader is needed, change change the OP instrunction to CALL
400
    // instruction (In next PR, use_upgrader will be parsed to parseInstruction
401
    // function and do the actual change)
402
    if (use_upgrader) {
403
      applyUpgrader(function.get(), operator_version_);
404
    }
405

406
    parseConstants(consts_list, function.get());
407

408
    parseTypes(types_list, function.get());
409

410
    function->set_register_size(register_size);
411

412
    parseFunctionSchema(
413
        function_name, schemaTable, bytecode_version_, function.get());
414

415
    mcu.register_function(std::move(function));
416
  }
417
}
418

419
void BytecodeDeserializer::deserialize_only_extra(
420
    c10::optional<at::Device> device,
421
    ExtraFilesMap& extra_files) {
422
  device_ = device;
423
  for (const auto& kv : extra_files) {
424
    const std::string& key = "extra/" + kv.first;
425
    if (reader_->hasRecord(key)) {
426
      auto [meta_ptr, meta_size] = reader_->getRecord(key);
427
      extra_files[kv.first] =
428
          std::string(static_cast<char*>(meta_ptr.get()), meta_size);
429
    }
430
  }
431
}
432

433
mobile::Module BytecodeDeserializer::deserialize(
434
    c10::optional<at::Device> device,
435
    ExtraFilesMap& extra_files) {
436
  deserialize_only_extra(device, extra_files);
437
  return deserialize(device);
438
}
439

440
mobile::Module BytecodeDeserializer::deserialize(
441
    c10::optional<at::Device> device) {
442
  device_ = device;
443
  auto mcu = std::make_shared<mobile::CompilationUnit>();
444

445
  // bvals can have 2 possible formats:
446
  //
447
  // 1. Old format: bvals is an array (Tuple) of N elements, each element being
448
  // itself a Tuple(method_name, method_table).
449
  //
450
  // 2. New format: bvals is an array (Tuple) of 1+N elements. The first element
451
  // being a Tuple (int, table), and the integer stands for the bytecode version
452
  // number. The rest of the elements are the same as before.
453
  //
454
  auto bvals = std::move(readArchive("bytecode", mcu).toTupleRef()).elements();
455

456
  c10::optional<c10::ivalue::TupleElements> debug_handles;
457
  bool has_debug_handles{false};
458
  if (reader_->hasRecord("mobile_debug_handles.pkl")) {
459
    debug_handles =
460
        std::move(readArchive("mobile_debug_handles", mcu).toTupleRef())
461
            .elements();
462
    has_debug_handles = true;
463
  }
464
  operator_version_ = reader_->version();
465
  parseMethods(std::move(bvals), std::move(debug_handles), *mcu);
466
  auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu);
467
  m.set_min_operator_version(operator_version_);
468
  m.set_bytecode_version(bytecode_version_);
469
  m.setHasDebugHandles(has_debug_handles);
470
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
471
  MobileDebugTable debug_table = MobileDebugTable(reader_, compilation_unit_);
472
  m.setDebugTable(std::move(debug_table));
473
#endif
474
  return m;
475
}
476

477
c10::IValue BytecodeDeserializer::readArchive(
478
    const std::string& archive_name,
479
    std::shared_ptr<mobile::CompilationUnit> mcu) {
480
  auto type_resolver = [this](const c10::QualifiedName& qn) {
481
    return typeResolverMobile(qn, compilation_unit_);
482
  };
483

484
  auto obj_loader = [&](const at::StrongTypePtr& type, const IValue& input) {
485
    return objLoaderMobile(type, input, *mcu);
486
  };
487

488
  bool bytecode_tensor_in_constants_archive =
489
      (archive_name == "bytecode" &&
490
       !isTensorInBytecodeArchive(*reader_.get()));
491

492
  auto ivalues = torch::jit::readArchiveAndTensors(
493
      archive_name,
494
      /*pickle_prefix=*/"",
495
      /*tensor_prefix=*/
496
      bytecode_tensor_in_constants_archive ? "constants/" : "",
497
      type_resolver,
498
      obj_loader,
499
      device_,
500
      *reader_.get(),
501
      nullptr);
502
  return ivalues;
503
}
504

505
mobile::Module _load_for_mobile_impl(
506
    std::unique_ptr<ReadAdapterInterface> rai,
507
    c10::optional<c10::Device> device,
508
    ExtraFilesMap& extra_files,
509
    uint64_t module_load_options) {
510
  auto observer = torch::observerConfig().getModuleObserver();
511
  // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
512
  auto instance_key = std::rand();
513

514
  std::unordered_map<std::string, std::string> metadata_map;
515
  if (observer) {
516
    observer->onEnterLoadModel(instance_key);
517
    auto defaultExtraFileList = observer->getDefaultExtraFiles();
518
    // Add files in defaultExtraFileList to fail_extra_files and extra_files
519
    for (const auto& fileName : defaultExtraFileList) {
520
      extra_files.insert(std::make_pair(fileName, ""));
521
    }
522
  }
523

524
  const size_t model_size = rai != nullptr ? rai->size() : 0;
525
  auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
526
  if (module_load_options &
527
      MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS) {
528
    // ExtraFilesMap is serialized with a "extra/", hence it is necessary to
529
    // account for when we de-serialize de-serialized filemap key values contain
530
    // prefix and we need to remove prior to construct the map. "extra/" string
531
    // has a length of 6 characters, hence we need only sub-string 6th position
532
    // of a string. Please refer to following link for a detail:
533
    // https://www.internalfb.com/code/fbsource/[9996fcb7a6fb]/fbcode/caffe2/torch/csrc/jit/mobile/import.cpp?lines=427-434
534
    std::vector<std::string> all_files = reader->getAllRecords();
535
    for (auto& file_name : all_files) {
536
      if (file_name.find("extra/") == 0) {
537
        extra_files[file_name.substr(6)] = "";
538
      }
539
    }
540
  }
541
  BytecodeDeserializer deserializer(std::move(reader), module_load_options);
542

543
  std::string error_message;
544
  auto guard = c10::make_scope_exit([&]() {
545
    if (!observer) {
546
      return;
547
    }
548
    deserializer.deserialize_only_extra(device, extra_files);
549

550
    metadata_map = observer->processMetadataFromExtra(extra_files);
551

552
    observer->onFailLoadModel(
553
        instance_key,
554
        error_message.empty() ? "Unknown exception" : error_message.c_str(),
555
        metadata_map);
556
  });
557

558
  try {
559
    mobile::Module result = deserializer.deserialize(device, extra_files);
560
    if (observer) {
561
      // Add model_name and model_size to metadata_map
562
      extra_files.insert(std::make_pair("model_name", result.name()));
563
      extra_files.insert(
564
          std::make_pair("model_size", std::to_string(model_size)));
565
      metadata_map = observer->processMetadataFromExtra(extra_files);
566
      observer->onExitLoadModel(instance_key, metadata_map);
567
    }
568
    result.setMetadata(metadata_map);
569
    guard.release();
570
    return result;
571
  } catch (c10::Error& error) {
572
    error_message = error.what();
573
    TORCH_RETHROW(error);
574
  }
575
}
576

577
mobile::Module _load_mobile_from_bytes(
578
    const std::shared_ptr<char>& data,
579
    size_t size,
580
    c10::optional<c10::Device> device,
581
    ExtraFilesMap& extra_files,
582
    uint64_t module_load_options) {
583
  TORCH_CHECK(size >= kFileFormatHeaderSize, "Format error");
584
  auto format = getFileFormat(data.get());
585
  switch (format) {
586
    case FileFormat::ZipFileFormat: {
587
      std::unique_ptr<ReadAdapterInterface> rai =
588
          std::make_unique<MemoryReadAdapter>(data.get(), size);
589
      return _load_for_mobile_impl(
590
          std::move(rai), device, extra_files, module_load_options);
591
    }
592
    case FileFormat::FlatbufferFileFormat: {
593
      return parse_and_initialize_mobile_module(
594
          data, size, device, &extra_files);
595
    }
596
    default: {
597
      TORCH_CHECK(false, "Format error");
598
    }
599
  }
600
}
601

602
} // namespace
603

604
mobile::Module _load_for_mobile(
605
    std::istream& in,
606
    c10::optional<at::Device> device) {
607
  ExtraFilesMap extra_files;
608
  return _load_for_mobile(in, device, extra_files);
609
}
610

611
mobile::Module _load_for_mobile(
612
    const std::string& filename,
613
    c10::optional<at::Device> device) {
614
  ExtraFilesMap extra_files;
615
  return _load_for_mobile(filename, device, extra_files);
616
}
617

618
mobile::Module _load_for_mobile(
619
    std::unique_ptr<ReadAdapterInterface> rai,
620
    c10::optional<c10::Device> device) {
621
  ExtraFilesMap extra_files;
622
  return _load_for_mobile(std::move(rai), device, extra_files);
623
}
624

625
mobile::Module _load_for_mobile(
626
    std::istream& in,
627
    c10::optional<at::Device> device,
628
    ExtraFilesMap& extra_files,
629
    uint64_t module_load_options) {
630
  if (getFileFormat(in) == FileFormat::FlatbufferFileFormat) {
631
    auto [data, size] = get_stream_content(in);
632
    return _load_mobile_from_bytes(
633
        data, size, device, extra_files, module_load_options);
634
  }
635
  std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
636
  auto module = _load_for_mobile_impl(
637
      std::move(rai), device, extra_files, module_load_options);
638
  return module;
639
}
640

641
mobile::Module _load_for_mobile(
642
    const std::string& filename,
643
    c10::optional<at::Device> device,
644
    ExtraFilesMap& extra_files) {
645
  return _load_for_mobile(
646
      filename, device, extra_files, kDefaultMobileLoadOptions);
647
}
648

649
mobile::Module _load_for_mobile(
650
    const std::string& filename,
651
    c10::optional<at::Device> device,
652
    ExtraFilesMap& extra_files,
653
    uint64_t module_load_options) {
654
  auto format = getFileFormat(filename);
655

656
  if (format == FileFormat::FlatbufferFileFormat) {
657
    auto [data, size] = get_file_content(filename.c_str());
658
    return _load_mobile_from_bytes(
659
        data, size, device, extra_files, module_load_options);
660
  }
661

662
  std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
663
  return _load_for_mobile_impl(
664
      std::move(rai), device, extra_files, module_load_options);
665
}
666

667
TORCH_API mobile::Module _load_for_mobile(
668
    std::unique_ptr<ReadAdapterInterface> rai,
669
    c10::optional<c10::Device> device,
670
    ExtraFilesMap& extra_files,
671
    uint64_t module_load_options) {
672
  // TODO optimize file read for non-flatbuffer models
673
  auto [data, size] = get_rai_content(rai.get());
674
  return _load_mobile_from_bytes(
675
      data, size, device, extra_files, module_load_options);
676
}
677

678
void _load_extra_only_for_mobile(
679
    const std::string& filename,
680
    c10::optional<at::Device> device,
681
    ExtraFilesMap& extra_files) {
682
  auto observer = torch::observerConfig().getModuleObserver();
683
  // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
684
  auto instance_key = std::rand();
685
  if (observer) {
686
    observer->onEnterLoadModel(instance_key);
687
  }
688

689
  auto format = getFileFormat(filename);
690
  switch (format) {
691
    case FileFormat::ZipFileFormat: {
692
      std::unique_ptr<FileAdapter> rai =
693
          std::make_unique<FileAdapter>(filename);
694
      auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
695
      BytecodeDeserializer deserializer(std::move(reader));
696
      deserializer.deserialize_only_extra(device, extra_files);
697
      break;
698
    }
699
    case FileFormat::FlatbufferFileFormat: {
700
      // TODO: the current flatbuffers implementation will always load the
701
      // whole module including the extra files. Ideally it should be
702
      // possible to just get the extra files given data
703
      load_mobile_module_from_file(filename, c10::nullopt, &extra_files);
704
      break;
705
    }
706
    default: {
707
      TORCH_CHECK(false, "Format error");
708
    }
709
  }
710
}
711

712
namespace mobile {
713

714
std::set<std::string> _export_operator_list(
715
    torch::jit::mobile::Module& module) {
716
  std::set<std::string> operator_list;
717
  for (Method func : module.get_methods()) {
718
    const Function& function = func.function();
719
    const auto& code = function.get_code();
720
    // op_names below isn't a list of unique operator names. In fact
721
    // it can contain the same operator name many many times, so we need
722
    // to de-dup the list by adding all the operator names into
723
    // an std::set<std::string>.
724
    std::vector<c10::OperatorName> const& op_names = code.op_names_;
725
    for (auto& op_name : op_names) {
726
      operator_list.insert(toString(op_name));
727
    }
728
  }
729
  return operator_list;
730
}
731

732
} // namespace mobile
733
} // namespace jit
734
} // namespace torch
735

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.