1
#include <torch/csrc/jit/backends/backend_detail.h>
3
#include <ATen/code_template.h>
4
#include <ATen/core/jit_type.h>
5
#include <torch/csrc/jit/backends/backend.h>
6
#include <torch/csrc/jit/backends/backend_debug_handler.h>
7
#include <torch/csrc/jit/backends/backend_debug_info.h>
8
#include <torch/csrc/jit/backends/backend_resolver.h>
12
#include <unordered_map>
20
* This is the API via which backend's preprocess function will obtain debug
21
* handles corresponding to the nodes of the graph for the lowered methods of
23
* Implementation: Given graph
24
* For each node of the graph, request debug handle via debug_info_recorder.
25
* debug_info_recorder returns the next debug handle and record node with
26
* corresponding debug info, such as source range and inlined callstack.
28
* Backend code for lowering module, preprocess, calls
29
* generate_debug_handles(graph)) which will return debug handles corresponding
30
* to the Node* of the said graph.
32
* In to_backend, after lowering, stopRecording is called on
33
* BackendModuleDebugInfoRecorder: It will extract debug map. This map gets
34
* stored as part of the lowered module.
35
* During serialization, specifically for bytecode serialization, check is made
36
* to see if the model being serialized has any lowered modules. If so
37
* corresponding debug map is extracted and serialized.
40
NodeToDebugHandle generate_debug_handles(
41
BackendDebugInfoRecorder& debug_info_recorder,
42
const std::shared_ptr<Graph>& graph) {
43
NodeToDebugHandle node_to_debug_handles;
45
std::stack<Block*> blocks_to_visit;
46
// TODO: Look into using DepthFirstGraphNodeIterator
47
// At the moment it takes non-const graph but maybe we can make it
48
// general such that it can work with both.
49
blocks_to_visit.push(graph->block());
50
while (!blocks_to_visit.empty()) {
51
Block* b = blocks_to_visit.top();
52
blocks_to_visit.pop();
53
for (Node* n : b->nodes()) {
54
DebugHandleType debug_handle = debug_info_recorder.getNextDebugHandle(n);
55
node_to_debug_handles.emplace(n, debug_handle);
56
for (Block* subblock : n->blocks()) {
57
blocks_to_visit.push(subblock);
61
return node_to_debug_handles;
64
std::unordered_map<std::string, BackendPreprocessFunction>&
65
backendPreprocessFunctions() {
66
static std::unordered_map<std::string, BackendPreprocessFunction>
68
return preprocess_functions;
72
bool hasBackendPreprocessFunction(const std::string& name) {
73
return backendPreprocessFunctions().count(name);
76
void registerBackendPreprocessFunction(
77
const std::string& name,
78
const BackendPreprocessFunction& preprocess) {
80
!detail::hasBackendPreprocessFunction(name),
81
"Preprocessing function for backend ",
83
" is already registered. Ensure that registration is only called once.");
84
detail::backendPreprocessFunctions()[name] = preprocess;
87
BackendPreprocessFunction getBackendPreprocessFunction(
88
const std::string& name) {
90
hasBackendPreprocessFunction(name),
91
"Preprocessing function for backend ",
93
" is not registered.");
94
return backendPreprocessFunctions()[name];
97
Module codegen_backend_module(
98
const std::string& backend_name,
99
const Module& orig_module,
100
const c10::Dict<IValue, IValue>& method_compile_spec,
101
const c10::DictTypePtr& any_dict_ty) {
102
const c10::QualifiedName qual_backend_name(
103
{"__torch__", "torch", "classes", kBackendsNamespace, backend_name});
104
// TODO: Validate method_compile_spec.
106
// Clone orig_module to make sure backend transformation is
108
auto cloned_module = orig_module.clone();
109
auto module_name = orig_module.type()->name()->qualifiedName();
111
// Generate LoweredModule.
112
Module loweredModule(
113
"torch.jit.LoweredModule." + backend_name + "." + module_name,
114
std::make_shared<CompilationUnit>(),
115
/*shouldMangle=*/true);
117
// Generate WrapperModule.
119
"torch.jit.LoweredWrapper." + backend_name + "." + module_name,
120
std::make_shared<CompilationUnit>(),
121
/*shouldMangle=*/true);
123
// 1. Initialized debug info recorder.
124
// 2. Later call debug_info_recorder.stopRecording() to gather
125
// recorded debug info and save it in __backend_debug_info.
126
BackendDebugInfoRecorder debug_info_recorder;
128
// Generate attributes.
129
// This is the preprocessed module.
130
// For backwards compatibility, for backends that implement preprocessing in
131
// the backend interface rather than as a separate function, we just pass
132
// the cloned original Module.
134
BackendDebugHandleGenerator debug_handle_generator =
135
[&](const std::shared_ptr<Graph>& g) {
136
return generate_debug_handles(debug_info_recorder, g);
138
loweredModule.register_attribute(
139
"__processed_module",
141
detail::getBackendPreprocessFunction(backend_name)(
142
cloned_module, method_compile_spec, debug_handle_generator),
145
// This is for the method_compile_spec passed in to to_<backend> or
146
// loaded from an exported model.
147
loweredModule.register_attribute(
148
"__method_compile_spec",
153
// This is a pointer to a backend instance that is used to access
154
// compile and execute functions.
155
auto cls = getCustomClass(qual_backend_name.qualifiedName());
156
TORCH_INTERNAL_ASSERT(cls);
157
c10::intrusive_ptr<torch::CustomClassHolder> backend;
158
loweredModule.register_attribute(
159
"__backend", cls, IValue::make_capsule(backend));
161
// This is the list of opaque backend handles returned by
163
loweredModule.register_attribute(
166
c10::impl::GenericDict(
167
any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
172
// This is a helper function for creating a new instance of the
174
static const auto create_backend_ct = at::jit::CodeTemplate(R"(
175
def __create_backend(self):
176
self.__backend = $name()
178
at::jit::TemplateEnv create_backend_te;
179
create_backend_te.s("name", qual_backend_name.qualifiedName());
180
loweredModule.define(
181
create_backend_ct.format(create_backend_te), loweredModuleResolver());
183
// Helper function to expose backend.is_available() to Module generation code.
184
// Assumes self.__backend exists (i.e. __create_backend() has already been
186
loweredModule.define(
188
def __is_available(self):
189
return self.__backend.is_available()
191
loweredModuleResolver());
193
// backend_debug_info_class is an instance of BackendDebugInfo that
194
// stores debug information.
195
// The purpose of this class is to make the debug information available
196
// at model saving time for serializing it outside of the lowered module,
197
// while still tying it to the module's lifetime (so it gets destroyed along
199
// Whereas this information is not serialized as part of the lowered
200
// module, we still need to provide a valid instance of the
201
// BackendDebugInfo class when the lowered module is deserialized.
202
// Since the deserialized modules does not need this information,
203
// we create a "dummy" instance with no extra code dependencies (to avoid
204
// overhead) when the backend is created in __setstate__.
205
c10::intrusive_ptr<torch::CustomClassHolder> backend_debug_info_class;
206
const c10::QualifiedName backend_debug_info_class_name(
210
kBackendUtilsNamespace,
211
kBackendDebugInfoClass});
212
auto debug_info_cls =
213
getCustomClass(backend_debug_info_class_name.qualifiedName());
214
TORCH_CHECK(debug_info_cls, "BackendDebugInfo class must be available.");
215
loweredModule.register_attribute(
216
"__backend_debug_info",
217
OptionalType::create(debug_info_cls),
218
IValue::make_capsule(backend_debug_info_class));
219
static const auto create_backend_debug_info_ct = at::jit::CodeTemplate(R"(
220
def __create_backend_debug_info(self):
221
self.__backend_debug_info = $backend_debug_info()
223
at::jit::TemplateEnv create_backend_debug_info_te;
224
create_backend_debug_info_te.s(
225
"backend_debug_info", backend_debug_info_class_name.qualifiedName());
226
loweredModule.define(
227
create_backend_debug_info_ct.format(create_backend_debug_info_te),
228
loweredModuleResolver());
230
// getstate and setstate are for serialization/deserialization of
231
// the LoweredModule.
232
// setstate is in charge of initializing self.__backend by invoking
233
// __create_backend().
234
loweredModule.define(
236
def __getstate__(self):
237
# The third parameter indicates whether __setstate__ must create
238
# the backend instance. It's hardcoded to True since the only
239
# case it can be false is when __setstate__ is called from
240
# outside the module (at module creation time), because
241
# __create_backed has been called already (also directly).
242
return self.__method_compile_spec, self.__processed_module, True
244
loweredModuleResolver());
246
loweredModule.define(
248
def __setstate__(self, state):
249
self.__method_compile_spec = state[0]
250
self.__processed_module = state[1]
251
# state[2] indicates whether to create the backend instance.
253
self.__create_backend()
254
self.__create_backend_debug_info()
255
if self.__backend.is_available() :
256
self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
258
raise Exception("Backend is not available.")
260
loweredModuleResolver());
262
// This loop generates one method on the LoweredModule for every key
263
// in method_compile_spec.
264
std::vector<std::string> wrapper_methods;
265
for (auto& e : method_compile_spec) {
266
std::string method_name = e.key().toStringRef();
267
static const auto method_ct = at::jit::CodeTemplate(R"(
268
def $method(self${,def_inputs}):
269
typed_inputs: List[Any] = [${fwd_inputs,}]
270
if self.__backend.is_available() :
271
$unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs)
275
raise Exception("Backend is not available.")
277
static const auto wrapper_method_ct = at::jit::CodeTemplate(R"(
278
def $method(self${,def_inputs}):
279
return self.__loweredModule__.$method(${fwd_inputs})
282
at::jit::TemplateEnv method_te, wrapper_method_te;
283
method_te.s("method", method_name);
284
wrapper_method_te.s("method", method_name);
285
auto method = orig_module.get_method(method_name);
286
auto& function = method.function();
287
auto& schema = function.getSchema();
289
// Generate the inputs for the function signature (def_inputs) and
290
// for passing to backend.execute (fwd_inputs).
291
std::vector<std::string> def_inputs, fwd_inputs;
292
for (const auto& arg : schema.arguments()) {
293
auto name = arg.name();
295
// Skip self since that is only and always present in the
297
if (name == "self") {
301
auto default_value = arg.default_value();
303
if (arg.kwarg_only()) {
304
// If this is a kwarg, it needs to be emitted as keyword=value
305
// in the definition and keyword=keyword in the call to
307
TORCH_INTERNAL_ASSERT(default_value.has_value());
308
std::stringstream def_ss, fwd_ss;
309
// Annotate type of the arg
310
def_ss << name << ": " << arg.type()->annotation_str(nullptr) << "=";
311
fwd_ss << name << "=" << name;
313
def_ss, [](std::ostream&, const IValue&) -> bool { return false; });
314
def_inputs.emplace_back(def_ss.str());
315
fwd_inputs.emplace_back(fwd_ss.str());
317
// If this is not a kwarg, it should be emitted as is in the
318
// signature and the call to backend_execute.
319
std::stringstream def_ss;
320
// Annotate type of the arg
321
def_ss << name << ": " << arg.type()->annotation_str(nullptr);
322
def_inputs.emplace_back(def_ss.str());
323
fwd_inputs.emplace_back(name);
327
// Generate a comma-delimited list of identifiers to unpack
328
// outputs, as well as a list of isinstance checks to make sure
329
// the backend returned the types it was supposed to.
330
std::stringstream out_ss, type_check_ss;
331
std::vector<std::string> type_checks;
332
TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
333
auto out_ty = schema.returns().at(0).type();
336
type_check_ss << "assert isinstance(_0, ";
338
auto out_tuple_ty = out_ty->cast<TupleType>();
341
auto tuple_elements = out_tuple_ty->elements();
342
type_check_ss << tuple_elements[0]->annotation_str() << ")";
343
type_checks.emplace_back(type_check_ss.str());
344
for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
345
type_check_ss.str(std::string());
346
type_check_ss.clear();
347
out_ss << ", _" << i;
348
type_check_ss << "assert isinstance(_" << i << ", "
349
<< tuple_elements[i]->annotation_str() << ")";
350
type_checks.emplace_back(type_check_ss.str());
353
type_check_ss << out_ty->annotation_str() << ")";
354
type_checks.emplace_back(type_check_ss.str());
357
method_te.v("def_inputs", def_inputs);
358
method_te.v("fwd_inputs", fwd_inputs);
359
method_te.v("refine", type_checks);
360
method_te.s("unpack", out_ss.str());
362
wrapper_method_te.v("def_inputs", def_inputs);
363
wrapper_method_te.v("fwd_inputs", fwd_inputs);
364
wrapper_methods.push_back(wrapper_method_ct.format(wrapper_method_te));
366
// If the output type is a single element tuple then add an extra comma
367
// to ensure the final output maintains this type.
368
if (out_tuple_ty && out_tuple_ty->elements().size() == 1) {
372
method_te.s("ret", out_ss.str());
374
loweredModule.define(method_ct.format(method_te), loweredModuleResolver());
377
// If backend is available, call __setstate__ to ensure that the returned
378
// Module is ready to run.
379
// Otherwise throw a warning indicating that the resulting Module is not
380
// ready for execution until is loaded to a device with the backend.
381
loweredModule.run_method("__create_backend");
382
if (loweredModule.run_method("__is_available").toBool()) {
383
auto state = at::ivalue::Tuple::create(
385
loweredModule.attr("__processed_module"),
386
/*create_backend*/ false);
387
loweredModule.run_method("__setstate__", state);
392
"] is not available. Execution of this Module is still possible by "
393
"saving and loading on a device where the backend is available.");
396
// stop debug info recording and get debug_info_map
397
auto debug_info_map = debug_info_recorder.stopRecording();
398
loweredModule.run_method("__create_backend_debug_info");
399
auto backend_debug_info = loweredModule.attr("__backend_debug_info")
400
.toCustomClass<PyTorchBackendDebugInfo>();
401
backend_debug_info->setDebugInfoMap(std::move(debug_info_map));
403
// Wrap lowered module to obfuscate custom serialization logic
404
wrapper.register_module("__loweredModule__", loweredModule);
405
for (auto& method : wrapper_methods) {
406
wrapper.define(method);