pytorch

Форк
0
/
promoted_prim_ops.cpp 
264 строки · 7.7 Кб
1
#include <ATen/ScalarOps.h>
2
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
3
namespace torch {
4
namespace jit {
5

6
void tupleIndex(Stack& stack) {
7
  int64_t index = pop(stack).toInt();
8
  auto tuple = pop(stack).toTuple();
9
  auto norm_index = normalizeIndex(index, tuple->elements().size());
10
  if (norm_index < 0 ||
11
      norm_index >= static_cast<int64_t>(tuple->elements().size())) {
12
    throw std::out_of_range("Tuple list index out of range");
13
  }
14
  stack.emplace_back(tuple->elements()[norm_index]);
15
}
16

17
void raiseException(Stack& stack) {
18
  // this kernel supports RaiseException with only one argument: the error
19
  // DEPRECATED from bytecode_version 8;
20
  // Please do not make any changes to this to support BC
21
  throw JITException(pop(stack).toStringRef());
22
}
23

24
void raiseExceptionWithMessage(Stack& stack) {
25
  // this kernel supports RaiseException with only two arguments: the error and
26
  // the message Please make changes only to this kernel
27
  c10::optional<std::string> qualified_class_name =
28
      pop(stack).toOptional<std::string>();
29
  std::string message;
30
  pop(stack, message);
31

32
  throw JITException(message, qualified_class_name);
33
}
34

35
void is(Stack& stack) {
36
  IValue self, obj;
37
  pop(stack, self, obj);
38
  push(stack, self.is(obj));
39
}
40

41
void unInitialized(Stack& stack) {
42
  push(stack, IValue::uninitialized());
43
}
44

45
void isNot(Stack& stack) {
46
  IValue self, obj;
47
  pop(stack, self, obj);
48
  push(stack, !self.is(obj));
49
}
50

51
void aten_format(Stack& stack) {
52
  size_t num_inputs = pop(stack).toInt();
53
  format(stack, num_inputs);
54
}
55

56
void size(Stack& stack) {
57
  auto t = std::move(pop(stack)).toTensor();
58
  pack(stack, t.sizes().vec());
59
}
60

61
void sym_size(Stack& stack) {
62
  auto t = std::move(pop(stack)).toTensor();
63
  pack(stack, t.sym_sizes().vec());
64
}
65
void sym_size_int(Stack& stack) {
66
  auto dim = pop(stack).toInt();
67
  auto t = pop(stack).toTensor();
68
  push(stack, t.sym_sizes()[dim]);
69
}
70
void sym_stride_int(Stack& stack) {
71
  auto dim = pop(stack).toInt();
72
  auto t = pop(stack).toTensor();
73
  push(stack, t.sym_strides()[dim]);
74
}
75

76
void sym_numel(Stack& stack) {
77
  auto t = std::move(pop(stack)).toTensor();
78
  push(stack, t.sym_numel());
79
}
80

81
void sym_storage_offset(Stack& stack) {
82
  auto t = std::move(pop(stack)).toTensor();
83
  push(stack, t.sym_storage_offset());
84
}
85

86
void sym_stride(Stack& stack) {
87
  auto t = std::move(pop(stack)).toTensor();
88
  pack(stack, t.sym_strides().vec());
89
}
90

91
void device(Stack& stack) {
92
  push(stack, pop(stack).toTensor().device());
93
}
94

95
void device_with_index(Stack& stack) {
96
  std::string type = pop(stack).toStringRef();
97
  int index = pop(stack).toInt();
98
  std::string device_str = type + ":" + std::to_string(index);
99
  auto device = c10::Device(device_str);
100
  push(stack, device);
101
}
102

103
void dtype(Stack& stack) {
104
  at::Tensor a;
105
  pop(stack, a);
106
  push(stack, static_cast<int64_t>(a.scalar_type()));
107
}
108

109
void layout(Stack& stack) {
110
  push(stack, pop(stack).toTensor().layout());
111
}
112

113
void toPrimDType(Stack& stack) {
114
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
115
  bool non_blocking;
116
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
117
  bool copy;
118
  pop(stack, non_blocking, copy);
119
  c10::optional<at::ScalarType> scalarType =
120
      pop(stack).toOptional<at::ScalarType>();
121
  c10::optional<c10::Device> device = c10::nullopt;
122
  at::Tensor self = pop(stack).toTensor();
123
  push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
124
}
125

126
void dim(Stack& stack) {
127
  at::Tensor arg = pop(stack).toTensor();
128
  push(stack, arg.dim());
129
}
130

131
void _not(Stack& stack) {
132
  push(stack, !pop(stack).toBool());
133
}
134

135
void boolTensor(Stack& stack) {
136
  at::Tensor a;
137
  pop(stack, a);
138
  push(stack, at::native::is_nonzero(a));
139
}
140

141
void toList(Stack& stack) {
142
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
143
  int elem_ty_val;
144
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
145
  int dim_val;
146
  at::Tensor t;
147

148
  pop(stack, elem_ty_val);
149
  pop(stack, dim_val);
150
  pop(stack, t);
151

152
  // If the Tensor is not on the CPU, transfer it.
153
  if (!t.device().is_cpu()) {
154
    t = t.cpu();
155
  }
156

157
  // Rebuild the output type using elem_ty_val and dim_val. Start
158
  // with the element type corresponding to elem_ty_val.
159
  at::TypePtr out_ty;
160
  if (elem_ty_val == 0) {
161
    out_ty = at::IntType::get();
162
  } else if (elem_ty_val == 1) {
163
    out_ty = at::FloatType::get();
164
  } else if (elem_ty_val == 2) {
165
    out_ty = at::BoolType::get();
166
  } else if (elem_ty_val == 3) {
167
    out_ty = at::ComplexType::get();
168
  } else {
169
    TORCH_CHECK(
170
        false,
171
        "Unsupported element type for tolist; only int, float, complex and bool are supported");
172
  }
173

174
  // Check that type of the Tensor matches that of the annotation.
175
  // Make an exception for the case in which the annotated type is
176
  // float/complex and the Tensor data type is also float/complex;
177
  // the elements will be casted to double/c10::complex<double>
178
  // later.
179
  TORCH_CHECK(
180
      (out_ty == at::FloatType::get() && t.is_floating_point()) ||
181
          (out_ty == at::ComplexType::get() && t.is_complex()) ||
182
          tryScalarTypeFromJitType(*out_ty) == t.scalar_type(),
183
      "Output annotation element type and runtime tensor element type must match for tolist(): ",
184
      *tryScalarTypeFromJitType(*out_ty),
185
      " vs ",
186
      t.scalar_type());
187

188
  // Check that the dimension of the Tensor matches that of the
189
  // annotation.
190
  TORCH_CHECK(
191
      dim_val == t.dim(),
192
      "Output annotation list dimension and runtime tensor dimension must match for tolist()");
193

194
  // Wrap out_ty in a ListType dim times.
195
  for (const auto i : c10::irange(dim_val)) {
196
    (void)i; // Suppress unused variable warning
197
    out_ty = at::ListType::create(out_ty);
198
  }
199

200
  int64_t dim = t.dim();
201
  auto sizes = t.sizes();
202
  auto strides = t.strides();
203
  size_t element_size = t.element_size();
204
  char* data = static_cast<char*>(t.data_ptr());
205
  auto result = tensorToListRecursive(
206
      data, 0, dim, out_ty, t.scalar_type(), sizes, strides, element_size);
207
  push(stack, std::move(result));
208
}
209

210
void numToTensorScalar(Stack& stack) {
211
  at::Scalar s;
212
  pop(stack, s);
213
  push(stack, c10::scalar_to_tensor(s));
214
}
215

216
void isCuda(Stack& stack) {
217
  at::Tensor a;
218
  pop(stack, a);
219
  push(stack, a.is_cuda());
220
}
221

222
void numToTensorBool(Stack& stack) {
223
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
224
  bool b;
225
  pop(stack, b);
226
  push(stack, c10::scalar_to_tensor(b));
227
}
228

229
void dictIndex(Stack& stack) {
230
  auto key = pop(stack);
231
  auto dict = pop(stack).toGenericDict();
232
  auto value = dict.find(key);
233
  if (value == dict.end()) {
234
    AT_ERROR("KeyError: ", key);
235
  }
236
  push(stack, value->value());
237
}
238

239
static const C10_UNUSED std::array<mobile::prim_op_fn_register, 16> op_reg = {
240
    mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex),
241
    mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor),
242
    mobile::prim_op_fn_register("aten::format", aten_format),
243
    mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar),
244
    mobile::prim_op_fn_register(
245
        "prim::RaiseException",
246
        raiseExceptionWithMessage),
247
    mobile::prim_op_fn_register("prim::device", device),
248
    mobile::prim_op_fn_register("prim::dtype", dtype),
249
    mobile::prim_op_fn_register("prim::layout", layout),
250
    mobile::prim_op_fn_register("aten::__not__", _not),
251
    mobile::prim_op_fn_register("aten::__is__", is),
252
    mobile::prim_op_fn_register("aten::__isnot__", isNot),
253
    mobile::prim_op_fn_register("aten::dim", dim),
254
    mobile::prim_op_fn_register("prim::Uninitialized", unInitialized),
255
    mobile::prim_op_fn_register("prim::is_cuda", isCuda),
256
    mobile::prim_op_fn_register("aten::__getitem__.Dict_str", dictIndex),
257
    mobile::prim_op_fn_register("prim::unchecked_cast", noop),
258
    // TODO: (@pavithran) size is overloaded with int[] and Tensor
259
    // so this throws error expecting int not Tensor
260
    // mobile::prim_op_fn_register("aten::size", size)
261
};
262

263
} // namespace jit
264
} // namespace torch
265

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.