1
#include <ATen/core/dynamic_type.h>
2
#include <ATen/core/type_factory.h>
3
#include <torch/csrc/jit/mobile/register_ops_common_utils.h>
8
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
10
// Handle negative indexing
11
idx = list_size + idx;
16
IValue tensorToListRecursive(
19
int64_t num_tensor_dims,
21
at::ScalarType scalar_ty,
22
at::IntArrayRef sizes,
23
at::IntArrayRef strides,
24
size_t element_size) {
25
// If ty is a ListType, get the element type.
26
if (auto list_type = ty->cast<at::ListType>()) {
27
ty = list_type->getElementType();
29
// If the output type is a scalar, read and push one scalar of
30
// the right type onto the stack.
31
if (ty == at::IntType::get()) {
32
int64_t scalar = *(int64_t*)data;
33
return IValue(scalar);
34
} else if (ty == at::FloatType::get()) {
35
TORCH_INTERNAL_ASSERT(
36
scalar_ty == at::ScalarType::Float ||
37
scalar_ty == at::ScalarType::Double,
38
"Unexpected scalar type for Tensor");
40
scalar_ty == at::ScalarType::Float ? *(float*)data : *(double*)data;
41
return IValue(scalar);
42
} else if (ty == at::ComplexType::get()) {
43
TORCH_INTERNAL_ASSERT(
44
scalar_ty == at::ScalarType::ComplexFloat ||
45
scalar_ty == at::ScalarType::ComplexDouble,
46
"Unexpected scalar type for Tensor");
47
c10::complex<double> scalar = scalar_ty == at::ScalarType::ComplexFloat
48
? *(c10::complex<float>*)data
49
: *(c10::complex<double>*)data;
50
return IValue(scalar);
51
} else if (ty == at::BoolType::get()) {
52
bool scalar = *(bool*)data;
53
return IValue(scalar);
58
" is not one of the supported types for tolist: int, float, bool");
62
// Make the result list consisting of elements of type ty. Since this
63
// invocation is processing dimension cur_dim, there will be sizes[cur_dim]
65
auto result = c10::impl::GenericList(ty);
66
result.reserve(sizes[cur_dim]);
68
// Since ty was a list type, tensorToListRecursive needs to be called
69
// recursively on each slice of the tensor in the current dimension.
70
for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
71
auto inner_result = tensorToListRecursive(
81
if (inner_result.isList()) {
82
result.emplace_back(inner_result.toList());
83
} else if (inner_result.isComplexDouble()) {
84
result.emplace_back(inner_result.toComplexDouble());
85
} else if (inner_result.isDouble()) {
86
result.emplace_back(inner_result.toDouble());
87
} else if (inner_result.isInt()) {
88
result.emplace_back(inner_result.toInt());
89
} else if (inner_result.isBool()) {
90
result.emplace_back(inner_result.toBool());
92
TORCH_INTERNAL_ASSERT(
93
false && "Unknown return type for tensorToListRecursive");
96
data += strides[cur_dim] * element_size;