pytorch

Форк
0
206 строк · 6.7 Кб
1
#include <iostream>
2
#include <sstream>
3
#include <string>
4

5
/**
6
 * The tracer.cpp generates a binary that accepts multiple Torch Mobile Model(s)
7
 * (with bytecode.pkl), each of which has at least 1 bundled
8
 * input. This binary then feeds the bundled input(s) into each corresponding
9
 * model and executes it using the lite interpreter.
10
 *
11
 * Both root operators as well as called operators are recorded and saved
12
 * into a YAML file (whose path is provided on the command line).
13
 *
14
 * Note: Root operators may include primary and other operators that
15
 * are not invoked using the dispatcher, and hence they may not show
16
 * up in the Traced Operator list.
17
 *
18
 */
19

20
#include <ATen/core/dispatch/ObservedOperators.h>
21
#include <torch/csrc/autograd/grad_mode.h>
22
#include <torch/csrc/jit/mobile/import.h>
23
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
24
#include <torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h>
25
#include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
26
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
27
#include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
28
#include <torch/csrc/jit/mobile/module.h>
29
#include <torch/csrc/jit/mobile/parse_operators.h>
30
#include <torch/script.h>
31

32
typedef std::map<std::string, std::set<std::string>> kt_type;
33

34
C10_DEFINE_string(
35
    model_input_path,
36
    "",
37
    "A comma separated list of path(s) to the input model file(s) (.ptl).");
38

39
C10_DEFINE_string(
40
    build_yaml_path,
41
    "",
42
    "The path of the output YAML file containing traced operator information.");
43

44
#define REQUIRE_STRING_ARG(name)                            \
45
  if (FLAGS_##name.empty()) {                               \
46
    std::cerr << "You must specify the flag --" #name "\n"; \
47
    return 1;                                               \
48
  }
49

50
#define REQUIRE_INT_ARG(name)                               \
51
  if (FLAGS_##name == -1) {                                 \
52
    std::cerr << "You must specify the flag --" #name "\n"; \
53
    return 1;                                               \
54
  }
55

56
void printOpYAML(
57
    std::ostream& out,
58
    int indent,
59
    const std::string& op_name,
60
    bool is_used_for_training,
61
    bool is_root_operator,
62
    bool include_all_overloads) {
63
  out << std::string(indent, ' ') << op_name << ":" << std::endl;
64
  out << std::string(indent + 2, ' ')
65
      << "is_used_for_training: " << (is_used_for_training ? "true" : "false")
66
      << std::endl;
67
  out << std::string(indent + 2, ' ')
68
      << "is_root_operator: " << (is_root_operator ? "true" : "false")
69
      << std::endl;
70
  out << std::string(indent + 2, ' ')
71
      << "include_all_overloads: " << (include_all_overloads ? "true" : "false")
72
      << std::endl;
73
}
74

75
void printOpsYAML(
76
    std::ostream& out,
77
    const std::set<std::string>& operator_list,
78
    bool is_used_for_training,
79
    bool is_root_operator,
80
    bool include_all_overloads) {
81
  for (auto& it : operator_list) {
82
    printOpYAML(out, 2, it, false, is_root_operator, false);
83
  }
84
}
85

86
void printDTypeYAML(
87
    std::ostream& out,
88
    int indent,
89
    const std::string& kernel_tag_name,
90
    const std::set<std::string> dtypes) {
91
  std::string indent_str = std::string(indent, ' ');
92
  out << indent_str << kernel_tag_name << ":" << std::endl;
93
  for (auto& dtype : dtypes) {
94
    out << indent_str << "- " << dtype << std::endl;
95
  }
96
}
97

98
void printDTypesYAML(
99
    std::ostream& out,
100
    const torch::jit::mobile::KernelDTypeTracer::kernel_tags_type&
101
        kernel_tags) {
102
  for (auto& it : kernel_tags) {
103
    printDTypeYAML(out, 2, it.first, it.second);
104
  }
105
}
106

107
void printCustomClassesYAML(
108
    std::ostream& out,
109
    const torch::jit::mobile::CustomClassTracer::custom_classes_type&
110
        loaded_classes) {
111
  for (auto& class_name : loaded_classes) {
112
    out << "- " << class_name << std::endl;
113
  }
114
}
115

116
/**
117
 * Runs multiple PyTorch lite interpreter models, and additionally writes
118
 * out a list of root and called operators, kernel dtypes, and loaded/used
119
 * TorchBind custom classes.
120
 */
121
int main(int argc, char* argv[]) {
122
  if (!c10::ParseCommandLineFlags(&argc, &argv)) {
123
    std::cerr << "Failed to parse command line flags!" << std::endl;
124
    return 1;
125
  }
126

127
  REQUIRE_STRING_ARG(model_input_path);
128
  REQUIRE_STRING_ARG(build_yaml_path);
129

130
  std::istringstream sin(FLAGS_model_input_path);
131
  std::ofstream yaml_out(FLAGS_build_yaml_path);
132

133
  std::cout << "Output: " << FLAGS_build_yaml_path << std::endl;
134
  torch::jit::mobile::TracerResult tracer_result;
135
  std::vector<std::string> model_input_paths;
136

137
  for (std::string model_input_path;
138
       std::getline(sin, model_input_path, ',');) {
139
    std::cout << "Processing: " << model_input_path << std::endl;
140
    model_input_paths.push_back(model_input_path);
141
  }
142

143
  try {
144
    tracer_result = torch::jit::mobile::trace_run(model_input_paths);
145
  } catch (std::exception& ex) {
146
    std::cerr
147
        << "ModelTracer has not been able to load the module for the following reasons:\n"
148
        << ex.what()
149
        << "\nPlease consider opening an issue at https://github.com/pytorch/pytorch/issues "
150
        << "with the detailed error message." << std::endl;
151

152
    throw ex;
153
  }
154

155
  if (tracer_result.traced_operators.size() <=
156
      torch::jit::mobile::always_included_traced_ops.size()) {
157
    std::cerr
158
        << c10::str(
159
               "Error traced_operators size: ",
160
               tracer_result.traced_operators.size(),
161
               ". Expected the traced operator list to be bigger then the default size ",
162
               torch::jit::mobile::always_included_traced_ops.size(),
163
               ". Please report a bug in PyTorch.")
164
        << std::endl;
165
  }
166

167
  // If the op exist in both traced_ops and root_ops, leave it in root_ops only
168
  for (const auto& root_op : tracer_result.root_ops) {
169
    if (tracer_result.traced_operators.find(root_op) !=
170
        tracer_result.traced_operators.end()) {
171
      tracer_result.traced_operators.erase(root_op);
172
    }
173
  }
174

175
  yaml_out << "include_all_non_op_selectives: false" << std::endl;
176
  yaml_out << "build_features: []" << std::endl;
177
  yaml_out << "operators:" << std::endl;
178
  printOpsYAML(
179
      yaml_out,
180
      tracer_result.root_ops,
181
      false /* is_used_for_training */,
182
      true /* is_root_operator */,
183
      false /* include_all_overloads */);
184
  printOpsYAML(
185
      yaml_out,
186
      tracer_result.traced_operators,
187
      false /* is_used_for_training */,
188
      false /* is_root_operator */,
189
      false /* include_all_overloads */);
190

191
  yaml_out << "kernel_metadata:";
192
  if (tracer_result.called_kernel_tags.empty()) {
193
    yaml_out << " []";
194
  }
195
  yaml_out << std::endl;
196
  printDTypesYAML(yaml_out, tracer_result.called_kernel_tags);
197

198
  yaml_out << "custom_classes:";
199
  if (tracer_result.loaded_classes.empty()) {
200
    yaml_out << " []";
201
  }
202
  yaml_out << std::endl;
203
  printCustomClassesYAML(yaml_out, tracer_result.loaded_classes);
204

205
  return 0;
206
}
207

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.