pytorch

Форк
0
/
register_ops_common_utils.cpp 
103 строки · 3.2 Кб
1
#include <ATen/core/dynamic_type.h>
2
#include <ATen/core/type_factory.h>
3
#include <torch/csrc/jit/mobile/register_ops_common_utils.h>
4

5
namespace torch {
6
namespace jit {
7

8
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
9
  if (idx < 0) {
10
    // Handle negative indexing
11
    idx = list_size + idx;
12
  }
13
  return idx;
14
}
15

16
IValue tensorToListRecursive(
17
    char* data,
18
    int64_t cur_dim,
19
    int64_t num_tensor_dims,
20
    at::TypePtr ty,
21
    at::ScalarType scalar_ty,
22
    at::IntArrayRef sizes,
23
    at::IntArrayRef strides,
24
    size_t element_size) {
25
  // If ty is a ListType, get the element type.
26
  if (auto list_type = ty->cast<at::ListType>()) {
27
    ty = list_type->getElementType();
28
  } else {
29
    // If the output type is a scalar, read and push one scalar of
30
    // the right type onto the stack.
31
    if (ty == at::IntType::get()) {
32
      int64_t scalar = *(int64_t*)data;
33
      return IValue(scalar);
34
    } else if (ty == at::FloatType::get()) {
35
      TORCH_INTERNAL_ASSERT(
36
          scalar_ty == at::ScalarType::Float ||
37
              scalar_ty == at::ScalarType::Double,
38
          "Unexpected scalar type for Tensor");
39
      double scalar =
40
          scalar_ty == at::ScalarType::Float ? *(float*)data : *(double*)data;
41
      return IValue(scalar);
42
    } else if (ty == at::ComplexType::get()) {
43
      TORCH_INTERNAL_ASSERT(
44
          scalar_ty == at::ScalarType::ComplexFloat ||
45
              scalar_ty == at::ScalarType::ComplexDouble,
46
          "Unexpected scalar type for Tensor");
47
      c10::complex<double> scalar = scalar_ty == at::ScalarType::ComplexFloat
48
          ? *(c10::complex<float>*)data
49
          : *(c10::complex<double>*)data;
50
      return IValue(scalar);
51
    } else if (ty == at::BoolType::get()) {
52
      bool scalar = *(bool*)data;
53
      return IValue(scalar);
54
    } else {
55
      TORCH_CHECK(
56
          false,
57
          ty->repr_str(),
58
          " is not one of the supported types for tolist: int, float, bool");
59
    }
60
  }
61

62
  // Make the result list consisting of elements of type ty. Since this
63
  // invocation is processing dimension cur_dim, there will be sizes[cur_dim]
64
  // output elements.
65
  auto result = c10::impl::GenericList(ty);
66
  result.reserve(sizes[cur_dim]);
67

68
  // Since ty was a list type, tensorToListRecursive needs to be called
69
  // recursively on each slice of the tensor in the current dimension.
70
  for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
71
    auto inner_result = tensorToListRecursive(
72
        data,
73
        cur_dim + 1,
74
        num_tensor_dims,
75
        ty,
76
        scalar_ty,
77
        sizes,
78
        strides,
79
        element_size);
80

81
    if (inner_result.isList()) {
82
      result.emplace_back(inner_result.toList());
83
    } else if (inner_result.isComplexDouble()) {
84
      result.emplace_back(inner_result.toComplexDouble());
85
    } else if (inner_result.isDouble()) {
86
      result.emplace_back(inner_result.toDouble());
87
    } else if (inner_result.isInt()) {
88
      result.emplace_back(inner_result.toInt());
89
    } else if (inner_result.isBool()) {
90
      result.emplace_back(inner_result.toBool());
91
    } else {
92
      TORCH_INTERNAL_ASSERT(
93
          false && "Unknown return type for tensorToListRecursive");
94
    }
95

96
    data += strides[cur_dim] * element_size;
97
  }
98

99
  return result;
100
}
101

102
} // namespace jit
103
} // namespace torch
104

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.