pytorch

Форк
0
/
MobileModelRunner.cpp 
237 строк · 7.9 Кб
1
#include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
2
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
3

4
namespace torch {
5
namespace jit {
6
namespace mobile {
7

8
std::vector<std::vector<at::IValue>> MobileModelRunner::
9
    ivalue_to_bundled_inputs(const c10::IValue& bundled_inputs) {
10
  CAFFE_ENFORCE(
11
      bundled_inputs.isList(),
12
      "Expected get_all_bundled_inputs to ",
13
      "return a list but got a ",
14
      bundled_inputs.tagKind(),
15
      " instead");
16

17
  c10::List<at::IValue> all_inputs = bundled_inputs.toList();
18
  CAFFE_ENFORCE(
19
      !all_inputs.empty(),
20
      "Expected at least 1 bundled input, ",
21
      "but found none. Please use ",
22
      "torch.utils.bundled_inputs.augment_model_with_bundled_inputs to add.");
23

24
  std::vector<std::vector<at::IValue>> ret;
25
  for (at::IValue input : all_inputs) {
26
    CAFFE_ENFORCE(
27
        input.isTuple(),
28
        "Expected list element to be a tuple ",
29
        "but got a ",
30
        input.tagKind(),
31
        " instead");
32
    ret.push_back(input.toTupleRef().elements());
33
  }
34

35
  return ret;
36
}
37

38
std::unordered_map<std::string, std::string> MobileModelRunner::
39
    ivalue_to_bundled_inputs_map(const c10::IValue& bundled_inputs) {
40
  CAFFE_ENFORCE(
41
      bundled_inputs.isGenericDict(),
42
      "Expected get_bundled_inputs_functions_and_info to ",
43
      "return a dict but got a ",
44
      bundled_inputs.tagKind(),
45
      " instead");
46

47
  c10::Dict<at::IValue, at::IValue> all_inputs = bundled_inputs.toGenericDict();
48
  CAFFE_ENFORCE(
49
      !all_inputs.empty(),
50
      "Expected at least 1 function with bundled inputs, ",
51
      "but found none. Please use ",
52
      "torch.utils.bundled_inputs.augment_model_with_bundled_inputs to add.");
53

54
  std::unordered_map<std::string, std::string> ret;
55
  for (auto& input : all_inputs) {
56
    at::IValue function_name = input.key();
57
    at::IValue nested_dict = input.value();
58
    CAFFE_ENFORCE(
59
        function_name.isString(),
60
        "Expected function with inputs to be a string ",
61
        "but got a ",
62
        function_name.tagKind(),
63
        " instead");
64
    CAFFE_ENFORCE(
65
        nested_dict.isGenericDict(),
66
        "Expected function name to map to dictionary ",
67
        "but got a ",
68
        nested_dict.tagKind(),
69
        " instead");
70

71
    // Got the nested dict now need to convert that into std types
72
    c10::Dict<at::IValue, at::IValue> function_and_info_ival_dict =
73
        nested_dict.toGenericDict();
74
    std::unordered_map<std::string, std::vector<std::string>>
75
        function_and_info_dict;
76
    for (auto& entry : function_and_info_ival_dict) {
77
      at::IValue key = entry.key();
78
      at::IValue value = entry.value();
79
      CAFFE_ENFORCE(
80
          key.isString(),
81
          "Expected extra information key to be a string ",
82
          "but got a ",
83
          value.tagKind(),
84
          " instead");
85
      CAFFE_ENFORCE(
86
          value.isList(),
87
          "Expected extra information values to be a list ",
88
          "but got a ",
89
          value.tagKind(),
90
          " instead");
91

92
      // Got the value of the nested dict entry now need to convert it to std
93
      // types
94
      std::vector<std::string> data_list;
95
      c10::List<at::IValue> ival_data = value.toList();
96
      for (at::IValue data : ival_data) {
97
        CAFFE_ENFORCE(
98
            data.isString(),
99
            "Expected list element of nested dict entries to be a string ",
100
            "but got a ",
101
            data.tagKind(),
102
            " instead");
103
        data_list.push_back(data.toStringRef());
104
      }
105

106
      // Add entry into std type mapping
107
      function_and_info_dict[key.toStringRef()] = data_list;
108
    }
109

110
    // Could store the full mapping of std types, but the 'info' section isnt
111
    // needed here
112
    std::string input_function =
113
        function_and_info_dict["get_inputs_function_name"][0];
114
    ret[function_name.toStringRef()] = input_function;
115
  }
116

117
  return ret;
118
}
119

120
std::vector<std::vector<at::IValue>> MobileModelRunner::
121
    get_all_bundled_inputs() {
122
  auto has_bundled_input = module_->find_method("get_all_bundled_inputs");
123
  CAFFE_ENFORCE(
124
      has_bundled_input,
125
      "Model does not have bundled inputs. ",
126
      "Use torch.utils.bundled_inputs.augment_model_with_bundled_inputs to add.");
127

128
  c10::IValue bundled_inputs = module_->run_method("get_all_bundled_inputs");
129
  return ivalue_to_bundled_inputs(bundled_inputs);
130
}
131

132
std::unordered_map<std::string, std::vector<std::vector<at::IValue>>>
133
MobileModelRunner::get_many_functions_bundled_inputs() {
134
  auto has_bundled_input =
135
      module_->find_method("get_bundled_inputs_functions_and_info");
136
  CAFFE_ENFORCE(
137
      has_bundled_input,
138
      "Model does not have bundled inputs. ",
139
      "Use torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs to add.");
140

141
  auto ival_bundled_inputs_mapping =
142
      module_->run_method("get_bundled_inputs_functions_and_info");
143
  auto bundled_inputs_mapping =
144
      ivalue_to_bundled_inputs_map(ival_bundled_inputs_mapping);
145

146
  std::unordered_map<std::string, std::vector<std::vector<at::IValue>>> ret;
147

148
  for (auto& entry : bundled_inputs_mapping) {
149
    std::string function_name = entry.first;
150
    std::string function_to_call = entry.second;
151

152
    auto has_func_to_call = module_->find_method(function_to_call);
153
    CAFFE_ENFORCE(
154
        has_func_to_call,
155
        "Model does not have ",
156
        function_to_call,
157
        "Use torch.utils.bundled_inputs.augment_many_model_functions_with_bundled_inputs to add.");
158

159
    c10::IValue bundled_inputs = module_->run_method(function_to_call);
160
    ret[function_name] = ivalue_to_bundled_inputs(bundled_inputs);
161
  }
162
  return ret;
163
}
164

165
std::vector<at::IValue> MobileModelRunner::run_with_inputs(
166
    std::vector<std::vector<at::IValue>> const& bundled_inputs) {
167
  std::vector<at::IValue> ret;
168
  ret.reserve(bundled_inputs.size());
169
  for (std::vector<at::IValue> const& input : bundled_inputs) {
170
    ret.emplace_back(module_->forward(input));
171
  }
172
  return ret;
173
}
174

175
std::vector<at::IValue> MobileModelRunner::run_with_inputs(
176
    const std::string& function_name,
177
    std::vector<std::vector<at::IValue>> const& bundled_inputs) const {
178
  std::vector<at::IValue> ret;
179
  ret.reserve(bundled_inputs.size());
180
  auto has_bundled_input = module_->find_method(function_name);
181
  CAFFE_ENFORCE(
182
      has_bundled_input,
183
      "Model does not have the method named ",
184
      function_name,
185
      "Please ensure that it was exported correctly");
186
  for (std::vector<at::IValue> const& input : bundled_inputs) {
187
    auto func = module_->get_method(function_name);
188
    ret.emplace_back(func(input));
189
  }
190
  return ret;
191
}
192

193
void MobileModelRunner::run_argless_functions(
194
    const std::vector<std::string>& functions) {
195
  for (auto& function_name : functions) {
196
    if (module_->find_method(function_name)) {
197
      module_->run_method(function_name);
198
    }
199
  }
200
}
201

202
bool MobileModelRunner::set_has_metal_gpu_operators(
203
    std::set<std::string> const& op_list) {
204
  for (std::string const& op : op_list) {
205
    if (op.find("metal::") == 0 || op.find("metal_prepack::") == 0 ||
206
        op.find("metal_prepack_unet::") == 0) {
207
      return true;
208
    }
209
  }
210
  return false;
211
}
212

213
void MobileModelRunner::for_each_tensor_in_bundled_inputs(
214
    std::function<void(const ::at::Tensor&)> const& func) {
215
  if (has_new_style_bundled_inputs()) {
216
    // Get the bundled inputs and access the arg level ivalues stored within
217
    auto bundled_inputs_mapping = this->get_many_functions_bundled_inputs();
218

219
    // Loop over functions
220
    for (auto& entry : bundled_inputs_mapping) {
221
      std::vector<std::vector<at::IValue>> bundled_inputs = entry.second;
222
      // Loop through inputs
223
      for (const std::vector<at::IValue>& input : bundled_inputs) {
224
        // Loop through values in an input
225
        for (const at::IValue& iv : input) {
226
          for_each_tensor_in_ivalue(iv, func);
227
        }
228
      }
229
    }
230
  } else {
231
    c10::IValue iv = module_->run_method("get_all_bundled_inputs");
232
    for_each_tensor_in_ivalue(iv, func);
233
  }
234
}
235
} // namespace mobile
236
} // namespace jit
237
} // namespace torch
238

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.