pytorch

Форк
0
/
tensor_list.cpp 
70 строк · 2.0 Кб
1
#include <torch/csrc/utils/tensor_list.h>
2

3
#include <c10/util/irange.h>
4
#include <pybind11/pybind11.h>
5
#include <torch/csrc/Exceptions.h>
6
#include <torch/csrc/autograd/python_variable.h>
7
#include <torch/csrc/utils/pybind.h>
8
#include <torch/csrc/utils/python_scalars.h>
9

10
using namespace at;
11

12
namespace torch {
13
namespace utils {
14

15
static PyObject* recursive_to_list(
16
    const char* data,
17
    IntArrayRef sizes,
18
    IntArrayRef strides,
19
    int64_t dim,
20
    ScalarType scalarType,
21
    size_t elementSize) {
22
  int64_t ndim = static_cast<int64_t>(sizes.size());
23
  if (dim == ndim) {
24
    return torch::utils::load_scalar(data, scalarType);
25
  }
26
  auto n = sizes[dim];
27
  auto list = THPObjectPtr(PyList_New(n));
28
  if (!list)
29
    throw python_error();
30
  for (const auto i : c10::irange(n)) {
31
    PyObject* obj = recursive_to_list(
32
        data, sizes, strides, dim + 1, scalarType, elementSize);
33
    if (!obj)
34
      throw python_error();
35
    PyList_SET_ITEM(list.get(), i, obj);
36
    auto advance_data_ptr = strides[dim] * elementSize;
37
    TORCH_INTERNAL_ASSERT(data || (advance_data_ptr == 0));
38
    data += advance_data_ptr;
39
  }
40
  return list.release();
41
}
42

43
PyObject* tensor_to_list(const Tensor& tensor) {
44
  {
45
    py::object pytensor =
46
        py::reinterpret_steal<py::object>(THPVariable_Wrap(tensor));
47
    TORCH_CHECK(
48
        !tensor.unsafeGetTensorImpl()->is_python_dispatch(),
49
        ".tolist() is not supported for tensor subclasses, got ",
50
        Py_TYPE(pytensor.ptr())->tp_name);
51
  }
52
  Tensor data = tensor.resolve_conj().resolve_neg();
53
  if (!data.device().is_cpu()) {
54
    pybind11::gil_scoped_release no_gil;
55
    data = data.toBackend(Backend::CPU);
56
  }
57
  TORCH_CHECK(
58
      tensor.numel() == 0 || data.const_data_ptr(),
59
      "tolist() shouldn't be called on a tensor with unallocated storage");
60
  return recursive_to_list(
61
      (const char*)data.const_data_ptr(),
62
      data.sizes(),
63
      data.strides(),
64
      0,
65
      data.scalar_type(),
66
      tensor.numel() == 0 ? 0 : data.dtype().itemsize());
67
}
68

69
} // namespace utils
70
} // namespace torch
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.