pytorch

Форк
0
398 строк · 14.5 Кб
1
#include <ATen/Functions.h>
2
#include <ATen/core/dispatch/Dispatcher.h>
3
#include <ATen/core/dispatch/ObservedOperators.h>
4
#include <c10/core/ScalarType.h>
5
#include <c10/util/Exception.h>
6
#include <torch/csrc/autograd/grad_mode.h>
7
#include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
8
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
9
#include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
10
#include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
11
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
12
#include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
13
#include <torch/csrc/jit/mobile/parse_operators.h>
14
#include <torch/csrc/jit/runtime/operator.h>
15
#include <torch/script.h>
16

17
namespace torch {
18
namespace jit {
19
namespace mobile {
20

21
// Fetched from caffe2/aten/src/ATen/native/metal/MetalAten.mm
22
// Diffusion Link: https://fburl.com/diffusion/atwwmax2
23
const std::vector<std::string> gpu_metal_operators = {
24
    "aten::conv2d",
25
    "aten::add.Tensor",
26
    "aten::add_.Tensor",
27
    "aten::addmm",
28
    "aten::empty.memory_format",
29
    "aten::empty_strided",
30
    "aten::log_softmax.int",
31
    "aten::max_pool2d",
32
    "aten::mul.Tensor",
33
    "aten::relu",
34
    "aten::relu_",
35
    "aten::sigmoid",
36
    "aten::sub.Tensor",
37
    "aten::upsample_nearest2d.vec",
38
    "aten::view",
39
    "aten::adaptive_avg_pool2d",
40
    "aten::hardtanh_",
41
    "aten::reshape",
42
    "aten::flatten.using_ints",
43
};
44

45
/**
46
 * These are a collection of some common ATen methods that are usually
47
 * called outside of the Model's forward() run, and they need to be
48
 * traced to ensure that the used operators are included in the build.
49
 * If/When this list becomes too long, we can consider making it a
50
 * per-model list.
51
 */
52
void call_setup_methods() {
53
  at::zeros({2, 2});
54
  at::ones({2, 2});
55
  at::Tensor t1 = at::empty({7, 7});
56
  at::Tensor t2 = t1.fill_(3);
57
  at::Tensor t3 = t1.new_empty_strided(
58
      {2, 3},
59
      {3,
60
       1}); // TODO investigate how this is different from normal empty_strided
61
  at::narrow(t2, 1, 0, 1);
62
  at::eq(t1, t2);
63
  const volatile bool nz = at::native::is_nonzero(at::zeros({1}));
64
  (void)nz;
65

66
  // Create a byte tensor and copy it
67
  auto zb = at::zeros({10}, at::kByte);
68
  auto zf = at::zeros({10}, at::kFloat);
69
  zb.copy_(zf);
70
  t2.div(1);
71

72
  // Typically, failures show up in CopyKernel.cpp, so enumerating
73
  // common dtypes that may show up.
74
  const auto all_dtypes_for_copy = {
75
      at::kBool,
76
      at::kByte,
77
      at::kFloat,
78
      at::kInt,
79
      at::kChar,
80
      at::kDouble,
81
      at::kShort,
82
      at::kLong};
83
  for (const auto dtype : all_dtypes_for_copy) {
84
    auto tensor1 = at::empty({10}, dtype);
85
    tensor1.copy_(at::zeros({10}, at::kBool));
86
    tensor1.copy_(at::zeros({10}, at::kFloat));
87
    tensor1.copy_(at::zeros({10}, at::kInt));
88
  }
89

90
  torch::zeros({0, 0}, torch::ScalarType::Float);
91
  std::vector<float> storage(20, 1.0);
92
  std::vector<int64_t> sizes({2, 10});
93
  torch::from_blob(storage.data(), at::IntArrayRef(sizes), at::kFloat);
94
}
95

96
/**
97
 * Similar to setup methods there are a suite a functions that often appear
98
 * under certain conditions but may avoid getting called in the trace due to the
99
 * narrow nature of bundled inputs
100
 */
101
void call_dependent_methods(std::set<std::string>& root_ops) {
102
  bool is_training = false;
103
  bool has_batchnorm = false;
104
  bool has_dropout = false;
105
  for (const std::string& op : root_ops) {
106
    if (op.find("backward") != std::string::npos ||
107
        op.find("requires_grad_") != std::string::npos) {
108
      is_training = true;
109
    }
110
    if (op.find("batch_norm") != std::string::npos) {
111
      has_batchnorm = true;
112
    }
113
    if (op.find("dropout") != std::string::npos) {
114
      has_dropout = true;
115
    }
116
  }
117
  if (is_training && has_batchnorm) {
118
    at::batch_norm(
119
        at::ones({2, 2}),
120
        c10::nullopt,
121
        c10::nullopt,
122
        c10::nullopt,
123
        c10::nullopt,
124
        true,
125
        0.1,
126
        0.1,
127
        false);
128
  }
129
  if (is_training && has_dropout) {
130
    at::dropout(at::ones({20, 20, 20}), 0.2, true);
131
  }
132
}
133

134
/**
135
 * Call methods on the Tensor object that we expect to be called
136
 * in production on this Tensor.
137
 */
138
void consume_tensor(const at::Tensor& t) {
139
  const at::Tensor& c = t;
140
  c.copy_(t.cpu());
141
}
142

143
std::unordered_map<std::string, c10::FunctionSchema>
144
_get_runtime_ops_and_schema() {
145
  std::unordered_map<std::string, c10::FunctionSchema> result;
146

147
  // Grab the jit operators
148
  auto nonDispatcherOperators = torch::jit::getAllOperators();
149
  for (const auto& full_op : nonDispatcherOperators) {
150
    auto op = full_op->schema();
151
    auto op_name = op.name();
152
    if (!op.overload_name().empty()) {
153
      op_name += ("." + op.overload_name());
154
    }
155
    result.emplace(op_name, op);
156
  }
157

158
  // Grab the dispatcher operators
159
  auto dispatcherOperators = c10::Dispatcher::singleton().getAllOpNames();
160
  for (auto& op : dispatcherOperators) {
161
    // grab schema
162
    const auto op_handle = c10::Dispatcher::singleton().findOp(op);
163
    if (op_handle->hasSchema()) {
164
      auto op_name = op.name;
165
      if (!op.overload_name.empty()) {
166
        op_name += ("." + op.overload_name);
167
      }
168
      result.emplace(op_name, op_handle->schema());
169
    }
170
  }
171

172
  return result;
173
}
174

175
/**
176
 * For the vast majority of usecases the instrumentation in getCustomClass will
177
 * catch any custom classes referenced by a model. There are however, niche
178
 * situations that avoid the getCustomClass instrumentation due to some nuances
179
 * of mobile model deserialization. To get around that we can search through all
180
 * the used ops, and inspect their schemas to search for any referenced classes.
181
 * Example schema: prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None,
182
 *   Scalar? output_min=None, Scalar? output_max=None) ->
183
 *   __torch__.torch.classes.xnnpack.LinearOpContext"
184
 */
185
void recordCustomClassesFromOpSchemas(
186
    std::set<std::string>& root_ops,
187
    std::set<std::string>& traced_ops,
188
    std::set<std::string>& loaded_classes) {
189
  std::set<std::string> ops;
190
  ops.insert(root_ops.begin(), root_ops.end());
191
  ops.insert(traced_ops.begin(), traced_ops.end());
192
  auto ops_and_schemas = _get_runtime_ops_and_schema();
193

194
  auto record_if_class = [&](std::string type_name) {
195
    // All custom class types start with __torch__ not sure if this is by
196
    // chance or guaranteed
197
    if (type_name.find("__torch__") != std::string::npos) {
198
      // The name of a customClassType here is its fully qualified name, but
199
      // in registration only the class name is used so only record that
200
      auto class_name = type_name.substr(type_name.find_last_of('.') + 1);
201
      // Function schemas can include other type indicators such as [] so we
202
      // need to trim to just alphanumeric + '_' characters as well
203
      class_name = class_name.substr(
204
          0,
205
          class_name.find_first_not_of(
206
              "aAbBcCdDeEfFgGhHiIjJkKlLmMnNoOpPqQrRsStTuUvVwWxXyYzZ_1234567890"));
207
      loaded_classes.insert(class_name);
208
    }
209
  };
210

211
  for (auto& op_name : ops) {
212
    // This check is only necessary because of GPU models.
213
    // Certain models can only run on a specific backend say metal.
214
    // Those ops will be present in the models root ops, but likely
215
    // not the tracer on linux
216
    if (ops_and_schemas.find(op_name) != ops_and_schemas.end()) {
217
      auto& schema = ops_and_schemas.at(op_name);
218
      for (auto& arg : schema.arguments()) {
219
        record_if_class(arg.type()->annotation_str());
220
      }
221
      for (auto& ret : schema.returns()) {
222
        record_if_class(ret.type()->annotation_str());
223
      }
224
    }
225
  }
226
}
227

228
void run_model(
229
    const std::string& input_module_path,
230
    std::set<std::string>& root_ops,
231
    std::set<std::string>& enabled_backends,
232
    KernelDTypeTracer::kernel_tags_type& called_kernel_tags) {
233
  // Load the module on CPU with the flag to skip the operator exists check.
234
  // This is needed so that we can load any TorchBind objects (custom classes)
235
  // that this model refers to so that any operators being called from those
236
  // TorchBind objects can be traced by the model tracer.
237
  torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
238
  root_ops = module_runner.get_root_operators();
239
  std::cout << "Got " << root_ops.size() << " Root Operators." << std::endl;
240

241
  if (torch::jit::mobile::MobileModelRunner::set_has_metal_gpu_operators(
242
          root_ops)) {
243
    std::cout << "Inferred Metal GPU Model." << std::endl;
244
    root_ops.insert(gpu_metal_operators.begin(), gpu_metal_operators.end());
245
    called_kernel_tags["__unused__"] = {"Float"};
246
    enabled_backends.insert("Metal GPU");
247

248
    // When we encounter a GPU model, we should call .cpu().copy_() on the
249
    // tensors in the bundled inputs, since this is what will happen when
250
    // such a model is executed on an iOS device (to copy the Tensor to Metal
251
    // memory via a call to .metal()).
252
    module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
253
  } else {
254
    std::cout << "Inferred CPU Model." << std::endl;
255
    enabled_backends.insert("CPU");
256
    torch::jit::mobile::MobileModelRunner mobile_module_runner(
257
        input_module_path);
258

259
    // When we encounter a CPU model, we should call .cpu().copy_() on the
260
    // tensors in the bundled inputs, since this is what will happen when
261
    // such a model is executed on an Android device since the PyTorch JNI
262
    // bindings call .cpu() in JIValue::newJIValueFromAtIValue().
263
    module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
264

265
    // If a user has bundled inputs since that api was updated to accept
266
    // bundled inputs for multiple methods They should go down this route.
267
    // Even if they only bundle inputs for forward they will have the new
268
    // style bundled inputs. Since at this time in tracer.cpp we do not know
269
    // what functions have bundled inputs we must call
270
    // get_bundled_inputs_functions_and_info if it exists to get the set.
271
    if (mobile_module_runner.has_new_style_bundled_inputs()) {
272
      auto bundled_inputs_mapping =
273
          mobile_module_runner.get_many_functions_bundled_inputs();
274
      for (auto& entry : bundled_inputs_mapping) {
275
        std::string function_name = entry.first;
276
        std::vector<std::vector<at::IValue>> bundled_inputs = entry.second;
277
        std::cout << "Got " << bundled_inputs.size() << " bundled input(s) for "
278
                  << function_name << "\n\n";
279
        std::vector<at::IValue> results =
280
            mobile_module_runner.run_with_inputs(function_name, bundled_inputs);
281

282
        for (auto& result : results) {
283
          // Consume the result Tensor(s) when tracing on CPU since the
284
          // Android/Java JNI bindings will do the same.
285
          torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
286
        }
287
      }
288
      // If get_bundled_inputs_functions_and_info does not exists we default
289
      // to assuming they bundled before that change was made. If no bundled
290
      // inputs are found here either an error will be thrown
291
    } else {
292
      std::vector<std::vector<at::IValue>> bundled_inputs =
293
          mobile_module_runner.get_all_bundled_inputs();
294
      std::cout << "Got " << bundled_inputs.size() << " bundled input(s)\n\n";
295
      std::vector<at::IValue> results =
296
          mobile_module_runner.run_with_inputs(bundled_inputs);
297

298
      for (auto& result : results) {
299
        // Consume the result Tensor(s) when tracing on CPU since the
300
        // Android/Java JNI bindings will do the same.
301
        torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
302
      }
303
    }
304
  }
305
}
306

307
TracerResult trace_run(const std::string& input_module_path) {
308
  return trace_run(std::vector<std::string>(1, input_module_path));
309
}
310

311
TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
312
  at::globalContext().setQEngine(at::QEngine::QNNPACK);
313
  c10::ObservedOperators::getUnobservedOperatorList().clear();
314

315
  torch::jit::mobile::OperatorCallTracer op_tracer;
316
  torch::jit::mobile::KernelDTypeTracer kdtype_tracer;
317
  torch::jit::mobile::CustomClassTracer custom_class_tracer;
318
  torch::jit::mobile::BuildFeatureTracer build_feature_tracer;
319

320
  call_setup_methods();
321

322
  std::set<std::string> root_ops, traced_operators, enabled_backends,
323
      loaded_classes, build_features;
324
  torch::jit::mobile::KernelDTypeTracer::kernel_tags_type called_kernel_tags;
325

326
  using torch::jit::MobileModuleLoadOptions;
327

328
  for (auto& input_module_path : input_module_paths) {
329
    // run with QNNPACK
330
    at::globalContext().setQEngine(at::QEngine::QNNPACK);
331

332
    run_model(
333
        input_module_path, root_ops, enabled_backends, called_kernel_tags);
334
    // Not every model can be successfully run with fbgemm,
335
    // but for those that can this can help broaden the tracers scope around
336
    // hyper optimized QNNPack paths
337
    try {
338
      at::globalContext().setQEngine(at::QEngine::FBGEMM);
339
      run_model(
340
          input_module_path, root_ops, enabled_backends, called_kernel_tags);
341
    } catch (std::exception& ex) {
342
      std::cerr
343
          << "ModelTracer encountered an error while attempting to run the model in FBGEMM mode"
344
          << ex.what() << "\n Skipping FBGEMM execution" << std::endl;
345
    }
346
    try {
347
      at::globalContext().setQEngine(at::QEngine::QNNPACK);
348
      c10::InferenceMode guard(true);
349
      run_model(
350
          input_module_path, root_ops, enabled_backends, called_kernel_tags);
351
    } catch (std::exception& ex) {
352
      std::cerr
353
          << "ModelTracer encountered an error while attempting to run the model under an inference guard"
354
          << ex.what() << "\n Skipping inference guard execution" << std::endl;
355
    }
356
  }
357

358
  call_dependent_methods(root_ops);
359

360
  op_tracer.getCalledOperators().withLock(
361
      [&](std::set<std::string>& called_operators) {
362
        traced_operators = called_operators;
363
      });
364

365
  recordCustomClassesFromOpSchemas(root_ops, traced_operators, loaded_classes);
366

367
  kdtype_tracer.getCalledKernelTags().withLock(
368
      [&](KernelDTypeTracer::kernel_tags_type& kernel_tags) {
369
        called_kernel_tags.insert(kernel_tags.begin(), kernel_tags.end());
370
      });
371

372
  traced_operators.insert(
373
      always_included_traced_ops.begin(), always_included_traced_ops.end());
374

375
  custom_class_tracer.getLoadedClasses().withLock(
376
      [&](CustomClassTracer::custom_classes_type& custom_classes) {
377
        loaded_classes.insert(custom_classes.begin(), custom_classes.end());
378
      });
379

380
  build_feature_tracer.getBuildFeatures().withLock(
381
      [&](BuildFeatureTracer::build_feature_type& bf) {
382
        build_features.insert(bf.begin(), bf.end());
383
      });
384

385
  TracerResult tracer_result = {
386
      root_ops,
387
      traced_operators,
388
      called_kernel_tags,
389
      loaded_classes,
390
      build_features,
391
      enabled_backends};
392

393
  return tracer_result;
394
}
395

396
} // namespace mobile
397
} // namespace jit
398
} // namespace torch
399

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.