1
#include <ATen/ScalarOps.h>
2
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
6
void tupleIndex(Stack& stack) {
7
int64_t index = pop(stack).toInt();
8
auto tuple = pop(stack).toTuple();
9
auto norm_index = normalizeIndex(index, tuple->elements().size());
11
norm_index >= static_cast<int64_t>(tuple->elements().size())) {
12
throw std::out_of_range("Tuple list index out of range");
14
stack.emplace_back(tuple->elements()[norm_index]);
17
void raiseException(Stack& stack) {
21
throw JITException(pop(stack).toStringRef());
24
void raiseExceptionWithMessage(Stack& stack) {
27
c10::optional<std::string> qualified_class_name =
28
pop(stack).toOptional<std::string>();
32
throw JITException(message, qualified_class_name);
35
void is(Stack& stack) {
37
pop(stack, self, obj);
38
push(stack, self.is(obj));
41
void unInitialized(Stack& stack) {
42
push(stack, IValue::uninitialized());
45
void isNot(Stack& stack) {
47
pop(stack, self, obj);
48
push(stack, !self.is(obj));
51
void aten_format(Stack& stack) {
52
size_t num_inputs = pop(stack).toInt();
53
format(stack, num_inputs);
56
void size(Stack& stack) {
57
auto t = std::move(pop(stack)).toTensor();
58
pack(stack, t.sizes().vec());
61
void sym_size(Stack& stack) {
62
auto t = std::move(pop(stack)).toTensor();
63
pack(stack, t.sym_sizes().vec());
65
void sym_size_int(Stack& stack) {
66
auto dim = pop(stack).toInt();
67
auto t = pop(stack).toTensor();
68
push(stack, t.sym_sizes()[dim]);
70
void sym_stride_int(Stack& stack) {
71
auto dim = pop(stack).toInt();
72
auto t = pop(stack).toTensor();
73
push(stack, t.sym_strides()[dim]);
76
void sym_numel(Stack& stack) {
77
auto t = std::move(pop(stack)).toTensor();
78
push(stack, t.sym_numel());
81
void sym_storage_offset(Stack& stack) {
82
auto t = std::move(pop(stack)).toTensor();
83
push(stack, t.sym_storage_offset());
86
void sym_stride(Stack& stack) {
87
auto t = std::move(pop(stack)).toTensor();
88
pack(stack, t.sym_strides().vec());
91
void device(Stack& stack) {
92
push(stack, pop(stack).toTensor().device());
95
void device_with_index(Stack& stack) {
96
std::string type = pop(stack).toStringRef();
97
int index = pop(stack).toInt();
98
std::string device_str = type + ":" + std::to_string(index);
99
auto device = c10::Device(device_str);
103
void dtype(Stack& stack) {
106
push(stack, static_cast<int64_t>(a.scalar_type()));
109
void layout(Stack& stack) {
110
push(stack, pop(stack).toTensor().layout());
113
void toPrimDType(Stack& stack) {
118
pop(stack, non_blocking, copy);
119
c10::optional<at::ScalarType> scalarType =
120
pop(stack).toOptional<at::ScalarType>();
121
c10::optional<c10::Device> device = c10::nullopt;
122
at::Tensor self = pop(stack).toTensor();
123
push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
126
void dim(Stack& stack) {
127
at::Tensor arg = pop(stack).toTensor();
128
push(stack, arg.dim());
131
void _not(Stack& stack) {
132
push(stack, !pop(stack).toBool());
135
void boolTensor(Stack& stack) {
138
push(stack, at::native::is_nonzero(a));
141
void toList(Stack& stack) {
148
pop(stack, elem_ty_val);
153
if (!t.device().is_cpu()) {
160
if (elem_ty_val == 0) {
161
out_ty = at::IntType::get();
162
} else if (elem_ty_val == 1) {
163
out_ty = at::FloatType::get();
164
} else if (elem_ty_val == 2) {
165
out_ty = at::BoolType::get();
166
} else if (elem_ty_val == 3) {
167
out_ty = at::ComplexType::get();
171
"Unsupported element type for tolist; only int, float, complex and bool are supported");
180
(out_ty == at::FloatType::get() && t.is_floating_point()) ||
181
(out_ty == at::ComplexType::get() && t.is_complex()) ||
182
tryScalarTypeFromJitType(*out_ty) == t.scalar_type(),
183
"Output annotation element type and runtime tensor element type must match for tolist(): ",
184
*tryScalarTypeFromJitType(*out_ty),
192
"Output annotation list dimension and runtime tensor dimension must match for tolist()");
195
for (const auto i : c10::irange(dim_val)) {
197
out_ty = at::ListType::create(out_ty);
200
int64_t dim = t.dim();
201
auto sizes = t.sizes();
202
auto strides = t.strides();
203
size_t element_size = t.element_size();
204
char* data = static_cast<char*>(t.data_ptr());
205
auto result = tensorToListRecursive(
206
data, 0, dim, out_ty, t.scalar_type(), sizes, strides, element_size);
207
push(stack, std::move(result));
210
void numToTensorScalar(Stack& stack) {
213
push(stack, c10::scalar_to_tensor(s));
216
void isCuda(Stack& stack) {
219
push(stack, a.is_cuda());
222
void numToTensorBool(Stack& stack) {
226
push(stack, c10::scalar_to_tensor(b));
229
void dictIndex(Stack& stack) {
230
auto key = pop(stack);
231
auto dict = pop(stack).toGenericDict();
232
auto value = dict.find(key);
233
if (value == dict.end()) {
234
AT_ERROR("KeyError: ", key);
236
push(stack, value->value());
239
static const C10_UNUSED std::array<mobile::prim_op_fn_register, 16> op_reg = {
240
mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex),
241
mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor),
242
mobile::prim_op_fn_register("aten::format", aten_format),
243
mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar),
244
mobile::prim_op_fn_register(
245
"prim::RaiseException",
246
raiseExceptionWithMessage),
247
mobile::prim_op_fn_register("prim::device", device),
248
mobile::prim_op_fn_register("prim::dtype", dtype),
249
mobile::prim_op_fn_register("prim::layout", layout),
250
mobile::prim_op_fn_register("aten::__not__", _not),
251
mobile::prim_op_fn_register("aten::__is__", is),
252
mobile::prim_op_fn_register("aten::__isnot__", isNot),
253
mobile::prim_op_fn_register("aten::dim", dim),
254
mobile::prim_op_fn_register("prim::Uninitialized", unInitialized),
255
mobile::prim_op_fn_register("prim::is_cuda", isCuda),
256
mobile::prim_op_fn_register("aten::__getitem__.Dict_str", dictIndex),
257
mobile::prim_op_fn_register("prim::unchecked_cast", noop),