1
#include <ATen/Functions.h>
2
#include <ATen/core/dispatch/Dispatcher.h>
3
#include <ATen/core/dispatch/ObservedOperators.h>
4
#include <c10/core/ScalarType.h>
5
#include <c10/util/Exception.h>
6
#include <torch/csrc/autograd/grad_mode.h>
7
#include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
8
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
9
#include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
10
#include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
11
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
12
#include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
13
#include <torch/csrc/jit/mobile/parse_operators.h>
14
#include <torch/csrc/jit/runtime/operator.h>
15
#include <torch/script.h>
21
// Fetched from caffe2/aten/src/ATen/native/metal/MetalAten.mm
22
// Diffusion Link: https://fburl.com/diffusion/atwwmax2
23
const std::vector<std::string> gpu_metal_operators = {
28
"aten::empty.memory_format",
29
"aten::empty_strided",
30
"aten::log_softmax.int",
37
"aten::upsample_nearest2d.vec",
39
"aten::adaptive_avg_pool2d",
42
"aten::flatten.using_ints",
46
* These are a collection of some common ATen methods that are usually
47
* called outside of the Model's forward() run, and they need to be
48
* traced to ensure that the used operators are included in the build.
49
* If/When this list becomes too long, we can consider making it a
52
void call_setup_methods() {
55
at::Tensor t1 = at::empty({7, 7});
56
at::Tensor t2 = t1.fill_(3);
57
at::Tensor t3 = t1.new_empty_strided(
60
1}); // TODO investigate how this is different from normal empty_strided
61
at::narrow(t2, 1, 0, 1);
63
const volatile bool nz = at::native::is_nonzero(at::zeros({1}));
66
// Create a byte tensor and copy it
67
auto zb = at::zeros({10}, at::kByte);
68
auto zf = at::zeros({10}, at::kFloat);
72
// Typically, failures show up in CopyKernel.cpp, so enumerating
73
// common dtypes that may show up.
74
const auto all_dtypes_for_copy = {
83
for (const auto dtype : all_dtypes_for_copy) {
84
auto tensor1 = at::empty({10}, dtype);
85
tensor1.copy_(at::zeros({10}, at::kBool));
86
tensor1.copy_(at::zeros({10}, at::kFloat));
87
tensor1.copy_(at::zeros({10}, at::kInt));
90
torch::zeros({0, 0}, torch::ScalarType::Float);
91
std::vector<float> storage(20, 1.0);
92
std::vector<int64_t> sizes({2, 10});
93
torch::from_blob(storage.data(), at::IntArrayRef(sizes), at::kFloat);
97
* Similar to setup methods there are a suite a functions that often appear
98
* under certain conditions but may avoid getting called in the trace due to the
99
* narrow nature of bundled inputs
101
void call_dependent_methods(std::set<std::string>& root_ops) {
102
bool is_training = false;
103
bool has_batchnorm = false;
104
bool has_dropout = false;
105
for (const std::string& op : root_ops) {
106
if (op.find("backward") != std::string::npos ||
107
op.find("requires_grad_") != std::string::npos) {
110
if (op.find("batch_norm") != std::string::npos) {
111
has_batchnorm = true;
113
if (op.find("dropout") != std::string::npos) {
117
if (is_training && has_batchnorm) {
129
if (is_training && has_dropout) {
130
at::dropout(at::ones({20, 20, 20}), 0.2, true);
135
* Call methods on the Tensor object that we expect to be called
136
* in production on this Tensor.
138
void consume_tensor(const at::Tensor& t) {
139
const at::Tensor& c = t;
143
std::unordered_map<std::string, c10::FunctionSchema>
144
_get_runtime_ops_and_schema() {
145
std::unordered_map<std::string, c10::FunctionSchema> result;
147
// Grab the jit operators
148
auto nonDispatcherOperators = torch::jit::getAllOperators();
149
for (const auto& full_op : nonDispatcherOperators) {
150
auto op = full_op->schema();
151
auto op_name = op.name();
152
if (!op.overload_name().empty()) {
153
op_name += ("." + op.overload_name());
155
result.emplace(op_name, op);
158
// Grab the dispatcher operators
159
auto dispatcherOperators = c10::Dispatcher::singleton().getAllOpNames();
160
for (auto& op : dispatcherOperators) {
162
const auto op_handle = c10::Dispatcher::singleton().findOp(op);
163
if (op_handle->hasSchema()) {
164
auto op_name = op.name;
165
if (!op.overload_name.empty()) {
166
op_name += ("." + op.overload_name);
168
result.emplace(op_name, op_handle->schema());
176
* For the vast majority of usecases the instrumentation in getCustomClass will
177
* catch any custom classes referenced by a model. There are however, niche
178
* situations that avoid the getCustomClass instrumentation due to some nuances
179
* of mobile model deserialization. To get around that we can search through all
180
* the used ops, and inspect their schemas to search for any referenced classes.
181
* Example schema: prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None,
182
* Scalar? output_min=None, Scalar? output_max=None) ->
183
* __torch__.torch.classes.xnnpack.LinearOpContext"
185
void recordCustomClassesFromOpSchemas(
186
std::set<std::string>& root_ops,
187
std::set<std::string>& traced_ops,
188
std::set<std::string>& loaded_classes) {
189
std::set<std::string> ops;
190
ops.insert(root_ops.begin(), root_ops.end());
191
ops.insert(traced_ops.begin(), traced_ops.end());
192
auto ops_and_schemas = _get_runtime_ops_and_schema();
194
auto record_if_class = [&](std::string type_name) {
195
// All custom class types start with __torch__ not sure if this is by
196
// chance or guaranteed
197
if (type_name.find("__torch__") != std::string::npos) {
198
// The name of a customClassType here is its fully qualified name, but
199
// in registration only the class name is used so only record that
200
auto class_name = type_name.substr(type_name.find_last_of('.') + 1);
201
// Function schemas can include other type indicators such as [] so we
202
// need to trim to just alphanumeric + '_' characters as well
203
class_name = class_name.substr(
205
class_name.find_first_not_of(
206
"aAbBcCdDeEfFgGhHiIjJkKlLmMnNoOpPqQrRsStTuUvVwWxXyYzZ_1234567890"));
207
loaded_classes.insert(class_name);
211
for (auto& op_name : ops) {
212
// This check is only necessary because of GPU models.
213
// Certain models can only run on a specific backend say metal.
214
// Those ops will be present in the models root ops, but likely
215
// not the tracer on linux
216
if (ops_and_schemas.find(op_name) != ops_and_schemas.end()) {
217
auto& schema = ops_and_schemas.at(op_name);
218
for (auto& arg : schema.arguments()) {
219
record_if_class(arg.type()->annotation_str());
221
for (auto& ret : schema.returns()) {
222
record_if_class(ret.type()->annotation_str());
229
const std::string& input_module_path,
230
std::set<std::string>& root_ops,
231
std::set<std::string>& enabled_backends,
232
KernelDTypeTracer::kernel_tags_type& called_kernel_tags) {
233
// Load the module on CPU with the flag to skip the operator exists check.
234
// This is needed so that we can load any TorchBind objects (custom classes)
235
// that this model refers to so that any operators being called from those
236
// TorchBind objects can be traced by the model tracer.
237
torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
238
root_ops = module_runner.get_root_operators();
239
std::cout << "Got " << root_ops.size() << " Root Operators." << std::endl;
241
if (torch::jit::mobile::MobileModelRunner::set_has_metal_gpu_operators(
243
std::cout << "Inferred Metal GPU Model." << std::endl;
244
root_ops.insert(gpu_metal_operators.begin(), gpu_metal_operators.end());
245
called_kernel_tags["__unused__"] = {"Float"};
246
enabled_backends.insert("Metal GPU");
248
// When we encounter a GPU model, we should call .cpu().copy_() on the
249
// tensors in the bundled inputs, since this is what will happen when
250
// such a model is executed on an iOS device (to copy the Tensor to Metal
251
// memory via a call to .metal()).
252
module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
254
std::cout << "Inferred CPU Model." << std::endl;
255
enabled_backends.insert("CPU");
256
torch::jit::mobile::MobileModelRunner mobile_module_runner(
259
// When we encounter a CPU model, we should call .cpu().copy_() on the
260
// tensors in the bundled inputs, since this is what will happen when
261
// such a model is executed on an Android device since the PyTorch JNI
262
// bindings call .cpu() in JIValue::newJIValueFromAtIValue().
263
module_runner.for_each_tensor_in_bundled_inputs(consume_tensor);
265
// If a user has bundled inputs since that api was updated to accept
266
// bundled inputs for multiple methods They should go down this route.
267
// Even if they only bundle inputs for forward they will have the new
268
// style bundled inputs. Since at this time in tracer.cpp we do not know
269
// what functions have bundled inputs we must call
270
// get_bundled_inputs_functions_and_info if it exists to get the set.
271
if (mobile_module_runner.has_new_style_bundled_inputs()) {
272
auto bundled_inputs_mapping =
273
mobile_module_runner.get_many_functions_bundled_inputs();
274
for (auto& entry : bundled_inputs_mapping) {
275
std::string function_name = entry.first;
276
std::vector<std::vector<at::IValue>> bundled_inputs = entry.second;
277
std::cout << "Got " << bundled_inputs.size() << " bundled input(s) for "
278
<< function_name << "\n\n";
279
std::vector<at::IValue> results =
280
mobile_module_runner.run_with_inputs(function_name, bundled_inputs);
282
for (auto& result : results) {
283
// Consume the result Tensor(s) when tracing on CPU since the
284
// Android/Java JNI bindings will do the same.
285
torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
288
// If get_bundled_inputs_functions_and_info does not exists we default
289
// to assuming they bundled before that change was made. If no bundled
290
// inputs are found here either an error will be thrown
292
std::vector<std::vector<at::IValue>> bundled_inputs =
293
mobile_module_runner.get_all_bundled_inputs();
294
std::cout << "Got " << bundled_inputs.size() << " bundled input(s)\n\n";
295
std::vector<at::IValue> results =
296
mobile_module_runner.run_with_inputs(bundled_inputs);
298
for (auto& result : results) {
299
// Consume the result Tensor(s) when tracing on CPU since the
300
// Android/Java JNI bindings will do the same.
301
torch::jit::mobile::for_each_tensor_in_ivalue(result, consume_tensor);
307
TracerResult trace_run(const std::string& input_module_path) {
308
return trace_run(std::vector<std::string>(1, input_module_path));
311
TracerResult trace_run(const std::vector<std::string>& input_module_paths) {
312
at::globalContext().setQEngine(at::QEngine::QNNPACK);
313
c10::ObservedOperators::getUnobservedOperatorList().clear();
315
torch::jit::mobile::OperatorCallTracer op_tracer;
316
torch::jit::mobile::KernelDTypeTracer kdtype_tracer;
317
torch::jit::mobile::CustomClassTracer custom_class_tracer;
318
torch::jit::mobile::BuildFeatureTracer build_feature_tracer;
320
call_setup_methods();
322
std::set<std::string> root_ops, traced_operators, enabled_backends,
323
loaded_classes, build_features;
324
torch::jit::mobile::KernelDTypeTracer::kernel_tags_type called_kernel_tags;
326
using torch::jit::MobileModuleLoadOptions;
328
for (auto& input_module_path : input_module_paths) {
330
at::globalContext().setQEngine(at::QEngine::QNNPACK);
333
input_module_path, root_ops, enabled_backends, called_kernel_tags);
334
// Not every model can be successfully run with fbgemm,
335
// but for those that can this can help broaden the tracers scope around
336
// hyper optimized QNNPack paths
338
at::globalContext().setQEngine(at::QEngine::FBGEMM);
340
input_module_path, root_ops, enabled_backends, called_kernel_tags);
341
} catch (std::exception& ex) {
343
<< "ModelTracer encountered an error while attempting to run the model in FBGEMM mode"
344
<< ex.what() << "\n Skipping FBGEMM execution" << std::endl;
347
at::globalContext().setQEngine(at::QEngine::QNNPACK);
348
c10::InferenceMode guard(true);
350
input_module_path, root_ops, enabled_backends, called_kernel_tags);
351
} catch (std::exception& ex) {
353
<< "ModelTracer encountered an error while attempting to run the model under an inference guard"
354
<< ex.what() << "\n Skipping inference guard execution" << std::endl;
358
call_dependent_methods(root_ops);
360
op_tracer.getCalledOperators().withLock(
361
[&](std::set<std::string>& called_operators) {
362
traced_operators = called_operators;
365
recordCustomClassesFromOpSchemas(root_ops, traced_operators, loaded_classes);
367
kdtype_tracer.getCalledKernelTags().withLock(
368
[&](KernelDTypeTracer::kernel_tags_type& kernel_tags) {
369
called_kernel_tags.insert(kernel_tags.begin(), kernel_tags.end());
372
traced_operators.insert(
373
always_included_traced_ops.begin(), always_included_traced_ops.end());
375
custom_class_tracer.getLoadedClasses().withLock(
376
[&](CustomClassTracer::custom_classes_type& custom_classes) {
377
loaded_classes.insert(custom_classes.begin(), custom_classes.end());
380
build_feature_tracer.getBuildFeatures().withLock(
381
[&](BuildFeatureTracer::build_feature_type& bf) {
382
build_features.insert(bf.begin(), bf.end());
385
TracerResult tracer_result = {
393
return tracer_result;