pytorch

Форк
0
/
python_list.cpp 
315 строк · 9.8 Кб
1
#include <ATen/core/ivalue.h>
2
#include <c10/util/irange.h>
3
#include <pybind11/detail/common.h>
4
#include <pybind11/pytypes.h>
5
#include <torch/csrc/jit/python/pybind_utils.h>
6
#include <torch/csrc/jit/python/python_list.h>
7
#include <torch/csrc/utils/pybind.h>
8
#include <stdexcept>
9

10
namespace torch::jit {
11

12
IValue ScriptListIterator::next() {
13
  if (iter_ == end_) {
14
    throw py::stop_iteration();
15
  }
16

17
  IValue result = *iter_;
18

19
  // Advance the iterator for next time.
20
  iter_++;
21

22
  return result;
23
}
24

25
bool ScriptListIterator::done() const {
26
  return iter_ == end_;
27
}
28

29
namespace {
30
py::list scriptListToPyList(const ScriptList& src) {
31
  py::list out(src.len());
32
  auto iter = src.iter();
33

34
  size_t i = 0;
35
  while (!iter.done()) {
36
    auto val = iter.next();
37
    // TODO: Handle nested dictionaries.
38
    if (val.isList()) {
39
      out[i] = scriptListToPyList(val);
40
    } else {
41
      out[i] = toPyObject(val);
42
    }
43
    ++i;
44
  }
45

46
  return out;
47
}
48
} // namespace
49

50
void initScriptListBindings(PyObject* module) {
51
  auto m = py::handle(module).cast<py::module>();
52

53
  py::class_<ScriptListIterator>(m, "ScriptListIterator")
54
      .def(
55
          "__next__",
56
          [](ScriptListIterator& iter) {
57
            auto result = iter.next();
58
            return toPyObject(result);
59
          })
60
      .def("__iter__", [](ScriptListIterator& iter) { return iter; });
61

62
  py::class_<ScriptList, std::shared_ptr<ScriptList>>(m, "ScriptList")
63
      .def(py::init([](py::list list) {
64
        TypePtr type = nullptr;
65

66
        if (!list.empty()) {
67
          // If the source list is nonempty, try to infer its type.
68
          auto inferred_type = tryToInferType(list);
69

70
          if (!inferred_type.success()) {
71
            std::stringstream ss;
72
            ss << "Unable to infer type of list: " << inferred_type.reason();
73
            throw JITException(ss.str());
74
          }
75

76
          type = inferred_type.type();
77
        } else {
78
          // If is empty, assume the type is List[Tensor] as is done in
79
          // TorchScript code.
80
          type = ListType::create(TensorType::getInferred());
81
        }
82

83
        auto data = toIValue(std::move(list), type);
84
        return std::make_shared<ScriptList>(data);
85
      }))
86
      .def(
87
          "__repr__",
88
          [](const std::shared_ptr<ScriptList>& self) {
89
            return toPyObject(self->repr());
90
          })
91
      .def(
92
          "__bool__",
93
          [](const std::shared_ptr<ScriptList>& self) {
94
            return toPyObject(self->toBool());
95
          })
96
      .def(
97
          "__len__",
98
          [](const std::shared_ptr<ScriptList>& self) {
99
            return toPyObject(static_cast<int64_t>(self->len()));
100
          })
101
      .def(
102
          "__contains__",
103
          [](const std::shared_ptr<ScriptList>& self, py::object elem) {
104
            try {
105
              return toPyObject(self->contains(
106
                  toIValue(std::move(elem), self->type()->getElementType())));
107
            } catch (const py::cast_error& e) {
108
              throw py::type_error();
109
            }
110
          })
111
      .def(
112
          "__getitem__",
113
          [](const std::shared_ptr<ScriptList>& self,
114
             ScriptList::diff_type idx) {
115
            try {
116
              auto value = self->getItem(idx);
117
              return toPyObject(value);
118
            } catch (const std::out_of_range& e) {
119
              throw py::index_error();
120
            }
121
          },
122
          py::return_value_policy::
123
              reference_internal) // Return value is a reference to an object
124
                                  // that resides in the ScriptList
125
      .def(
126
          "__getitem__",
127
          [](const std::shared_ptr<ScriptList>& self, const py::slice& slice) {
128
            size_t start = 0, stop = 0, step = 0, slicelength = 0;
129

130
            if (!slice.compute(
131
                    self->len(), &start, &stop, &step, &slicelength)) {
132
              throw py::error_already_set();
133
            }
134

135
            auto seq = std::make_shared<ScriptList>(self->type());
136

137
            for (const auto i : c10::irange(slicelength)) {
138
              (void)i; // Suppress unused variable warning
139
              seq->append(self->getItem(start));
140
              start += step;
141
            }
142

143
            return seq;
144
          })
145
      .def(
146
          "__setitem__",
147
          [](const std::shared_ptr<ScriptList>& self,
148
             ScriptList::diff_type idx,
149
             py::object value) {
150
            try {
151
              self->setItem(
152
                  idx,
153
                  toIValue(std::move(value), self->type()->getElementType()));
154
            } catch (const std::out_of_range& e) {
155
              throw py::index_error();
156
            } catch (const py::cast_error& e) {
157
              throw py::type_error();
158
            }
159
          })
160
      .def(
161
          "__setitem__",
162
          [](const std::shared_ptr<ScriptList>& self,
163
             const py::slice& slice,
164
             const py::list& value) {
165
            size_t start = 0, stop = 0, step = 0, slicelength = 0;
166

167
            if (!slice.compute(
168
                    self->len(), &start, &stop, &step, &slicelength)) {
169
              throw py::error_already_set();
170
            }
171

172
            if (slicelength != value.size()) {
173
              throw std::runtime_error(
174
                  "Left and right hand size of slice assignment have different sizes");
175
            }
176

177
            for (const auto i : c10::irange(slicelength)) {
178
              try {
179
                self->setItem(
180
                    start, toIValue(value[i], self->type()->getElementType()));
181
              } catch (const py::cast_error& e) {
182
                throw py::type_error();
183
              }
184
              start += step;
185
            }
186
          })
187
      .def(
188
          "__delitem__",
189
          [](const std::shared_ptr<ScriptList>& self,
190
             ScriptList::diff_type idx) {
191
            try {
192
              self->delItem(idx);
193
            } catch (const std::out_of_range& e) {
194
              throw py::index_error();
195
            }
196
          })
197
      .def(
198
          "__iter__",
199
          [](const std::shared_ptr<ScriptList>& self) { return self->iter(); },
200
          py::keep_alive<0, 1>()) // ScriptList needs to be alive at least as
201
                                  // long as the iterator
202
      .def(
203
          "count",
204
          [](const std::shared_ptr<ScriptList>& self, py::object value) {
205
            try {
206
              return self->count(
207
                  toIValue(std::move(value), self->type()->getElementType()));
208

209
            } catch (const py::cast_error& e) {
210
              throw py::type_error();
211
            }
212
          })
213
      .def(
214
          "remove",
215
          [](const std::shared_ptr<ScriptList>& self, py::object value) {
216
            try {
217
              return self->remove(
218
                  toIValue(std::move(value), self->type()->getElementType()));
219
            } catch (const py::cast_error& e) {
220
              throw py::type_error();
221
            }
222
          })
223
      .def(
224
          "append",
225
          [](const std::shared_ptr<ScriptList>& self, py::object value) {
226
            try {
227
              return self->append(
228
                  toIValue(std::move(value), self->type()->getElementType()));
229
            } catch (const py::cast_error& e) {
230
              throw py::type_error();
231
            }
232
          })
233
      .def(
234
          "clear",
235
          [](const std::shared_ptr<ScriptList>& self) { self->clear(); })
236
      .def(
237
          "extend",
238
          [](const std::shared_ptr<ScriptList>& self, py::list list) {
239
            try {
240
              self->extend(toIValue(std::move(list), self->type()));
241
            } catch (const py::cast_error& e) {
242
              throw py::type_error();
243
            }
244
          })
245
      .def(
246
          "extend",
247
          [](const std::shared_ptr<ScriptList>& self,
248
             const py::iterable& iter) {
249
            ScriptList iter_list(self->type());
250

251
            try {
252
              for (py::handle obj : iter) {
253
                iter_list.append(toIValue(
254
                    py::reinterpret_borrow<py::object>(obj),
255
                    self->type()->getElementType()));
256
              }
257
            } catch (const py::cast_error& e) {
258
              throw py::type_error();
259
            }
260

261
            self->extend(toIValue(py::cast(iter_list), self->type()));
262
          })
263
      .def(
264
          "pop",
265
          [](const std::shared_ptr<ScriptList>& self) {
266
            return toPyObject(self->pop());
267
          })
268
      .def(
269
          "pop",
270
          [](const std::shared_ptr<ScriptList>& self,
271
             ScriptList::diff_type idx) { return toPyObject(self->pop(idx)); })
272
      .def(
273
          "insert",
274
          [](const std::shared_ptr<ScriptList>& self,
275
             ScriptList::diff_type idx,
276
             py::object obj) {
277
            try {
278
              self->insert(
279
                  toIValue(std::move(obj), self->type()->getElementType()),
280
                  idx);
281
            } catch (const py::cast_error& e) {
282
              throw py::type_error();
283
            }
284
          })
285
      .def(py::pickle(
286
          [](const ScriptList& data) { // __getstate__
287
            return scriptListToPyList(data);
288
          },
289
          [](py::list list) { // __setstate__
290
            TypePtr type = nullptr;
291

292
            if (!list.empty()) {
293
              // If the source list is nonempty, try to infer its type.
294
              auto inferred_type = tryToInferType(list);
295

296
              if (!inferred_type.success()) {
297
                std::stringstream ss;
298
                ss << "Unable to infer type of list: "
299
                   << inferred_type.reason();
300
                throw JITException(ss.str());
301
              }
302

303
              type = inferred_type.type();
304
            } else {
305
              // If is empty, assume the type is List[Tensor] as is done in
306
              // TorchScript code.
307
              type = ListType::create(TensorType::getInferred());
308
            }
309

310
            auto data = toIValue(std::move(list), type);
311
            return std::make_shared<ScriptList>(data);
312
          }));
313
}
314

315
} // namespace torch::jit
316

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.