pytorch

Форк
0
/
aot_compiler.cpp 
450 строк · 16.1 Кб
1
#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
2

3
#include <ATen/Functions.h>
4
#include <ATen/NativeFunctions.h>
5
#include <torch/csrc/jit/backends/backend.h>
6
#include <torch/csrc/jit/backends/backend_detail.h>
7
#include <torch/csrc/jit/backends/backend_preprocess.h>
8
#include <torch/csrc/jit/ir/ir.h>
9
#include <torch/csrc/jit/jit_log.h>
10
#include <torch/csrc/jit/passes/constant_propagation.h>
11
#include <torch/csrc/jit/passes/dead_code_elimination.h>
12
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
13
#include <torch/csrc/jit/passes/lower_tuples.h>
14
#include <torch/csrc/jit/passes/peephole.h>
15
#include <torch/csrc/jit/passes/remove_mutation.h>
16
#include <torch/csrc/jit/passes/shape_analysis.h>
17
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
18
#include <torch/csrc/jit/runtime/jit_trace.h>
19
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
20
#include <torch/csrc/jit/tensorexpr/ir.h>
21
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
22
#include <torch/csrc/jit/tensorexpr/kernel.h>
23
#include <fstream>
24

25
using namespace torch::jit;
26
using namespace torch::jit::tensorexpr;
27

28
namespace torch {
29
namespace jit {
30
namespace mobile {
31
namespace nnc {
32

33
static std::vector<int64_t> getConstSizes(const BufPtr b) {
34
  std::vector<int64_t> r;
35
  for (const auto& dim : b->dims()) {
36
    LongImmPtr imm_dim = to<LongImm>(dim);
37
    // TODO: assert it's actually immediate
38
    int64_t s = imm_dim->value();
39
    r.push_back(s);
40
  }
41
  return r;
42
}
43

44
// Construct input-specs vector from the inputs of the original graph
45
static std::vector<mobile::nnc::InputSpec> toInputSpecs(
46
    const std::shared_ptr<tensorexpr::TensorExprKernel>& kernel) {
47
  const std::shared_ptr<Graph>& g = kernel->graph();
48
  std::vector<mobile::nnc::InputSpec> specs;
49

50
  // Graph inputs include scalar values for symbolic shapes, for which we
51
  // don't need input specs. These scalar values come last among the graph
52
  // inputs
53
  auto num_inputs =
54
      g->inputs().size() - kernel->getSymbolicShapeInputs().size();
55

56
  for (const auto i : c10::irange(num_inputs)) {
57
    auto v = g->inputs()[i];
58
    const auto& t = v->type();
59
    mobile::nnc::InputSpec spec;
60
    TORCH_CHECK(t->kind() == TypeKind::TensorType, "Unsupported input type");
61
    const auto& tt = t->cast<TensorType>();
62
    spec.sizes_ = {};
63
    auto sizes_vec = *tt->sizes().sizes();
64
    for (auto s : sizes_vec) {
65
      spec.sizes_.push_back(s ? *s : 0);
66
    }
67
    spec.dtype_ = *tt->scalarType();
68
    specs.emplace_back(std::move(spec));
69
  }
70
  return specs;
71
}
72

73
// Locate symbolic shapes in shapes of the inputs.
74
//
75
// For each symbolic shape we're trying to find the input from which it can be
76
// extracted and the dimension index in that input.
77
// For instance, if we have
78
// graph(%x : Float(SS(-1), 10), %y : Long(20, SS(-2), %ss_1 : int, %ss_2 : int)
79
// then we would need to find locations of two symbolic shapes: SS(-1) and
80
// SS(-2). The first one corresponds to the first dimension of the first input,
81
// the second one corresponds to the second dimension of the second input,
82
// so we will return {{0, 0}, {1, 1}}.
83
//
84
// If a symbolic shape cannot be found among dimensions of inputs, we
85
// will throw an error (this situation is possible when symbolic shape
86
// corresponds to the size of an intermediate - we don't support this
87
// case here yet).
88
//
89
// If a symbolic shape can be found in several different positions, we
90
// return the first one we find (TODO: maybe we should return all and
91
// verify that they all match at runtime).
92
static std::vector<SymbolicShapePosition> findSymbolicShapePositions(
93
    std::shared_ptr<tensorexpr::TensorExprKernel> kernel) {
94
  std::vector<SymbolicShapePosition> res;
95
  for (int64_t sym_idx : kernel->getSymbolicShapeInputs()) {
96
    bool found = false;
97
    for (int64_t input_idx : c10::irange(kernel->graph()->inputs().size())) {
98
      auto input = kernel->graph()->inputs()[input_idx];
99

100
      if (!input->type()->cast<TensorType>()) {
101
        continue;
102
      }
103
      auto tt = input->type()->expect<TensorType>();
104
      if (!tt->symbolic_sizes().sizes()) {
105
        continue;
106
      }
107
      std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
108
      for (int64_t dim_idx : c10::irange(shape_vec.size())) {
109
        if (shape_vec[dim_idx].value() == sym_idx) {
110
          res.emplace_back(input_idx, dim_idx);
111
          found = true;
112
          break;
113
        }
114
      }
115
      if (found) {
116
        break;
117
      }
118
    }
119
    TORCH_CHECK(
120
        found, "Could not locate a symbolic shape among input tensor shapes");
121
  }
122
  return res;
123
}
124

125
static std::unique_ptr<Function> compileMethod(
126
    std::shared_ptr<tensorexpr::TensorExprKernel> kernel,
127
    const std::string& method_name,
128
    const std::vector<std::vector<int64_t>>& sizes,
129
    const std::vector<at::ScalarType>& types) {
130
  auto func = std::make_unique<Function>();
131
  func->set_name(method_name);
132
  func->set_input_specs(toInputSpecs(kernel));
133

134
  auto params = c10::impl::GenericList(c10::AnyType::get());
135
  auto const_descriptors = kernel->getConstantDescriptors();
136
  for (const auto& cd : const_descriptors) {
137
    auto sizes = getConstSizes(cd.buf);
138
    if (!cd.node) {
139
      // sizes.empty() needs to be handled as sizes can be empty for Scalar
140
      // Tensors
141
      at::Tensor const_tensor = !sizes.empty()
142
          ? at::from_blob(cd.ptr, sizes).clone()
143
          : at::native::wrapped_scalar_tensor(*static_cast<double*>(cd.ptr));
144
      params.push_back(const_tensor);
145
    } else {
146
      params.emplace_back(toIValue(cd.node->output()));
147
    }
148
  }
149
  func->set_parameters(params);
150

151
  MemoryPlan plan;
152
  plan.buffer_sizes_ = {}; // temp_sizes_;
153
  // TODO: implement prealloc optimization and fill in temp_sizes
154
  func->set_memory_plan(plan);
155

156
  int64_t n_inputs = kernel->graph()->inputs().size();
157
  int64_t n_outputs = kernel->graph()->outputs().size();
158
  std::vector<OutputSpec> out_spec;
159
  for (int64_t idx = n_inputs; idx < n_inputs + n_outputs; idx++) {
160
    const auto& ba = kernel->getBufferArgs()[idx];
161
    OutputSpec output;
162
    output.sizes_ = getConstSizes(ba.buf());
163
    // TODO: assert the output is a buffer and not a scalar
164
    output.dtype_ = ba.buf()->dtype().scalar_type();
165
    if (isQIntType(output.dtype_)) {
166
      // Supporting only static qscale/qzero
167
      output.qscale_ =
168
          to<DoubleImm>(torch::jit::tensorexpr::IRSimplifier::simplify(
169
                            ba.buf()->qscale()))
170
              ->value();
171
      output.qzero_ =
172
          to<LongImm>(
173
              torch::jit::tensorexpr::IRSimplifier::simplify(ba.buf()->qzero()))
174
              ->value();
175
    }
176
    out_spec.push_back(output);
177
  }
178
  func->set_output_specs(out_spec);
179
  func->set_sym_shape_positions(findSymbolicShapePositions(kernel));
180

181
  return func;
182
}
183

184
static std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
185
    const std::string& method_name,
186
    std::shared_ptr<Graph>& g,
187
    const std::vector<std::vector<int64_t>>& sizes,
188
    const std::vector<at::ScalarType>& types,
189
    const std::string& kernel_func_name,
190
    const std::vector<int64_t>& symbolic_ind) {
191
  GRAPH_DEBUG("Input sizes ", sizes);
192
  GRAPH_DEBUG("Input types ", types);
193
  GRAPH_DEBUG("Method name ", method_name);
194
  GRAPH_DEBUG("Kernel func name ", kernel_func_name);
195
  GRAPH_DEBUG("Symbolic indices ", symbolic_ind);
196

197
  std::shared_ptr<tensorexpr::TensorExprKernel> kernel;
198
  std::vector<torch::jit::StrideInput> stride_desc = {
199
      torch::jit::StrideInput::TENSOR_CONT};
200
  std::unordered_map<
201
      const torch::jit::Value*,
202
      std::vector<torch::jit::StrideInput>>
203
      symbolic_strides;
204
  if (!symbolic_ind.empty()) {
205
    for (auto i : g->inputs()) {
206
      symbolic_strides[i] = stride_desc;
207
    }
208
    for (auto o : g->outputs()) {
209
      symbolic_strides[o] = stride_desc;
210
    }
211
  }
212
  kernel = std::make_shared<tensorexpr::TensorExprKernel>(TensorExprKernel(
213
      g, kernel_func_name, {}, symbolic_ind, false, symbolic_strides));
214

215
  const std::string compiled_assembly = kernel->getCodeText();
216
  auto func = compileMethod(kernel, method_name, sizes, types);
217
  return std::make_pair(std::move(func), compiled_assembly);
218
}
219

220
static void writeOutputLlvmAssembly(
221
    const std::string& asm_code,
222
    const std::string& output_llvm_file_name) {
223
  std::ofstream output(output_llvm_file_name);
224
  output << asm_code;
225
  GRAPH_DEBUG(
226
      "The compiled llvm assembly code was saved to ", output_llvm_file_name);
227
}
228

229
static std::vector<std::string> split(
230
    char separator,
231
    const std::string& string,
232
    bool ignore_empty = true) {
233
  std::vector<std::string> pieces;
234
  std::stringstream ss(string);
235
  std::string item;
236
  while (getline(ss, item, separator)) {
237
    if (!ignore_empty || !item.empty()) {
238
      pieces.push_back(std::move(item));
239
    }
240
  }
241
  return pieces;
242
}
243

244
static std::vector<std::vector<int64_t>> parseInputShapes(
245
    const std::string& input_dims_s) {
246
  std::vector<std::string> input_dims_list = split(';', input_dims_s);
247
  std::vector<std::vector<int64_t>> inputs;
248
  for (const auto& input_dims_item : input_dims_list) {
249
    auto input_dims_str = split(',', input_dims_item);
250
    std::vector<int64_t> input_dims;
251
    input_dims.reserve(input_dims_str.size());
252
    for (const auto& s : input_dims_str) {
253
      input_dims.push_back(std::stoi(s));
254
    }
255
    inputs.push_back(input_dims);
256
  }
257
  return inputs;
258
}
259

260
static std::vector<at::ScalarType> parseInputTypes(
261
    const std::string& input_types_str) {
262
  std::vector<std::string> inputTypes = split(';', input_types_str);
263
  std::vector<at::ScalarType> scalarTypes;
264
  for (const auto& inputType : inputTypes) {
265
    at::ScalarType scalarType;
266
    if (inputType == "float") {
267
      scalarType = at::ScalarType::Float;
268
    } else if (inputType == "uint8") {
269
      scalarType = at::ScalarType::Byte;
270
    } else if (inputType == "int64") {
271
      scalarType = at::ScalarType::Long;
272
    } else {
273
      CAFFE_THROW("Unsupported input type: ", inputType);
274
    }
275
    scalarTypes.push_back(scalarType);
276
  }
277
  return scalarTypes;
278
}
279

280
static std::vector<at::MemoryFormat> parseInputMemoryFormats(
281
    const std::string& input_memory_format_str) {
282
  std::vector<std::string> memFormatsStr = split(';', input_memory_format_str);
283
  std::vector<at::MemoryFormat> memFormats;
284
  for (const auto& memFormatStr : memFormatsStr) {
285
    at::MemoryFormat memFormat;
286
    if (memFormatStr == "contiguous") {
287
      memFormat = at::MemoryFormat::Contiguous;
288
    } else if (memFormatStr == "channels_last") {
289
      memFormat = at::MemoryFormat::ChannelsLast;
290
    } else {
291
      CAFFE_THROW("Unsupported memory format: ", memFormatStr);
292
    }
293
    memFormats.push_back(memFormat);
294
  }
295
  return memFormats;
296
}
297

298
static std::vector<int64_t> parseInputDynamicShapes(
299
    const std::string& dynamic_dims_s) {
300
  std::vector<std::string> dynamic_dims_list = split(',', dynamic_dims_s);
301
  std::vector<int64_t> dynamic_dims;
302
  dynamic_dims.reserve(dynamic_dims_list.size());
303
  for (const auto& dim : dynamic_dims_list) {
304
    dynamic_dims.push_back(std::stoi(dim));
305
  }
306
  return dynamic_dims;
307
}
308

309
static std::string getNncKernelId(
310
    const std::string& model_name,
311
    const std::string& model_version,
312
    const std::string& method_name) {
313
  // TODO: calculate the version_token.
314
  const std::string version_token = "VERTOKEN";
315
  return model_name + ":" + model_version + ":" + method_name + ":" +
316
      version_token;
317
}
318

319
static std::string getNncKernelFuncName(
320
    const std::string& model_name,
321
    const std::string& model_version,
322
    const std::string& method_name) {
323
  return "nnc_" + model_name + "_" + model_version + "_" + method_name;
324
}
325

326
// Preprocess the graph and returns the processed graph and
327
// symbolic values if dynamic input shapes are specified
328
static std::pair<std::shared_ptr<Graph>, std::vector<int64_t>>
329
preprocessGraphPasses(
330
    std::shared_ptr<Graph>& graph,
331
    const std::vector<c10::optional<at::Tensor>>& example_inputs,
332
    const std::vector<int64_t>& dynamic_sizes) {
333
  GRAPH_DEBUG("Before preprocessing graph passes: ", *graph);
334
  torch::jit::RemoveTensorMutation(graph);
335
  torch::jit::EliminateDeadCode(graph->block());
336
  graph = torch::jit::tensorexpr::removeUnusedSelfArgument(graph);
337

338
  torch::jit::tensorexpr::annotateInputShapes(graph, example_inputs);
339
  torch::jit::OptimizeFrozenGraph(graph, true);
340
  torch::jit::PropagateShapesOnGraph(graph);
341
  torch::jit::PeepholeOptimize(graph, false);
342
  torch::jit::ConstantPropagation(graph);
343
  torch::jit::PropagateShapesOnGraph(graph);
344
  torch::jit::PeepholeOptimize(graph, false);
345
  torch::jit::ConstantPropagation(graph);
346

347
  tensorexpr::removeUnusedSelfArgument(graph);
348

349
  std::vector<at::IValue> example_values;
350
  example_values.reserve(example_inputs.size());
351
  for (auto example_input : example_inputs) {
352
    example_values.emplace_back(*example_input);
353
  }
354
  graph = TraceGraph(graph, example_values);
355
  // TODO: Remove annotateInputShapes pass when TraceGraph can also capture
356
  // input shapes
357
  tensorexpr::annotateInputShapes(graph, example_inputs);
358

359
  RemoveListMutation(graph);
360
  RemoveTensorMutation(graph);
361
  EliminateDeadCode(graph);
362
  LowerAllTuples(graph);
363

364
  auto sym_val =
365
      torch::jit::tensorexpr::makeShapesSymbolic(graph, dynamic_sizes);
366

367
  GRAPH_DEBUG("After preprocessing graph passes: ", *graph);
368
  return std::make_pair(graph, sym_val);
369
}
370

371
static std::vector<c10::optional<at::Tensor>> generateExampleInputs(
372
    const std::vector<std::vector<int64_t>>& inputShapes,
373
    const std::vector<at::ScalarType>& inputTypes,
374
    const std::vector<at::MemoryFormat>& inputMemoryFormats) {
375
  std::vector<c10::optional<at::Tensor>> example_inputs;
376
  example_inputs.reserve(inputShapes.size());
377
  for (const auto i : c10::irange(inputShapes.size())) {
378
    const auto dtype = at::dtype(inputTypes[i]);
379
    const auto memory_format = inputMemoryFormats[i];
380
    example_inputs.emplace_back(
381
        at::rand(inputShapes[i]).to(dtype).contiguous(memory_format));
382
  }
383
  return example_inputs;
384
}
385

386
static c10::IValue preprocess(
387
    const torch::jit::Module& mod,
388
    const c10::Dict<c10::IValue, c10::IValue>& compile_spec,
389
    const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
390
  torch::jit::mobile::nnc::CompilationUnit cu;
391
  for (const auto& kv : compile_spec) {
392
    GRAPH_DEBUG("Key: ", kv.key());
393
    GRAPH_DEBUG("Value: ", kv.value());
394
    std::string method_name = *(kv.key().toString());
395
    GRAPH_DEBUG("Method name: ", method_name);
396
    auto method_spec = kv.value().toGenericDict();
397
    std::string model_name = *method_spec.at("model_name").toString();
398
    std::string model_version = *method_spec.at("model_version").toString();
399
    std::string asmfile_name = *method_spec.at("asmfile").toString();
400
    GRAPH_DEBUG("Model name: ", model_name);
401
    GRAPH_DEBUG("Model version: ", model_version);
402
    GRAPH_DEBUG("Asm file name: ", asmfile_name);
403

404
    auto method = mod.get_method(method_name);
405
    auto graph = toGraphFunction(method.function()).graph()->copy();
406

407
    auto sizes = parseInputShapes(*method_spec.at("sizes").toString());
408
    auto types = parseInputTypes(*method_spec.at("types").toString());
409
    auto dynamic_sizes =
410
        parseInputDynamicShapes(*method_spec.at("dynamic_sizes").toString());
411

412
    std::string memory_formats_str = method_spec.contains("memory_formats")
413
        ? (*method_spec.at("memory_formats").toString()).string()
414
        : "";
415
    auto memory_formats = memory_formats_str.empty()
416
        ? std::vector<at::MemoryFormat>(
417
              sizes.size(), at::MemoryFormat::Contiguous)
418
        : parseInputMemoryFormats(memory_formats_str);
419

420
    auto example_inputs = generateExampleInputs(sizes, types, memory_formats);
421
    auto preprocessed =
422
        preprocessGraphPasses(graph, example_inputs, dynamic_sizes);
423

424
    auto kernel_func_name =
425
        getNncKernelFuncName(model_name, model_version, method_name);
426
    auto processed_graph = preprocessed.first;
427
    auto sym_values = preprocessed.second;
428
    auto compiled = torch::jit::mobile::nnc::aotCompile(
429
        method_name,
430
        processed_graph,
431
        sizes,
432
        types,
433
        kernel_func_name,
434
        sym_values);
435
    writeOutputLlvmAssembly(compiled.second, asmfile_name);
436
    auto func = std::move(compiled.first);
437
    func->set_nnc_kernel_id(
438
        getNncKernelId(model_name, model_version, method_name));
439
    cu.register_function(std::move(func));
440
  }
441
  return cu.serialize();
442
}
443

444
// TODO(mvz): temporarily disable NNC backend in mobile builds.
445
// static auto reg = torch::jit::backend_preprocess_register("nnc", preprocess);
446

447
} // namespace nnc
448
} // namespace mobile
449
} // namespace jit
450
} // namespace torch
451

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.