pytorch

Форк
0
/
structseq.cpp 
76 строк · 2.0 Кб
1
/* Copyright Python Software Foundation
2
 *
3
 * This file is copy-pasted from CPython source code with modifications:
4
 * https://github.com/python/cpython/blob/master/Objects/structseq.c
5
 * https://github.com/python/cpython/blob/2.7/Objects/structseq.c
6
 *
7
 * The purpose of this file is to overwrite the default behavior
8
 * of repr of structseq to provide better printting for returned
9
 * structseq objects from operators, aka torch.return_types.*
10
 *
11
 * For more information on copyright of CPython, see:
12
 * https://github.com/python/cpython#copyright-and-license-information
13
 */
14

15
#include <torch/csrc/utils/six.h>
16
#include <torch/csrc/utils/structseq.h>
17
#include <sstream>
18

19
#include <structmember.h>
20

21
namespace torch {
22
namespace utils {
23

24
// NOTE: The built-in repr method from PyStructSequence was updated in
25
// https://github.com/python/cpython/commit/c70ab02df2894c34da2223fc3798c0404b41fd79
26
// so this function might not be required in Python 3.8+.
27
PyObject* returned_structseq_repr(PyStructSequence* obj) {
28
  PyTypeObject* typ = Py_TYPE(obj);
29
  THPObjectPtr tup = six::maybeAsTuple(obj);
30
  if (tup == nullptr) {
31
    return nullptr;
32
  }
33

34
  std::stringstream ss;
35
  ss << typ->tp_name << "(\n";
36
  Py_ssize_t num_elements = Py_SIZE(obj);
37

38
  for (Py_ssize_t i = 0; i < num_elements; i++) {
39
    const char* cname = typ->tp_members[i].name;
40
    if (cname == nullptr) {
41
      PyErr_Format(
42
          PyExc_SystemError,
43
          "In structseq_repr(), member %zd name is nullptr"
44
          " for type %.500s",
45
          i,
46
          typ->tp_name);
47
      return nullptr;
48
    }
49

50
    PyObject* val = PyTuple_GetItem(tup.get(), i);
51
    if (val == nullptr) {
52
      return nullptr;
53
    }
54

55
    auto repr = THPObjectPtr(PyObject_Repr(val));
56
    if (repr == nullptr) {
57
      return nullptr;
58
    }
59

60
    const char* crepr = PyUnicode_AsUTF8(repr);
61
    if (crepr == nullptr) {
62
      return nullptr;
63
    }
64

65
    ss << cname << '=' << crepr;
66
    if (i < num_elements - 1) {
67
      ss << ",\n";
68
    }
69
  }
70
  ss << ")";
71

72
  return PyUnicode_FromString(ss.str().c_str());
73
}
74

75
} // namespace utils
76
} // namespace torch
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.