pytorch

Форк
0
/
backend_init.cpp 
193 строки · 7.3 Кб
1
#include <torch/csrc/jit/backends/backend_init.h>
2

3
#include <pybind11/iostream.h>
4
#include <torch/csrc/jit/backends/backend_detail.h>
5
#include <torch/csrc/jit/backends/backend_resolver.h>
6
#include <torch/csrc/jit/python/module_python.h>
7
#include <torch/csrc/jit/python/pybind_utils.h>
8
#include <torch/csrc/utils/pybind.h>
9

10
namespace torch {
11
namespace jit {
12

13
// Get all types that are shared in the module hierarchy rooted at \p mod.
14
std::unordered_set<TypePtr> getSharedModuleTypes(Module& mod) {
15
  // Maintain a set of all TypePtrs.
16
  std::unordered_set<TypePtr> types;
17
  // Maintain another set of TypePtrs that have been encountered more than once.
18
  std::unordered_set<TypePtr> duplicate_types;
19

20
  // Iterate over all modules in the hierarchy, including the root.
21
  for (auto module : mod.modules()) {
22
    auto module_type = module.type();
23
    if (types.count(module_type) > 0) {
24
      duplicate_types.insert(module_type);
25
    }
26

27
    types.insert(module_type);
28
  }
29

30
  return duplicate_types;
31
}
32

33
// Selectively lower \p mod to a backend. \p to_backend
34
// is called to lower modules. \p modules_to_lower contains
35
// qualified names of submodules of \p mod that should be lowered.
36
void toBackendSelectiveImpl(
37
    Module& mod,
38
    const py::function& to_backend,
39
    const std::vector<std::string>& modules_to_lower,
40
    const std::unordered_set<TypePtr>& duplicate_types) {
41
  // This map will be used later to remap types in ancestor module graphs for
42
  // all lowered submodules.
43
  std::unordered_map<TypePtr, TypePtr> type_remap;
44

45
  // For each module that should be lowered:
46
  for (const auto& module_to_lower : modules_to_lower) {
47
    // Use QualifiedName to parse the qualified module names.
48
    c10::QualifiedName qual_module_name(module_to_lower);
49
    auto& atoms = qual_module_name.atoms();
50

51
    // Search through the module hierarchy using the atoms of
52
    // qual_module_name until current points to the module to
53
    // be lowered and parent points to its parent.
54
    Module current = mod;
55
    Module parent;
56

57
    for (size_t i = 0, e = atoms.size(); i < e; ++i) {
58
      IValue submodule = current.attr(atoms[i]);
59
      if (submodule.isModule()) {
60
        if (i == e - 1) {
61
          parent = current;
62
        }
63
        current = submodule.toModule();
64
      } else {
65
        std::stringstream err;
66
        err << "Attribute named " << atoms[i] << " is not a Module";
67
        throw std::runtime_error(err.str());
68
      }
69
    }
70

71
    // Check that the parent type is not shared and therefore can be edited.
72
    if (duplicate_types.count(parent.type()) > 0) {
73
      throw py::cast_error(c10::str(
74
          "Selective lowering is only supported for module hierarchies with unique types for selected modules; ",
75
          parent.type()->repr_str(),
76
          " is shared"));
77
    }
78

79
    // Call to_backend on the module that needs to be lowered. It needs to be
80
    // wrapped before doing so because _to_jit_backend accepts wrapped modules.
81
    // The result needs to be unwrapped in order to access its type below.
82
    auto lowered_submodule =
83
        py::cast<Module>(to_backend(py::module::import("torch.jit._recursive")
84
                                        .attr("wrap_cpp_module")(current))
85
                             .attr("_c"));
86

87
    // Adjust the parent's type so that the type of the submodule matches
88
    // the type of lowered_submodule.
89
    auto parent_type = parent.type();
90

91
    parent_type->unsafeChangeAttributeType(
92
        atoms.back(), lowered_submodule.type());
93
    parent.setattr(atoms.back(), lowered_submodule._ivalue());
94

95
    // Record the type mapping from old type -> lowered type.
96
    type_remap[current.type()] = lowered_submodule.type();
97
  }
98

99
  // Having lowered all of the modules that needed to be lowered, remap types in
100
  // all graphs in the hierarchy so that the graphs all use the new lowered
101
  // type.
102
  auto type_remap_fn = [&type_remap](TypePtr in) {
103
    auto it = type_remap.find(in);
104
    if (it == type_remap.end())
105
      return in;
106
    return it->second;
107
  };
108

109
  // modules() iterates over all modules in the hierarchy including the root.
110
  for (auto module : mod.modules()) {
111
    auto module_type = module.type();
112
    for (auto& fn : module_type->methods()) {
113
      auto method = module.get_method(fn->name());
114
      auto graph = method.graph();
115
      graph->remapTypes(type_remap_fn);
116
      auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
117
      fn->setSchema(new_schema);
118
    }
119
  }
120
}
121

122
Module codegen_func(
123
    const std::string& backend_name,
124
    const Module& orig_module,
125
    const py::dict& method_compile_spec) {
126
  // Represents of a Type of Dict[str, Any].
127
  auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
128
  return detail::codegen_backend_module(
129
      backend_name,
130
      orig_module,
131
      toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
132
      any_dict_ty);
133
}
134

135
void initJitBackendBindings(PyObject* module) {
136
  // Bind a function for lowering to each JIT backend. The name of the backend
137
  // must be the first argument. For example, to lower a Module to
138
  // "example_backend", declared as
139
  //
140
  //  static auto cls = torch::jit::backend<ExampleBackend>("example_backend");
141
  //
142
  // this function must be called like
143
  //
144
  //  torch._C._jit_to_backend("example_backend", module, spec)
145
  auto m = py::handle(module).cast<py::module>();
146
  m.def(
147
      "_jit_to_backend",
148
      [=](const std::string& backend_name,
149
          py::handle orig_module,
150
          const py::dict& method_compile_spec) {
151
        py::scoped_ostream_redirect cerr(
152
            std::cerr, py::module_::import("sys").attr("stderr"));
153
        py::scoped_ostream_redirect cout(
154
            std::cout, py::module_::import("sys").attr("stdout"));
155
        return py::module::import("torch.jit._recursive")
156
            .attr("wrap_cpp_module")(codegen_func(
157
                backend_name,
158
                py::cast<Module>(orig_module.attr("_c")),
159
                method_compile_spec));
160
      });
161

162
  m.def(
163
      "_jit_to_backend_selective",
164
      [=](py::handle orig_module,
165
          const py::function& to_backend,
166
          const std::vector<std::string>& modules_to_lower) {
167
        py::scoped_ostream_redirect cerr(
168
            std::cerr, py::module_::import("sys").attr("stderr"));
169
        py::scoped_ostream_redirect cout(
170
            std::cout, py::module_::import("sys").attr("stdout"));
171
        if (auto original_module =
172
                as_module(py::cast<py::object>(orig_module))) {
173
          // Clone the Module to avoid editing types that are shared with
174
          // Modules in other instances outside this hierarchy.
175
          Module& mod = original_module.value();
176
          auto cloned_mod = mod.clone();
177
          // Get all shared module types. Type sharing is only a problem if the
178
          // parent modules of the ones to lower are in this set.
179
          auto shared_types = getSharedModuleTypes(cloned_mod);
180
          toBackendSelectiveImpl(
181
              cloned_mod, to_backend, modules_to_lower, shared_types);
182
          // Wrap the result in a RecursiveScriptModule because that's what
183
          // the caller passed in.
184
          return py::module::import("torch.jit._recursive")
185
              .attr("wrap_cpp_module")(cloned_mod);
186
        }
187

188
        throw py::cast_error(c10::str(
189
            "Object ", py::str(orig_module), " is not a ScriptModule"));
190
      });
191
}
192
} // namespace jit
193
} // namespace torch
194

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.