pytorch

Форк
0
/
builtin_functions.cpp 
185 строк · 6.0 Кб
1
#include <torch/csrc/jit/frontend/builtin_functions.h>
2

3
#include <ATen/code_template.h>
4
#include <caffe2/serialize/versions.h>
5
#include <torch/csrc/api/include/torch/jit.h>
6
#include <torch/csrc/jit/frontend/resolver.h>
7

8
namespace torch::jit {
9

10
auto scalar_operators_source = at::jit::CodeTemplate(
11
    R"SCRIPT(
12
def mul(a : ${Scalar}, b : Tensor) -> Tensor:
13
  return b * a
14
def add(a : ${Scalar}, b : Tensor) -> Tensor:
15
  return b + a
16
def ne(a : ${Scalar}, b : Tensor) -> Tensor:
17
  return b != a
18
def eq(a : ${Scalar}, b : Tensor) -> Tensor:
19
  return b == a
20
def sub(a : ${Scalar}, b : Tensor) -> Tensor:
21
  return torch.neg(b) + a
22
def div(a : ${Scalar}, b : Tensor) -> Tensor:
23
  return torch.reciprocal(b) * a
24
)SCRIPT");
25

26
auto scalar_operators_no_complex_source = at::jit::CodeTemplate(
27
    R"SCRIPT(
28
def lt(a : ${Scalar}, b : Tensor) -> Tensor:
29
  return b > a
30
def le(a : ${Scalar}, b : Tensor) -> Tensor:
31
  return b >= a
32
def gt(a : ${Scalar}, b : Tensor) -> Tensor:
33
  return b < a
34
def ge(a : ${Scalar}, b : Tensor) -> Tensor:
35
  return b <= a
36
)SCRIPT");
37

38
auto _ntuple_ops = at::jit::CodeTemplate(
39
    R"SCRIPT(
40
def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
41
  return x
42
)SCRIPT");
43

44
auto floordiv = at::jit::CodeTemplate(
45
    R"SCRIPT(
46
def floordiv(self : Tensor, other : ${Rhs_Type}) -> Tensor:
47
  return torch.floor_divide(self, other)
48
)SCRIPT");
49

50
auto tensor_properties =
51
    R"SCRIPT(
52
def ndim(a : Tensor) -> int:
53
  return a.dim()
54
def T(a : Tensor) -> Tensor:
55
  return a.numpy_T()
56
def H(a : Tensor) -> Tensor:
57
  return a.matrix_H()
58
def mT(a : Tensor) -> Tensor:
59
  return a.mT
60
def mH(a : Tensor) -> Tensor:
61
  return a.mH
62
def shape(a : Tensor) -> List[int]:
63
  return a.size()
64
)SCRIPT";
65

66
// _assert_int_or_pair is only here for backwards-compatibility with the
67
// aten::_assert_int_or_pair op which was removed once we were able to compile
68
// torch.nn.functional.assert_int_or_pair
69
// list_with_default also needs to be here for BC
70
auto aten_ops =
71
    R"SCRIPT(
72
def _assert_int_or_pair(vals: List[int], name: str, message: str):
73
  pass
74
def list_with_default(out_size: List[int], defaults: List[int]):
75
  assert len(defaults) > len(out_size)
76
  return out_size
77
def _assert(condition : bool, message : str):
78
  assert condition, message
79
# existing device operator is registered with input name `a`, which prevents
80
# torch.device(type="cuda") from working. add shim-layer here
81
def device(type: str):
82
  return torch.device(type)
83
def type(self: Tensor, dtype: int, non_blocking: bool=False, copy: bool=False) -> Tensor:
84
  return self.to(dtype, non_blocking, copy)
85
)SCRIPT";
86

87
// an additional overload for Tensor variant of _assert
88
const auto aten_ops_additional =
89
    R"SCRIPT(
90
def _assert(condition : Tensor, message : str):
91
  assert bool(condition), message
92
def __contains__(self: str, key: str):
93
    return self.find(key, 0, len(self)) != -1
94
)SCRIPT";
95

96
struct BuiltinFunctionRegistry {
97
  const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
98
    const static std::vector<Function*> empty;
99
    // when initializing the builtin function library, we will re-enter
100
    // getAllBuiltinFunctionsFor since it is called in the compiler to
101
    // lookup builtins and initializing the builtin functions calls the
102
    // compiler. To avoid deadlocking, we use a recursive mutex (same thread can
103
    // re-lock, the mutex without waiting), and report no loaded builtins during
104
    // init.
105
    std::lock_guard<std::recursive_mutex> guard(mutex);
106
    if (state == INTIIALIZING) {
107
      return empty;
108
    } else if (state == UNINITIALIZED) {
109
      state = INTIIALIZING;
110
      loadBuiltinFunctions();
111
      state = INITIALIZED;
112
    }
113
    AT_ASSERT(state == INITIALIZED);
114
    auto it = builtins_by_name_.find(name);
115
    if (it == builtins_by_name_.end())
116
      return empty;
117
    return it->second;
118
  }
119

120
 private:
121
  void loadSource(const std::string& source, const std::string& the_namespace) {
122
    std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
123
    modules.emplace_back(cu);
124
    cu->define(c10::nullopt, source, nativeResolver(), /*self=*/nullptr);
125
    for (auto& method : cu->get_functions()) {
126
      builtins_by_name_[Symbol::fromQualString(
127
                            the_namespace + "::" + method->name())]
128
          .push_back(method);
129
    }
130
  }
131

132
  void loadBuiltinFunctions() {
133
    for (auto scalar : {"float", "int", "complex"}) {
134
      at::jit::TemplateEnv env;
135
      env.s("Scalar", scalar);
136
      loadSource(scalar_operators_source.format(env), "aten");
137
    }
138

139
    for (auto scalar : {"float", "int"}) {
140
      at::jit::TemplateEnv env;
141
      env.s("Scalar", scalar);
142
      loadSource(scalar_operators_no_complex_source.format(env), "aten");
143
    }
144

145
    using str_pair = std::pair<std::string, std::string>;
146
    const std::vector<str_pair> name_len = {
147
        str_pair("single", "1"),
148
        str_pair("pair", "2"),
149
        str_pair("triple", "3"),
150
        str_pair("quadruple", "4"),
151
    };
152
    for (const auto scalar : {"float", "int"}) {
153
      for (const auto& pair : name_len) {
154
        at::jit::TemplateEnv env;
155
        env.s("Scalar", scalar);
156
        env.s("name", pair.first);
157
        env.s("Length", pair.second);
158
        loadSource(_ntuple_ops.format(env), "aten");
159
      }
160
    }
161
    for (auto rhs : {"number", "Tensor"}) {
162
      at::jit::TemplateEnv env;
163
      env.s("Rhs_Type", rhs);
164
      loadSource(floordiv.format(env), "aten");
165
    }
166

167
    loadSource(aten_ops, "aten");
168
    loadSource(aten_ops_additional, "aten");
169

170
    // These are under `prim` instead of `aten` since they exist to bind certain
171
    // tensor property getters to correpsonding methods
172
    loadSource(tensor_properties, "prim");
173
  }
174
  enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
175
  std::recursive_mutex mutex;
176
  std::vector<std::shared_ptr<CompilationUnit>> modules;
177
  std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name_;
178
};
179

180
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
181
  static BuiltinFunctionRegistry registry;
182
  return registry.getAllBuiltinFunctionsFor(name);
183
}
184

185
} // namespace torch::jit
186

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.