pytorch

Форк
0
/
backend_detail.cpp 
413 строк · 15.8 Кб
1
#include <torch/csrc/jit/backends/backend_detail.h>
2

3
#include <ATen/code_template.h>
4
#include <ATen/core/jit_type.h>
5
#include <torch/csrc/jit/backends/backend.h>
6
#include <torch/csrc/jit/backends/backend_debug_handler.h>
7
#include <torch/csrc/jit/backends/backend_debug_info.h>
8
#include <torch/csrc/jit/backends/backend_resolver.h>
9

10
#include <memory>
11
#include <stack>
12
#include <unordered_map>
13

14
namespace torch {
15
namespace jit {
16
namespace detail {
17
namespace {
18

19
/*
20
 * This is the API via which backend's preprocess function will obtain debug
21
 * handles corresponding to the nodes of the graph for the lowered methods of
22
 * the module.
23
 * Implementation: Given graph
24
 * For each node of the graph, request debug handle via debug_info_recorder.
25
 * debug_info_recorder returns the next debug handle and record node with
26
 * corresponding debug info, such as source range and inlined callstack.
27
 *
28
 * Backend code for lowering module, preprocess, calls
29
 * generate_debug_handles(graph)) which will return debug handles corresponding
30
 * to the Node* of the said graph.
31
 *
32
 * In to_backend, after lowering, stopRecording is called on
33
 * BackendModuleDebugInfoRecorder: It will extract debug map. This map gets
34
 * stored as part of the lowered module.
35
 * During serialization, specifically for bytecode serialization, check is made
36
 * to see if the model being serialized has any lowered modules. If so
37
 * corresponding debug map is extracted and serialized.
38
 */
39

40
NodeToDebugHandle generate_debug_handles(
41
    BackendDebugInfoRecorder& debug_info_recorder,
42
    const std::shared_ptr<Graph>& graph) {
43
  NodeToDebugHandle node_to_debug_handles;
44

45
  std::stack<Block*> blocks_to_visit;
46
  // TODO: Look into using DepthFirstGraphNodeIterator
47
  // At the moment it takes non-const graph but maybe we can make it
48
  // general such that it can work with both.
49
  blocks_to_visit.push(graph->block());
50
  while (!blocks_to_visit.empty()) {
51
    Block* b = blocks_to_visit.top();
52
    blocks_to_visit.pop();
53
    for (Node* n : b->nodes()) {
54
      DebugHandleType debug_handle = debug_info_recorder.getNextDebugHandle(n);
55
      node_to_debug_handles.emplace(n, debug_handle);
56
      for (Block* subblock : n->blocks()) {
57
        blocks_to_visit.push(subblock);
58
      }
59
    }
60
  }
61
  return node_to_debug_handles;
62
}
63

64
std::unordered_map<std::string, BackendPreprocessFunction>&
65
backendPreprocessFunctions() {
66
  static std::unordered_map<std::string, BackendPreprocessFunction>
67
      preprocess_functions;
68
  return preprocess_functions;
69
}
70
} // namespace
71

72
bool hasBackendPreprocessFunction(const std::string& name) {
73
  return backendPreprocessFunctions().count(name);
74
}
75

76
void registerBackendPreprocessFunction(
77
    const std::string& name,
78
    const BackendPreprocessFunction& preprocess) {
79
  TORCH_CHECK(
80
      !detail::hasBackendPreprocessFunction(name),
81
      "Preprocessing function for backend ",
82
      name,
83
      " is already registered. Ensure that registration is only called once.");
84
  detail::backendPreprocessFunctions()[name] = preprocess;
85
}
86

87
BackendPreprocessFunction getBackendPreprocessFunction(
88
    const std::string& name) {
89
  TORCH_CHECK(
90
      hasBackendPreprocessFunction(name),
91
      "Preprocessing function for backend ",
92
      name,
93
      " is not registered.");
94
  return backendPreprocessFunctions()[name];
95
}
96

97
Module codegen_backend_module(
98
    const std::string& backend_name,
99
    const Module& orig_module,
100
    const c10::Dict<IValue, IValue>& method_compile_spec,
101
    const c10::DictTypePtr& any_dict_ty) {
102
  const c10::QualifiedName qual_backend_name(
103
      {"__torch__", "torch", "classes", kBackendsNamespace, backend_name});
104
  // TODO: Validate method_compile_spec.
105

106
  // Clone orig_module to make sure backend transformation is
107
  // functional.
108
  auto cloned_module = orig_module.clone();
109
  auto module_name = orig_module.type()->name()->qualifiedName();
110

111
  // Generate LoweredModule.
112
  Module loweredModule(
113
      "torch.jit.LoweredModule." + backend_name + "." + module_name,
114
      std::make_shared<CompilationUnit>(),
115
      /*shouldMangle=*/true);
116

117
  // Generate WrapperModule.
118
  Module wrapper(
119
      "torch.jit.LoweredWrapper." + backend_name + "." + module_name,
120
      std::make_shared<CompilationUnit>(),
121
      /*shouldMangle=*/true);
122

123
  // 1. Initialized debug info recorder.
124
  // 2. Later call debug_info_recorder.stopRecording() to gather
125
  //    recorded debug info and save it in __backend_debug_info.
126
  BackendDebugInfoRecorder debug_info_recorder;
127

128
  // Generate attributes.
129
  // This is the preprocessed module.
130
  // For backwards compatibility, for backends that implement preprocessing in
131
  // the backend interface rather than as a separate function, we just pass
132
  // the cloned original Module.
133

134
  BackendDebugHandleGenerator debug_handle_generator =
135
      [&](const std::shared_ptr<Graph>& g) {
136
        return generate_debug_handles(debug_info_recorder, g);
137
      };
138
  loweredModule.register_attribute(
139
      "__processed_module",
140
      AnyType::get(),
141
      detail::getBackendPreprocessFunction(backend_name)(
142
          cloned_module, method_compile_spec, debug_handle_generator),
143
      /*is_param=*/false);
144

145
  // This is for the method_compile_spec passed in to to_<backend> or
146
  // loaded from an exported model.
147
  loweredModule.register_attribute(
148
      "__method_compile_spec",
149
      any_dict_ty,
150
      method_compile_spec,
151
      /*is_param=*/false);
152

153
  // This is a pointer to a backend instance that is used to access
154
  // compile and execute functions.
155
  auto cls = getCustomClass(qual_backend_name.qualifiedName());
156
  TORCH_INTERNAL_ASSERT(cls);
157
  c10::intrusive_ptr<torch::CustomClassHolder> backend;
158
  loweredModule.register_attribute(
159
      "__backend", cls, IValue::make_capsule(backend));
160

161
  // This is the list of opaque backend handles returned by
162
  // backend.compile.
163
  loweredModule.register_attribute(
164
      "__handles",
165
      any_dict_ty,
166
      c10::impl::GenericDict(
167
          any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
168
      /*is_param=*/false);
169

170
  // Methods.
171

172
  // This is a helper function for creating a new instance of the
173
  // backend class.
174
  static const auto create_backend_ct = at::jit::CodeTemplate(R"(
175
            def __create_backend(self):
176
                self.__backend = $name()
177
            )");
178
  at::jit::TemplateEnv create_backend_te;
179
  create_backend_te.s("name", qual_backend_name.qualifiedName());
180
  loweredModule.define(
181
      create_backend_ct.format(create_backend_te), loweredModuleResolver());
182

183
  // Helper function to expose backend.is_available() to Module generation code.
184
  // Assumes self.__backend exists (i.e. __create_backend() has already been
185
  // invoked).
186
  loweredModule.define(
187
      R"(
188
            def __is_available(self):
189
                return self.__backend.is_available()
190
            )",
191
      loweredModuleResolver());
192

193
  // backend_debug_info_class is an instance of BackendDebugInfo that
194
  // stores debug information.
195
  // The purpose of this class is to make the debug information available
196
  // at model saving time for serializing it outside of the lowered module,
197
  // while still tying it to the module's lifetime (so it gets destroyed along
198
  // with it).
199
  // Whereas this information is not serialized as part of the lowered
200
  // module, we still need to provide a valid instance of the
201
  // BackendDebugInfo class when the lowered module is deserialized.
202
  // Since the deserialized modules does not need this information,
203
  // we create a "dummy" instance with no extra code dependencies (to avoid
204
  // overhead) when the backend is created in __setstate__.
205
  c10::intrusive_ptr<torch::CustomClassHolder> backend_debug_info_class;
206
  const c10::QualifiedName backend_debug_info_class_name(
207
      {"__torch__",
208
       "torch",
209
       "classes",
210
       kBackendUtilsNamespace,
211
       kBackendDebugInfoClass});
212
  auto debug_info_cls =
213
      getCustomClass(backend_debug_info_class_name.qualifiedName());
214
  TORCH_CHECK(debug_info_cls, "BackendDebugInfo class must be available.");
215
  loweredModule.register_attribute(
216
      "__backend_debug_info",
217
      OptionalType::create(debug_info_cls),
218
      IValue::make_capsule(backend_debug_info_class));
219
  static const auto create_backend_debug_info_ct = at::jit::CodeTemplate(R"(
220
            def __create_backend_debug_info(self):
221
                self.__backend_debug_info = $backend_debug_info()
222
            )");
223
  at::jit::TemplateEnv create_backend_debug_info_te;
224
  create_backend_debug_info_te.s(
225
      "backend_debug_info", backend_debug_info_class_name.qualifiedName());
226
  loweredModule.define(
227
      create_backend_debug_info_ct.format(create_backend_debug_info_te),
228
      loweredModuleResolver());
229

230
  // getstate and setstate are for serialization/deserialization of
231
  // the LoweredModule.
232
  // setstate is in charge of initializing self.__backend by invoking
233
  // __create_backend().
234
  loweredModule.define(
235
      R"(
236
            def __getstate__(self):
237
                # The third parameter indicates whether __setstate__ must create
238
                # the backend instance. It's hardcoded to True since the only
239
                # case it can be false is when __setstate__ is called from
240
                # outside the module (at module creation time), because
241
                # __create_backed has been called already (also directly).
242
                return self.__method_compile_spec, self.__processed_module, True
243
            )",
244
      loweredModuleResolver());
245

246
  loweredModule.define(
247
      R"(
248
            def __setstate__(self, state):
249
                self.__method_compile_spec = state[0]
250
                self.__processed_module = state[1]
251
                # state[2] indicates whether to create the backend instance.
252
                if state[2]:
253
                    self.__create_backend()
254
                    self.__create_backend_debug_info()
255
                if self.__backend.is_available() :
256
                    self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
257
                else:
258
                    raise Exception("Backend is not available.")
259
            )",
260
      loweredModuleResolver());
261

262
  // This loop generates one method on the LoweredModule for every key
263
  // in method_compile_spec.
264
  std::vector<std::string> wrapper_methods;
265
  for (auto& e : method_compile_spec) {
266
    std::string method_name = e.key().toStringRef();
267
    static const auto method_ct = at::jit::CodeTemplate(R"(
268
            def $method(self${,def_inputs}):
269
                typed_inputs: List[Any] = [${fwd_inputs,}]
270
                if self.__backend.is_available() :
271
                  $unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs)
272
                  ${refine,}
273
                  return $ret
274
                else:
275
                  raise Exception("Backend is not available.")
276
            )");
277
    static const auto wrapper_method_ct = at::jit::CodeTemplate(R"(
278
            def $method(self${,def_inputs}):
279
                return self.__loweredModule__.$method(${fwd_inputs})
280
            )");
281

282
    at::jit::TemplateEnv method_te, wrapper_method_te;
283
    method_te.s("method", method_name);
284
    wrapper_method_te.s("method", method_name);
285
    auto method = orig_module.get_method(method_name);
286
    auto& function = method.function();
287
    auto& schema = function.getSchema();
288

289
    // Generate the inputs for the function signature (def_inputs) and
290
    // for passing to backend.execute (fwd_inputs).
291
    std::vector<std::string> def_inputs, fwd_inputs;
292
    for (const auto& arg : schema.arguments()) {
293
      auto name = arg.name();
294

295
      // Skip self since that is only and always present in the
296
      // signature.
297
      if (name == "self") {
298
        continue;
299
      }
300

301
      auto default_value = arg.default_value();
302

303
      if (arg.kwarg_only()) {
304
        // If this is a kwarg, it needs to be emitted as keyword=value
305
        // in the definition and keyword=keyword in the call to
306
        // backend_execute.
307
        TORCH_INTERNAL_ASSERT(default_value.has_value());
308
        std::stringstream def_ss, fwd_ss;
309
        // Annotate type of the arg
310
        def_ss << name << ": " << arg.type()->annotation_str(nullptr) << "=";
311
        fwd_ss << name << "=" << name;
312
        default_value->repr(
313
            def_ss, [](std::ostream&, const IValue&) -> bool { return false; });
314
        def_inputs.emplace_back(def_ss.str());
315
        fwd_inputs.emplace_back(fwd_ss.str());
316
      } else {
317
        // If this is not a kwarg, it should be emitted as is in the
318
        // signature and the call to backend_execute.
319
        std::stringstream def_ss;
320
        // Annotate type of the arg
321
        def_ss << name << ": " << arg.type()->annotation_str(nullptr);
322
        def_inputs.emplace_back(def_ss.str());
323
        fwd_inputs.emplace_back(name);
324
      }
325
    }
326

327
    // Generate a comma-delimited list of identifiers to unpack
328
    // outputs, as well as a list of isinstance checks to make sure
329
    // the backend returned the types it was supposed to.
330
    std::stringstream out_ss, type_check_ss;
331
    std::vector<std::string> type_checks;
332
    TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
333
    auto out_ty = schema.returns().at(0).type();
334

335
    out_ss << "_0";
336
    type_check_ss << "assert isinstance(_0, ";
337

338
    auto out_tuple_ty = out_ty->cast<TupleType>();
339

340
    if (out_tuple_ty) {
341
      auto tuple_elements = out_tuple_ty->elements();
342
      type_check_ss << tuple_elements[0]->annotation_str() << ")";
343
      type_checks.emplace_back(type_check_ss.str());
344
      for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
345
        type_check_ss.str(std::string());
346
        type_check_ss.clear();
347
        out_ss << ", _" << i;
348
        type_check_ss << "assert isinstance(_" << i << ", "
349
                      << tuple_elements[i]->annotation_str() << ")";
350
        type_checks.emplace_back(type_check_ss.str());
351
      }
352
    } else {
353
      type_check_ss << out_ty->annotation_str() << ")";
354
      type_checks.emplace_back(type_check_ss.str());
355
    }
356

357
    method_te.v("def_inputs", def_inputs);
358
    method_te.v("fwd_inputs", fwd_inputs);
359
    method_te.v("refine", type_checks);
360
    method_te.s("unpack", out_ss.str());
361

362
    wrapper_method_te.v("def_inputs", def_inputs);
363
    wrapper_method_te.v("fwd_inputs", fwd_inputs);
364
    wrapper_methods.push_back(wrapper_method_ct.format(wrapper_method_te));
365

366
    // If the output type is a single element tuple then add an extra comma
367
    // to ensure the final output maintains this type.
368
    if (out_tuple_ty && out_tuple_ty->elements().size() == 1) {
369
      out_ss << ",";
370
    }
371

372
    method_te.s("ret", out_ss.str());
373

374
    loweredModule.define(method_ct.format(method_te), loweredModuleResolver());
375
  }
376

377
  // If backend is available, call __setstate__ to ensure that the returned
378
  // Module is ready to run.
379
  // Otherwise throw a warning indicating that the resulting Module is not
380
  // ready for execution until is loaded to a device with the backend.
381
  loweredModule.run_method("__create_backend");
382
  if (loweredModule.run_method("__is_available").toBool()) {
383
    auto state = at::ivalue::Tuple::create(
384
        method_compile_spec,
385
        loweredModule.attr("__processed_module"),
386
        /*create_backend*/ false);
387
    loweredModule.run_method("__setstate__", state);
388
  } else {
389
    TORCH_WARN(
390
        "Backend [",
391
        backend_name,
392
        "] is not available. Execution of this Module is still possible by "
393
        "saving and loading on a device where the backend is available.");
394
  }
395

396
  // stop debug info recording and get debug_info_map
397
  auto debug_info_map = debug_info_recorder.stopRecording();
398
  loweredModule.run_method("__create_backend_debug_info");
399
  auto backend_debug_info = loweredModule.attr("__backend_debug_info")
400
                                .toCustomClass<PyTorchBackendDebugInfo>();
401
  backend_debug_info->setDebugInfoMap(std::move(debug_info_map));
402

403
  // Wrap lowered module to obfuscate custom serialization logic
404
  wrapper.register_module("__loweredModule__", loweredModule);
405
  for (auto& method : wrapper_methods) {
406
    wrapper.define(method);
407
  }
408

409
  return wrapper;
410
}
411
} // namespace detail
412
} // namespace jit
413
} // namespace torch
414

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.