pytorch

Форк
0
/
export_data.cpp 
154 строки · 4.7 Кб
1
#include <torch/csrc/jit/mobile/train/export_data.h>
2

3
#include <torch/csrc/jit/mobile/import_export_common.h>
4
#include <torch/csrc/jit/mobile/module.h>
5
#include <torch/csrc/jit/runtime/instruction.h>
6
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
7
#include <torch/csrc/jit/serialization/pickler.h>
8
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
9

10
#include <caffe2/serialize/inline_container.h>
11

12
#include <ATen/core/ivalue.h>
13
#include <ATen/core/jit_type.h>
14

15
#include <string>
16
#include <vector>
17

18
namespace torch {
19
namespace jit {
20
namespace mobile {
21

22
char const* toString(OpCode op);
23

24
namespace {
25

26
/**
27
 * Serializes an IValue using Pickle, and puts it in a file named "data.pkl"
28
 * in a ZIP wrapper.
29
 */
30
class IValuePickler final {
31
 public:
32
  explicit IValuePickler(const std::string& filename) : writer_(filename) {}
33

34
  explicit IValuePickler(
35
      const std::function<size_t(const void*, size_t)>& writer_func)
36
      : writer_(writer_func) {}
37

38
  void serialize(const IValue& object) {
39
    // Serialize just the data
40
    writeArchive("data", object);
41
  }
42

43
 private:
44
  void writeArchive(const std::string& archive_name, const IValue& value) {
45
    std::vector<char> data;
46
    // Vector to capture the run-time class types during pickling the IValues
47
    std::vector<c10::ClassTypePtr> memoizedClassTypes;
48
    Pickler data_pickle(
49
        [&](const char* buf, size_t size) {
50
          data.insert(data.end(), buf, buf + size);
51
        },
52
        nullptr,
53
        [&](const c10::ClassTypePtr& t) {
54
          return type_name_uniquer_.getUniqueName(t);
55
        },
56
        &memoizedClassTypes);
57
    data_pickle.protocol();
58
    data_pickle.pushIValue(value);
59
    data_pickle.stop();
60
    size_t i = 0;
61
    std::string prefix = archive_name + "/";
62
    for (const auto& td : data_pickle.tensorData()) {
63
      WriteableTensorData writable_td = getWriteableTensorData(td);
64
      std::string fname = prefix + c10::to_string(i++);
65
      writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
66
    }
67
    std::string fname = archive_name + ".pkl";
68
    writer_.writeRecord(fname, data.data(), data.size());
69
  }
70

71
  caffe2::serialize::PyTorchStreamWriter writer_;
72
  TypeNameUniquer type_name_uniquer_;
73
};
74

75
} // namespace
76

77
/**
78
 * Converts a map of named tensors to a c10::Dict.
79
 */
80
c10::Dict<std::string, at::Tensor> tensor_map_to_dict(
81
    const std::map<std::string, at::Tensor>& map) {
82
  c10::Dict<std::string, at::Tensor> dict;
83
  for (const auto& e : map) {
84
    dict.insert(e.first, e.second);
85
  }
86
  return dict;
87
}
88

89
/**
90
 * Returns a Module with a single attribute, with the attribute name specified
91
 * by #internal::kSavedParametersAttributeName, whose value is the provided
92
 * dict.
93
 */
94
mobile::Module tensor_dict_to_mobile(
95
    const c10::Dict<std::string, at::Tensor>& dict) {
96
  // Create an Object to back the Module, with an attribute to hold the dict.
97
  auto cu = std::make_shared<torch::jit::CompilationUnit>();
98
  // Note that the name doesn't really matter, but it must begin with
99
  // "__torch__." to be treated as a valid class when being imported.
100
  auto cls = c10::ClassType::create(
101
      "__torch__.SavedParameters", cu, /*is_module=*/true);
102
  cls->addAttribute(
103
      internal::kSavedParametersAttributeName,
104
      c10::DictType::create(dict.keyType(), dict.valueType()));
105
  auto object = c10::ivalue::Object::create(
106
      c10::StrongTypePtr(std::move(cu), std::move(cls)), /*numSlots=*/1);
107

108
  // Add the dict as an attribute.
109
  object->setAttr(internal::kSavedParametersAttributeName, dict);
110

111
  // Wrap the Object in a Module.
112
  auto mcu = std::make_shared<mobile::CompilationUnit>();
113
  return mobile::Module(object, mcu);
114
}
115

116
} // namespace mobile
117

118
void (*_save_mobile_module_to)(
119
    const mobile::Module& module,
120
    const std::function<size_t(const void*, size_t)>& writer_func) = nullptr;
121

122
void _save_parameters(
123
    const std::map<std::string, at::Tensor>& map,
124
    std::ostream& out,
125
    bool use_flatbuffer) {
126
  auto dict = mobile::tensor_map_to_dict(map);
127

128
  auto write_func = [&out](const void* buf, size_t nbytes) -> size_t {
129
    out.write(
130
        static_cast<const char*>(buf), static_cast<std::streamsize>(nbytes));
131
    return !out ? 0 : nbytes;
132
  };
133

134
  if (use_flatbuffer) {
135
    save_mobile_module_to_func(mobile::tensor_dict_to_mobile(dict), write_func);
136
  } else {
137
    // For Pickle, we only serialize the dict itself.
138
    mobile::IValuePickler pickler(write_func);
139
    pickler.serialize(dict);
140
  }
141
}
142

143
void _save_parameters(
144
    const std::map<std::string, at::Tensor>& map,
145
    const std::string& filename,
146
    bool use_flatbuffer) {
147
  auto dict = mobile::tensor_map_to_dict(map);
148

149
  std::ofstream ifile(filename);
150
  _save_parameters(map, ifile, use_flatbuffer);
151
}
152

153
} // namespace jit
154
} // namespace torch
155

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.