pytorch

Форк
0
/
Size.cpp 
284 строки · 8.5 Кб
1
#include <c10/util/irange.h>
2
#include <pybind11/pytypes.h>
3
#include <torch/csrc/Size.h>
4
#include <torch/csrc/utils/pybind.h>
5

6
#include <torch/csrc/utils/object_ptr.h>
7
#include <torch/csrc/utils/python_arg_parser.h>
8
#include <torch/csrc/utils/python_numbers.h>
9
#include <torch/csrc/utils/python_strings.h>
10
#include <torch/csrc/utils/python_tuples.h>
11
#include <string>
12

13
#include <torch/csrc/autograd/python_variable.h>
14
#include <torch/csrc/jit/frontend/tracer.h>
15

16
struct THPSize {
17
  PyTupleObject tuple;
18
};
19

20
PyObject* THPSize_New(const torch::autograd::Variable& var) {
21
  if (!torch::jit::tracer::isTracing()) {
22
    auto sizes = var.sizes();
23
    return THPSize_NewFromSizes(var.dim(), sizes.data());
24
  }
25
  auto self = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, var.dim()));
26
  if (!self)
27
    throw python_error();
28

29
  for (const auto i : c10::irange(var.dim())) {
30
    PyObject* py_size_tensor =
31
        THPVariable_Wrap(torch::jit::tracer::getSizeOf(var, i));
32
    if (!py_size_tensor)
33
      throw python_error();
34
    PyTuple_SET_ITEM(self.get(), i, py_size_tensor);
35
  }
36

37
  return self.release();
38
}
39

40
PyObject* THPSize_NewFromSizes(int64_t dim, const int64_t* sizes) {
41
  auto self = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, dim));
42
  if (!self)
43
    throw python_error();
44
  THPUtils_packInt64Array(self, dim, sizes);
45
  return self.release();
46
}
47

48
PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
49
  auto sym_sizes = self_.sym_sizes();
50

51
  auto ret = THPObjectPtr(THPSizeType.tp_alloc(
52
      &THPSizeType, static_cast<Py_ssize_t>(sym_sizes.size())));
53
  if (!ret)
54
    throw python_error();
55

56
  for (auto i : c10::irange(sym_sizes.size())) {
57
    auto si = sym_sizes[i];
58
    if (si.is_symbolic()) {
59
      // First check for actual symbolic values.
60
      // Reason: so that we don't replace it by its integer replacement
61
      // implicitly.
62
      TORCH_CHECK(
63
          !torch::jit::tracer::isTracing(),
64
          "JIT Tracing of SymInts isn't supported");
65
      auto py_symint = py::cast(si).release().ptr();
66
      if (!py_symint)
67
        throw python_error();
68
      PyTuple_SET_ITEM(ret.get(), i, py_symint);
69
    } else {
70
      // Otherwise, we know that it is an actual integer value.
71
      auto m = si.maybe_as_int();
72
      if (torch::jit::tracer::isTracing()) {
73
        PyObject* py_size_tensor = THPVariable_Wrap(
74
            torch::jit::tracer::getSizeOf(self_, static_cast<int64_t>(i)));
75
        if (!py_size_tensor)
76
          throw python_error();
77
        PyTuple_SET_ITEM(ret.get(), i, py_size_tensor);
78
      } else {
79
        // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
80
        PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(*m));
81
      }
82
    }
83
  }
84
  return ret.release();
85
}
86

87
static bool isTracedZeroDimVar(PyObject* item) {
88
  if (!THPVariable_Check(item))
89
    return false;
90
  auto& var = THPVariable_Unpack(item);
91
  return var.dim() == 0 && torch::jit::tracer::getValueTrace(var);
92
}
93

94
static PyObject* THPSize_pynew(
95
    PyTypeObject* type,
96
    PyObject* args,
97
    PyObject* kwargs) {
98
  HANDLE_TH_ERRORS
99
  THPObjectPtr self(PyTuple_Type.tp_new(type, args, kwargs));
100
  if (self) {
101
    for (Py_ssize_t i = 0; i < PyTuple_Size(self); ++i) {
102
      PyObject* item = PyTuple_GET_ITEM(self.get(), i);
103
      if (THPUtils_checkLong(item)) {
104
        continue;
105
      }
106
      if (torch::is_symint(item)) {
107
        continue;
108
      }
109
      if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) {
110
        continue;
111
      }
112
      // item.__index__() works with 0-dim tensors and tensors with one element
113
      THPObjectPtr number(PyNumber_Index(item));
114
      if (number && THPUtils_checkLong(number.get())) {
115
        Py_INCREF(number.get());
116
        auto status = PyTuple_SetItem(self, i, number.get());
117
        if (status != 0) {
118
          throw python_error();
119
        }
120
        continue;
121
      }
122
      return PyErr_Format(
123
          PyExc_TypeError,
124
          "torch.Size() takes an iterable of 'int' (item %zd is '%s')",
125
          i,
126
          Py_TYPE(item)->tp_name);
127
    }
128
  }
129
  return self.release();
130
  END_HANDLE_TH_ERRORS
131
}
132

133
static PyObject* THPSize_repr(THPSize* self) {
134
  HANDLE_TH_ERRORS
135
  std::string repr("torch.Size([");
136
  for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
137
    if (i != 0) {
138
      repr += ", ";
139
    }
140
    auto item = PyTuple_GET_ITEM(self, i);
141
    auto ih = py::handle(item);
142

143
    repr += torch::is_symint(ih)
144
        ? std::string(py::str(ih))
145
        : std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i)));
146
  }
147
  repr += "])";
148
  return THPUtils_packString(repr);
149
  END_HANDLE_TH_ERRORS
150
}
151

152
extern PyTypeObject THPSizeType;
153

154
template <typename FnType, FnType fn, typename... Args>
155
static PyObject* wrap_tuple_fn(Args... args) {
156
  THPObjectPtr result((*fn)(std::forward<Args>(args)...));
157
  if (!result)
158
    return nullptr;
159
  if (PyTuple_Check(result.get())) {
160
    return PyObject_CallFunctionObjArgs(
161
        (PyObject*)&THPSizeType, result.get(), nullptr);
162
  }
163
  return result.release();
164
}
165

166
// We use an anonymous namespace instead of static to work around
167
// (what @peterjc123 think is) a bug in Visual Studio
168
namespace {
169
auto sq_concat = PyTuple_Type.tp_as_sequence->sq_concat;
170
auto sq_repeat = PyTuple_Type.tp_as_sequence->sq_repeat;
171
binaryfunc mp_subscript = PyTuple_Type.tp_as_mapping->mp_subscript;
172
} // namespace
173

174
static PySequenceMethods THPSize_as_sequence = {
175
    nullptr, /* sq_length */
176
    wrap_tuple_fn<decltype(&sq_concat), &sq_concat>,
177
    wrap_tuple_fn<decltype(&sq_repeat), &sq_repeat>,
178
    nullptr, /* sq_item */
179
    nullptr, /* sq_slice */
180
    nullptr, /* sq_ass_item */
181
    nullptr, /* sq_ass_slice */
182
    nullptr /* sq_contains */
183
};
184

185
static PyMappingMethods THPSize_as_mapping = {
186
    nullptr, /* mp_length */
187
    wrap_tuple_fn<decltype(&mp_subscript), &mp_subscript>,
188
    nullptr};
189

190
static PyObject* THPSize_numel(PyObject* _self, PyObject* noargs) {
191
  HANDLE_TH_ERRORS
192
  auto self = (THPSize*)_self;
193
  int64_t numel = 1;
194
  for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
195
    numel *= THPUtils_unpackLong(PyTuple_GET_ITEM(self, i));
196
  }
197
  return THPUtils_packInt64(numel);
198
  END_HANDLE_TH_ERRORS
199
}
200

201
static PyObject* THPSize_reduce(PyObject* _self, PyObject* noargs) {
202
  HANDLE_TH_ERRORS
203
  auto self = (THPSize*)_self;
204
  auto ret = THPObjectPtr{PyTuple_New(2)};
205
  if (!ret)
206
    throw python_error();
207

208
  auto obj = (PyObject*)(&THPSizeType);
209
  Py_INCREF(&THPSizeType);
210
  PyTuple_SET_ITEM(ret.get(), 0, obj);
211

212
  THPObjectPtr t(PyTuple_New(PyTuple_Size((PyObject*)self)));
213
  if (!t)
214
    throw python_error();
215
  for (Py_ssize_t i = 0; i < PyTuple_Size((PyObject*)self); ++i) {
216
    auto d = PyTuple_GET_ITEM(self, i);
217
    Py_INCREF(d);
218
    PyTuple_SET_ITEM(t.get(), i, d);
219
  }
220

221
  THPObjectPtr dims(Py_BuildValue("(O)", t.get()));
222
  if (!dims)
223
    throw python_error();
224
  PyTuple_SET_ITEM(ret.get(), 1, dims.release());
225

226
  return ret.release();
227
  END_HANDLE_TH_ERRORS
228
}
229

230
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
231
static PyMethodDef THPSize_methods[] = {
232
    {"numel", THPSize_numel, METH_NOARGS, nullptr},
233
    {"__reduce__", THPSize_reduce, METH_NOARGS, nullptr},
234
    {nullptr}};
235

236
PyTypeObject THPSizeType = {
237
    PyVarObject_HEAD_INIT(nullptr, 0) "torch.Size", /* tp_name */
238
    sizeof(THPSize), /* tp_basicsize */
239
    0, /* tp_itemsize */
240
    nullptr, /* tp_dealloc */
241
    0, /* tp_vectorcall_offset */
242
    nullptr, /* tp_getattr */
243
    nullptr, /* tp_setattr */
244
    nullptr, /* tp_reserved */
245
    (reprfunc)THPSize_repr, /* tp_repr */
246
    nullptr, /* tp_as_number */
247
    &THPSize_as_sequence, /* tp_as_sequence */
248
    &THPSize_as_mapping, /* tp_as_mapping */
249
    nullptr, /* tp_hash  */
250
    nullptr, /* tp_call */
251
    nullptr, /* tp_str */
252
    nullptr, /* tp_getattro */
253
    nullptr, /* tp_setattro */
254
    nullptr, /* tp_as_buffer */
255
    Py_TPFLAGS_DEFAULT, /* tp_flags */
256
    nullptr, /* tp_doc */
257
    nullptr, /* tp_traverse */
258
    nullptr, /* tp_clear */
259
    nullptr, /* tp_richcompare */
260
    0, /* tp_weaklistoffset */
261
    nullptr, /* tp_iter */
262
    nullptr, /* tp_iternext */
263
    THPSize_methods, /* tp_methods */
264
    nullptr, /* tp_members */
265
    nullptr, /* tp_getset */
266
    &PyTuple_Type, /* tp_base */
267
    nullptr, /* tp_dict */
268
    nullptr, /* tp_descr_get */
269
    nullptr, /* tp_descr_set */
270
    0, /* tp_dictoffset */
271
    nullptr, /* tp_init */
272
    nullptr, /* tp_alloc */
273
    THPSize_pynew, /* tp_new */
274
};
275

276
void THPSize_init(PyObject* module) {
277
  if (PyType_Ready(&THPSizeType) < 0) {
278
    throw python_error();
279
  }
280
  Py_INCREF(&THPSizeType);
281
  if (PyModule_AddObject(module, "Size", (PyObject*)&THPSizeType) < 0) {
282
    throw python_error();
283
  }
284
}
285

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.