pytorch

Форк
0
/
python_arg_flatten.cpp 
196 строк · 6.7 Кб
1
#include <c10/util/irange.h>
2
#include <torch/csrc/jit/python/python_arg_flatten.h>
3
#include <torch/csrc/utils/python_strings.h>
4
#include <torch/csrc/utils/six.h>
5

6
#include <torch/csrc/autograd/grad_mode.h>
7

8
namespace torch::jit::python {
9

10
using namespace torch::autograd;
11
using namespace at;
12

13
// Alphabet used to describe structure of inputs/outputs (D for desc)
14
namespace D {
15
static constexpr char DictOpen = '<';
16
static constexpr char DictClose = '>';
17
static constexpr char ListOpen = '[';
18
static constexpr char ListClose = ']';
19
static constexpr char TupleOpen = '(';
20
static constexpr char TupleClose = ')';
21
static constexpr char Variable = 'v';
22
static constexpr char Bool = 'b';
23
static constexpr char Long = 'l';
24
static constexpr char Double = 'd';
25
static constexpr char String = 's';
26
static constexpr char NoneType = 'n';
27
} // namespace D
28

29
namespace {
30

31
inline bool PyNone_Check(PyObject* o) {
32
  return o == Py_None;
33
}
34

35
template <typename T>
36
py::object cast_handle_sequence(std::vector<py::handle> objs) {
37
  auto num_objs = objs.size();
38
  T sequence{num_objs};
39
  for (const auto i : c10::irange(num_objs)) {
40
    sequence[i] = py::reinterpret_borrow<py::object>(objs[i]);
41
  }
42
  return sequence;
43
}
44

45
void flatten_rec(PyObject* obj, ParsedArgs& args) {
46
  auto& structure = args.desc.structure;
47
  if (six::isTuple(obj)) {
48
    structure.push_back(D::TupleOpen);
49
    for (auto item : py::reinterpret_borrow<py::tuple>(obj))
50
      flatten_rec(item.ptr(), args);
51
    structure.push_back(D::TupleClose);
52
  } else if (PyList_Check(obj)) {
53
    structure.push_back(D::ListOpen);
54
    for (auto item : py::reinterpret_borrow<py::list>(obj))
55
      flatten_rec(item.ptr(), args);
56
    structure.push_back(D::ListClose);
57
  } else if (PyDict_Check(obj)) {
58
    auto* dict_items = PyDict_Items(obj);
59
    structure.push_back(D::DictOpen);
60
    for (auto item : py::reinterpret_borrow<py::list>(dict_items)) {
61
      flatten_rec(item.ptr(), args);
62
    }
63
    structure.push_back(D::DictClose);
64
    Py_DECREF(dict_items);
65
  } else if (THPUtils_checkString(obj)) {
66
    string str = THPUtils_unpackString(obj);
67
    args.desc.strings.emplace_back(str);
68
    args.desc.structure.push_back(D::String);
69
  } else if (THPVariable_Check(obj)) {
70
    auto& var = THPVariable_Unpack(obj);
71
    args.vars.push_back(var);
72
    args.desc.metadata.emplace_back(var);
73
    args.desc.structure.push_back(D::Variable);
74
  } else if (PyNone_Check(obj)) {
75
    args.desc.structure.push_back(D::NoneType);
76
  } else if (PyBool_Check(obj)) { // Wrap bools in Bool tensors
77
    at::Tensor var = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
78
    args.vars.push_back(var);
79
    args.desc.metadata.emplace_back(var);
80
    args.desc.structure.push_back(D::Bool);
81
  } else if (PyLong_Check(obj)) { // Wrap longs in Long tensors
82
    at::Tensor var = scalar_to_tensor(
83
        at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(obj))));
84
    args.vars.push_back(var);
85
    args.desc.metadata.emplace_back(var);
86
    args.desc.structure.push_back(D::Long);
87
  } else if (PyFloat_Check(obj)) { // Wrap floats in Double tensors
88
    at::Tensor var = scalar_to_tensor(THPUtils_unpackDouble(obj));
89
    args.vars.push_back(var);
90
    args.desc.metadata.emplace_back(var);
91
    args.desc.structure.push_back(D::Double);
92
  } else {
93
    std::string msg =
94
        "Only tuples, lists and Variables are supported as JIT inputs/outputs. "
95
        "Dictionaries and strings are also accepted, but their usage is not "
96
        "recommended. Here, received an input of unsupported type: ";
97
    msg += THPUtils_typename(obj);
98
    throw std::runtime_error(msg);
99
  }
100
}
101

102
} // anonymous namespace
103

104
ParsedArgs flatten(py::handle obj) {
105
  ParsedArgs args;
106
  args.desc.grad_enabled = autograd::GradMode::is_enabled();
107
  flatten_rec(obj.ptr(), args);
108
  return args;
109
}
110

111
namespace {
112

113
template <typename T>
114
py::object cast_sequence(std::vector<py::object> objs) {
115
  auto num_objs = objs.size();
116
  T sequence{num_objs};
117
  for (const auto i : c10::irange(num_objs)) {
118
    sequence[i] = std::move(objs[i]);
119
  }
120
  return std::move(sequence);
121
}
122

123
py::object cast_dict(std::vector<py::object> objs) {
124
  auto num_objs = objs.size();
125
  py::dict sequence = {};
126
  for (const auto i : c10::irange(num_objs)) {
127
    py::tuple obj = py::reinterpret_borrow<py::tuple>(objs[i]);
128
    sequence[obj[0]] = obj[1];
129
  }
130
  return std::move(sequence);
131
}
132

133
py::object unflatten_rec(
134
    ArrayRef<Variable>::iterator& var_it,
135
    ArrayRef<Variable>::iterator& var_it_end,
136
    std::string::const_iterator& desc_it,
137
    std::vector<string>::const_iterator& str_it,
138
    std::vector<string>::const_iterator& str_it_end) {
139
  char type = *desc_it++;
140
  if (type == D::TupleOpen) {
141
    std::vector<py::object> objs;
142
    while (*desc_it != D::TupleClose)
143
      objs.push_back(
144
          unflatten_rec(var_it, var_it_end, desc_it, str_it, str_it_end));
145
    ++desc_it;
146
    return cast_sequence<py::tuple>(objs);
147
  } else if (type == D::ListOpen) {
148
    std::vector<py::object> objs;
149
    while (*desc_it != D::ListClose)
150
      objs.push_back(
151
          unflatten_rec(var_it, var_it_end, desc_it, str_it, str_it_end));
152
    ++desc_it;
153
    return cast_sequence<py::list>(objs);
154
  } else if (type == D::DictOpen) {
155
    std::vector<py::object> objs;
156
    while (*desc_it != D::DictClose) {
157
      objs.push_back(
158
          unflatten_rec(var_it, var_it_end, desc_it, str_it, str_it_end));
159
    }
160
    ++desc_it;
161
    return cast_dict(objs);
162
  } else if (type == D::String) {
163
    if (str_it == str_it_end)
164
      throw std::runtime_error("Not enough Variables given to unflatten");
165
    auto str = *str_it++;
166
    return py::reinterpret_borrow<py::object>(THPUtils_packString(str));
167
  } else if (type == D::NoneType) {
168
    return py::reinterpret_borrow<py::object>(py::none());
169
  } else {
170
    // if (type == D::Long || type == D::Double || type == D::Bool ||
171
    // D::Variable) unwrap variables (D::Variable), or unwrap primitive types
172
    // (Long, Double, Bool) as variables for tracer.
173
    if (var_it == var_it_end)
174
      throw std::runtime_error("Not enough Variables given to unflatten");
175
    auto var = *var_it++;
176
    return py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
177
  }
178
}
179

180
} // anonymous namespace
181

182
PyObject* unflatten(ArrayRef<Variable> vars, const IODescriptor& desc) {
183
  // NB: We don't do correctness checking on descriptor.
184
  // It has to be a correct bytes object produced by unflatten.
185
  auto vars_it = vars.begin();
186
  auto vars_it_end = vars.end();
187
  auto desc_it = desc.structure.begin();
188
  std::vector<std::string>::const_iterator str_it = desc.strings.begin();
189
  std::vector<std::string>::const_iterator str_end = desc.strings.end();
190
  auto output = unflatten_rec(vars_it, vars_it_end, desc_it, str_it, str_end);
191
  if (vars_it != vars_it_end)
192
    throw std::runtime_error("Too many Variables given to unflatten");
193
  return output.release().ptr();
194
}
195

196
} // namespace torch::jit::python
197

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.