6
* The tracer.cpp generates a binary that accepts multiple Torch Mobile Model(s)
7
* (with bytecode.pkl), each of which has at least 1 bundled
8
* input. This binary then feeds the bundled input(s) into each corresponding
9
* model and executes it using the lite interpreter.
11
* Both root operators as well as called operators are recorded and saved
12
* into a YAML file (whose path is provided on the command line).
14
* Note: Root operators may include primary and other operators that
15
* are not invoked using the dispatcher, and hence they may not show
16
* up in the Traced Operator list.
20
#include <ATen/core/dispatch/ObservedOperators.h>
21
#include <torch/csrc/autograd/grad_mode.h>
22
#include <torch/csrc/jit/mobile/import.h>
23
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
24
#include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
25
#include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
26
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
27
#include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
28
#include <torch/csrc/jit/mobile/module.h>
29
#include <torch/csrc/jit/mobile/parse_operators.h>
30
#include <torch/script.h>
32
typedef std::map<std::string, std::set<std::string>> kt_type;
37
"A comma separated list of path(s) to the input model file(s) (.ptl).");
42
"The path of the output YAML file containing traced operator information.");
44
#define REQUIRE_STRING_ARG(name) \
45
if (FLAGS_##name.empty()) { \
46
std::cerr << "You must specify the flag --" #name "\n"; \
50
#define REQUIRE_INT_ARG(name) \
51
if (FLAGS_##name == -1) { \
52
std::cerr << "You must specify the flag --" #name "\n"; \
59
const std::string& op_name,
60
bool is_used_for_training,
61
bool is_root_operator,
62
bool include_all_overloads) {
63
out << std::string(indent, ' ') << op_name << ":" << std::endl;
64
out << std::string(indent + 2, ' ')
65
<< "is_used_for_training: " << (is_used_for_training ? "true" : "false")
67
out << std::string(indent + 2, ' ')
68
<< "is_root_operator: " << (is_root_operator ? "true" : "false")
70
out << std::string(indent + 2, ' ')
71
<< "include_all_overloads: " << (include_all_overloads ? "true" : "false")
77
const std::set<std::string>& operator_list,
78
bool is_used_for_training,
79
bool is_root_operator,
80
bool include_all_overloads) {
81
for (auto& it : operator_list) {
82
printOpYAML(out, 2, it, false, is_root_operator, false);
89
const std::string& kernel_tag_name,
90
const std::set<std::string> dtypes) {
91
std::string indent_str = std::string(indent, ' ');
92
out << indent_str << kernel_tag_name << ":" << std::endl;
93
for (auto& dtype : dtypes) {
94
out << indent_str << "- " << dtype << std::endl;
100
const torch::jit::mobile::KernelDTypeTracer::kernel_tags_type&
102
for (auto& it : kernel_tags) {
103
printDTypeYAML(out, 2, it.first, it.second);
107
void printCustomClassesYAML(
109
const torch::jit::mobile::CustomClassTracer::custom_classes_type&
111
for (auto& class_name : loaded_classes) {
112
out << "- " << class_name << std::endl;
117
* Runs multiple PyTorch lite interpreter models, and additionally writes
118
* out a list of root and called operators, kernel dtypes, and loaded/used
119
* TorchBind custom classes.
121
int main(int argc, char* argv[]) {
122
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
123
std::cerr << "Failed to parse command line flags!" << std::endl;
127
REQUIRE_STRING_ARG(model_input_path);
128
REQUIRE_STRING_ARG(build_yaml_path);
130
std::istringstream sin(FLAGS_model_input_path);
131
std::ofstream yaml_out(FLAGS_build_yaml_path);
133
std::cout << "Output: " << FLAGS_build_yaml_path << std::endl;
134
torch::jit::mobile::TracerResult tracer_result;
135
std::vector<std::string> model_input_paths;
137
for (std::string model_input_path;
138
std::getline(sin, model_input_path, ',');) {
139
std::cout << "Processing: " << model_input_path << std::endl;
140
model_input_paths.push_back(model_input_path);
144
tracer_result = torch::jit::mobile::trace_run(model_input_paths);
145
} catch (std::exception& ex) {
147
<< "ModelTracer has not been able to load the module for the following reasons:\n"
149
<< "\nPlease consider opening an issue at https://github.com/pytorch/pytorch/issues "
150
<< "with the detailed error message." << std::endl;
155
if (tracer_result.traced_operators.size() <=
156
torch::jit::mobile::always_included_traced_ops.size()) {
159
"Error traced_operators size: ",
160
tracer_result.traced_operators.size(),
161
". Expected the traced operator list to be bigger then the default size ",
162
torch::jit::mobile::always_included_traced_ops.size(),
163
". Please report a bug in PyTorch.")
167
// If the op exist in both traced_ops and root_ops, leave it in root_ops only
168
for (const auto& root_op : tracer_result.root_ops) {
169
if (tracer_result.traced_operators.find(root_op) !=
170
tracer_result.traced_operators.end()) {
171
tracer_result.traced_operators.erase(root_op);
175
yaml_out << "include_all_non_op_selectives: false" << std::endl;
176
yaml_out << "build_features: []" << std::endl;
177
yaml_out << "operators:" << std::endl;
180
tracer_result.root_ops,
181
false /* is_used_for_training */,
182
true /* is_root_operator */,
183
false /* include_all_overloads */);
186
tracer_result.traced_operators,
187
false /* is_used_for_training */,
188
false /* is_root_operator */,
189
false /* include_all_overloads */);
191
yaml_out << "kernel_metadata:";
192
if (tracer_result.called_kernel_tags.empty()) {
195
yaml_out << std::endl;
196
printDTypesYAML(yaml_out, tracer_result.called_kernel_tags);
198
yaml_out << "custom_classes:";
199
if (tracer_result.loaded_classes.empty()) {
202
yaml_out << std::endl;
203
printCustomClassesYAML(yaml_out, tracer_result.loaded_classes);