pytorch

Форк
0
/
context.cpp 
348 строк · 10.5 Кб
1
#include <torch/csrc/jit/mobile/nnc/context.h>
2

3
#include <ATen/Functions.h>
4
#include <ATen/core/functional.h>
5
#include <c10/core/CPUAllocator.h>
6
#include <c10/util/irange.h>
7

8
#include <torch/csrc/jit/mobile/nnc/registry.h>
9

10
namespace torch {
11
namespace jit {
12
namespace mobile {
13
namespace nnc {
14

15
constexpr int64_t kProducedNNCFileFormatVersion = 0x1L;
16

17
namespace {
18

19
c10::IValue Tup(std::initializer_list<c10::IValue> ivalues) {
20
  return c10::ivalue::Tuple::create(ivalues);
21
}
22

23
c10::IValue Tup(std::vector<c10::IValue>&& ivalues) {
24
  return c10::ivalue::Tuple::create(ivalues);
25
}
26

27
} // namespace
28

29
InputSpec::InputSpec(const c10::IValue& value) {
30
  auto dict = value.toGenericDict();
31
  sizes_ = dict.at("sizes").toIntVector();
32
  dtype_ = dict.at("dtype").toScalarType();
33
}
34

35
c10::IValue InputSpec::serialize() const {
36
  c10::Dict<c10::IValue, c10::IValue> dict(
37
      at::StringType::get(), at::AnyType::get());
38
  dict.insert("sizes", sizes_);
39
  dict.insert("dtype", dtype_);
40
  return dict;
41
}
42

43
bool InputSpec::validate(const at::Tensor& input) const {
44
  if (sizes_.size() != input.sizes().size() || input.scalar_type() != dtype_) {
45
    return false;
46
  }
47
  auto spec_sizes = sizes_;
48
  for (const auto i : c10::irange(spec_sizes.size())) {
49
    // InputSpec size 0 means that the dimension is dynamic
50
    if (spec_sizes[i] != 0 && spec_sizes[i] != input.sizes()[i]) {
51
      return false;
52
    }
53
  }
54
  return true;
55
}
56

57
OutputSpec::OutputSpec(const c10::IValue& value) {
58
  auto dict = value.toGenericDict();
59
  sizes_ = dict.at("sizes").toIntVector();
60
  dtype_ = dict.at("dtype").toScalarType();
61
  if (dict.contains("qscale")) {
62
    qscale_ = dict.at("qscale").toDouble();
63
  }
64
  if (dict.contains("qzero")) {
65
    qzero_ = dict.at("qzero").toInt();
66
  }
67
}
68

69
c10::IValue OutputSpec::serialize() const {
70
  c10::Dict<c10::IValue, c10::IValue> dict(
71
      at::StringType::get(), at::AnyType::get());
72
  dict.insert("sizes", sizes_);
73
  dict.insert("dtype", dtype_);
74
  if (qscale_) {
75
    dict.insert("qscale", *qscale_);
76
  }
77
  if (qzero_) {
78
    dict.insert("qzero", *qzero_);
79
  }
80
  return dict;
81
}
82

83
at::Tensor OutputSpec::allocate() const {
84
  if (isQIntType(dtype_)) {
85
    TORCH_CHECK(
86
        qscale_ && qzero_,
87
        "Quantized output tensor must have qscale_ and qzero_");
88
    return at::_empty_affine_quantized(
89
        sizes_,
90
        at::TensorOptions()
91
            .dtype(dtype_)
92
            .layout(at::kStrided)
93
            .device(at::kCPU)
94
            .requires_grad(false),
95
        *qscale_,
96
        *qzero_);
97
  }
98
  return at::empty(
99
      sizes_,
100
      at::TensorOptions()
101
          .dtype(dtype_)
102
          .layout(at::kStrided)
103
          .device(at::kCPU)
104
          .requires_grad(false));
105
}
106

107
MemoryPlan::MemoryPlan(const c10::IValue& value) {
108
  auto dict = value.toGenericDict();
109
  buffer_sizes_ = dict.at("buffer_sizes").toIntVector();
110
}
111

112
c10::IValue MemoryPlan::serialize() const {
113
  c10::Dict<c10::IValue, c10::IValue> dict(
114
      at::StringType::get(), at::AnyType::get());
115
  dict.insert("buffer_sizes", buffer_sizes_);
116
  return dict;
117
}
118

119
void MemoryPlan::allocate(ExecutionState* state) const {
120
  auto& allocations = state->preallocations_;
121
  allocations.clear();
122
  allocations.reserve(buffer_sizes_.size());
123
  for (int64_t buffer_size : buffer_sizes_) {
124
    at::DataPtr buffer = c10::GetCPUAllocator()->allocate(buffer_size);
125
    allocations.emplace_back(std::move(buffer));
126
  }
127
}
128

129
Function::Function(const c10::IValue& value) {
130
  auto dict = value.toGenericDict();
131
  name_ = c10::QualifiedName(dict.at("name").toStringRef());
132
  nnc_kernel_id_ = dict.at("nnc_kernel_id").toStringRef();
133
  parameters_ = dict.at("parameters").toList();
134

135
  // input_specs_
136
  for (const auto& input_value :
137
       dict.at("input_specs").toTupleRef().elements()) {
138
    input_specs_.emplace_back(input_value);
139
  }
140

141
  // output_specs_
142
  for (const auto& output_value :
143
       dict.at("output_specs").toTupleRef().elements()) {
144
    output_specs_.emplace_back(output_value);
145
  }
146

147
  // memory_plan_
148
  memory_plan_ = MemoryPlan(dict.at("memory_plan"));
149

150
  // symbolic shape positions
151
  for (const auto& sym_shape_pos :
152
       dict.at("sym_shape_pos").toTupleRef().elements()) {
153
    auto sym_shape_elements = sym_shape_pos.toTupleRef().elements();
154
    sym_shape_positions_.emplace_back(
155
        sym_shape_elements[0].toInt(), sym_shape_elements[1].toInt());
156
  }
157
}
158

159
c10::IValue Function::serialize() const {
160
  c10::Dict<c10::IValue, c10::IValue> dict(
161
      at::StringType::get(), at::AnyType::get());
162

163
  dict.insert("name", name_.qualifiedName());
164
  dict.insert("nnc_kernel_id", nnc_kernel_id_);
165
  // TODO: should serialize parameters with Module instead of with each Method.
166
  // And ideally the parameters should be shared between the compiled model
167
  // and the original model if we can serialize both in the same model file.
168
  dict.insert("parameters", parameters_);
169

170
  // input_specs_
171
  std::vector<c10::IValue> input_specs;
172
  input_specs.reserve(input_specs_.size());
173
  for (const auto& input_spec : input_specs_) {
174
    input_specs.emplace_back(input_spec.serialize());
175
  }
176
  dict.insert("input_specs", Tup(std::move(input_specs)));
177

178
  // output_specs_
179
  std::vector<c10::IValue> output_specs;
180
  output_specs.reserve(output_specs_.size());
181
  for (const auto& output_spec : output_specs_) {
182
    output_specs.emplace_back(output_spec.serialize());
183
  }
184
  dict.insert("output_specs", Tup(std::move(output_specs)));
185

186
  // memory_plan_
187
  dict.insert("memory_plan", memory_plan_.serialize());
188

189
  // sym_shape_positions_
190
  std::vector<c10::IValue> sym_shape_pos_vec;
191
  sym_shape_pos_vec.reserve(sym_shape_positions_.size());
192
  for (const auto& sym_shape_pos : sym_shape_positions_) {
193
    sym_shape_pos_vec.emplace_back(
194
        Tup({sym_shape_pos.input_idx_, sym_shape_pos.dim_idx_}));
195
  }
196
  dict.insert("sym_shape_pos", Tup(std::move(sym_shape_pos_vec)));
197

198
  return dict;
199
}
200

201
void Function::init_execution_state() const {
202
  if (execution_state_.get() != nullptr) {
203
    return;
204
  }
205

206
  ExecutionState state;
207
  memory_plan_.allocate(&state);
208

209
  // The arguments vector consists of 5 sections: inputs, symbolic shapes,
210
  // outputs, parameters and buffers.
211
  auto input_args = input_specs_.size();
212
  auto sym_shape_args = sym_shape_positions_.size();
213
  auto output_args = output_specs_.size();
214
  auto param_args = parameters_.size();
215
  auto buffer_args = state.preallocations_.size();
216

217
  auto& arguments = state.arguments_;
218
  arguments.reserve(
219
      input_args + sym_shape_args + output_args + param_args + buffer_args);
220

221
  // Keep empty slots to fill in inputs/outputs pointers at execution time.
222
  arguments.resize(input_args + sym_shape_args + output_args);
223

224
  // Fill in parameters as untyped raw pointers.
225
  // The underlying storage of the parameters should be owned by `parameters_`,
226
  // which should be alive when `execution_state_` is being used.
227
  for (const auto& param : parameters_) {
228
    const c10::IValue& ivalue = (c10::IValue)param;
229
    if (ivalue.isTensor()) {
230
      arguments.emplace_back(ivalue.toTensor().data_ptr());
231
    } else if (torch::isCustomClass(ivalue)) {
232
      arguments.emplace_back(ivalue.toObjectRef().getSlot(0).toCapsule().get());
233
    } else {
234
      TORCH_CHECK(false, "Invalid parameter: ", ivalue);
235
    }
236
  }
237

238
  // Fill in preallocated buffer pointers.
239
  for (const auto& preallocation : state.preallocations_) {
240
    arguments.emplace_back(preallocation.get());
241
  }
242

243
  execution_state_ = std::make_unique<ExecutionState>(std::move(state));
244
}
245

246
c10::impl::GenericList Function::run(
247
    const c10::impl::GenericList& inputs) const {
248
  TORCH_CHECK(
249
      registry::has_nnc_kernel(nnc_kernel_id_),
250
      "Cannot find NNC kernel: ",
251
      nnc_kernel_id_);
252

253
  init_execution_state();
254

255
  std::vector<void*>& args = execution_state_->arguments_;
256

257
  // Fill in input tensors.
258
  TORCH_CHECK(
259
      input_specs_.size() == inputs.size(),
260
      "Input size doesn't match the spec, expect: ",
261
      input_specs_.size(),
262
      " actual: ",
263
      inputs.size());
264
  std::vector<int64_t> scalar_values;
265
  int offset = 0;
266
  for (const auto i : c10::irange(inputs.size())) {
267
    const c10::IValue& input = inputs[i];
268
    const auto& spec = input_specs_[i];
269
    const auto& input_tensor = input.toTensor();
270
    TORCH_CHECK(spec.validate(input_tensor), "Invalid input at pos: ", i);
271
    args[i] = input_tensor.data_ptr();
272
  }
273
  offset += inputs.size();
274

275
  scalar_values.reserve(sym_shape_positions_.size());
276
  for (const auto i : c10::irange(sym_shape_positions_.size())) {
277
    const auto& sym_shape_pos = sym_shape_positions_[i];
278
    const c10::IValue& input = inputs[sym_shape_pos.input_idx_];
279
    auto dim = input.toTensor().size(sym_shape_pos.dim_idx_);
280
    scalar_values.push_back(dim);
281
    args[i + offset] = &scalar_values[scalar_values.size() - 1];
282
  }
283
  offset += sym_shape_positions_.size();
284

285
  // Preallocate and fill in output tensors.
286
  c10::List<at::Tensor> outputs;
287
  outputs.reserve(output_specs_.size());
288
  for (const auto i : c10::irange(output_specs_.size())) {
289
    at::Tensor output = output_specs_[i].allocate();
290
    outputs.emplace_back(output);
291
    args[i + offset] = output.data_ptr();
292
  }
293

294
  // TODO: check consistency, e.g.: code version, input shape and compiled
295
  // shape, etc.
296
  auto kernel = registry::get_nnc_kernel(nnc_kernel_id_);
297
  kernel->execute(args.data());
298

299
  return c10::impl::toList(outputs);
300
}
301

302
CompilationUnit::CompilationUnit(const c10::IValue& value) {
303
  const auto& root = value.toTupleRef().elements();
304
  const auto& functions = root[1].toTupleRef().elements();
305
  for (const auto& function : functions) {
306
    register_function(std::make_unique<Function>(function));
307
  }
308
}
309

310
c10::IValue CompilationUnit::serialize() const {
311
  auto functions =
312
      c10::fmap(functions_, [](decltype(functions_)::const_reference func) {
313
        return func.second->serialize();
314
      });
315
  return Tup({kProducedNNCFileFormatVersion, Tup(std::move(functions))});
316
}
317

318
c10::impl::GenericList CompilationUnit::run(
319
    const c10::QualifiedName& name,
320
    const c10::impl::GenericList& inputs) const {
321
  Function* func = find_function(name);
322
  TORCH_CHECK(
323
      func != nullptr, "Function '", name.qualifiedName(), "' is not defined.");
324
  return func->run(inputs);
325
}
326

327
void CompilationUnit::register_function(std::unique_ptr<Function> fn) {
328
  TORCH_CHECK(
329
      0 == functions_.count(fn->name()),
330
      "method '",
331
      fn->name().qualifiedName(),
332
      "' already defined.");
333
  const auto& name = fn->name();
334
  functions_.emplace(name, std::move(fn));
335
}
336

337
Function* CompilationUnit::find_function(const c10::QualifiedName& name) const {
338
  auto it = functions_.find(name);
339
  if (it == functions_.end()) {
340
    return nullptr;
341
  }
342
  return it->second.get();
343
}
344

345
} // namespace nnc
346
} // namespace mobile
347
} // namespace jit
348
} // namespace torch
349

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.