1
#include <ATen/core/ivalue.h>
2
#include <c10/util/irange.h>
3
#include <pybind11/detail/common.h>
4
#include <pybind11/pytypes.h>
5
#include <torch/csrc/jit/python/pybind_utils.h>
6
#include <torch/csrc/jit/python/python_list.h>
7
#include <torch/csrc/utils/pybind.h>
12
IValue ScriptListIterator::next() {
14
throw py::stop_iteration();
17
IValue result = *iter_;
19
// Advance the iterator for next time.
25
bool ScriptListIterator::done() const {
30
py::list scriptListToPyList(const ScriptList& src) {
31
py::list out(src.len());
32
auto iter = src.iter();
35
while (!iter.done()) {
36
auto val = iter.next();
37
// TODO: Handle nested dictionaries.
39
out[i] = scriptListToPyList(val);
41
out[i] = toPyObject(val);
50
void initScriptListBindings(PyObject* module) {
51
auto m = py::handle(module).cast<py::module>();
53
py::class_<ScriptListIterator>(m, "ScriptListIterator")
56
[](ScriptListIterator& iter) {
57
auto result = iter.next();
58
return toPyObject(result);
60
.def("__iter__", [](ScriptListIterator& iter) { return iter; });
62
py::class_<ScriptList, std::shared_ptr<ScriptList>>(m, "ScriptList")
63
.def(py::init([](py::list list) {
64
TypePtr type = nullptr;
67
// If the source list is nonempty, try to infer its type.
68
auto inferred_type = tryToInferType(list);
70
if (!inferred_type.success()) {
72
ss << "Unable to infer type of list: " << inferred_type.reason();
73
throw JITException(ss.str());
76
type = inferred_type.type();
78
// If is empty, assume the type is List[Tensor] as is done in
80
type = ListType::create(TensorType::getInferred());
83
auto data = toIValue(std::move(list), type);
84
return std::make_shared<ScriptList>(data);
88
[](const std::shared_ptr<ScriptList>& self) {
89
return toPyObject(self->repr());
93
[](const std::shared_ptr<ScriptList>& self) {
94
return toPyObject(self->toBool());
98
[](const std::shared_ptr<ScriptList>& self) {
99
return toPyObject(static_cast<int64_t>(self->len()));
103
[](const std::shared_ptr<ScriptList>& self, py::object elem) {
105
return toPyObject(self->contains(
106
toIValue(std::move(elem), self->type()->getElementType())));
107
} catch (const py::cast_error& e) {
108
throw py::type_error();
113
[](const std::shared_ptr<ScriptList>& self,
114
ScriptList::diff_type idx) {
116
auto value = self->getItem(idx);
117
return toPyObject(value);
118
} catch (const std::out_of_range& e) {
119
throw py::index_error();
122
py::return_value_policy::
123
reference_internal) // Return value is a reference to an object
124
// that resides in the ScriptList
127
[](const std::shared_ptr<ScriptList>& self, const py::slice& slice) {
128
size_t start = 0, stop = 0, step = 0, slicelength = 0;
131
self->len(), &start, &stop, &step, &slicelength)) {
132
throw py::error_already_set();
135
auto seq = std::make_shared<ScriptList>(self->type());
137
for (const auto i : c10::irange(slicelength)) {
138
(void)i; // Suppress unused variable warning
139
seq->append(self->getItem(start));
147
[](const std::shared_ptr<ScriptList>& self,
148
ScriptList::diff_type idx,
153
toIValue(std::move(value), self->type()->getElementType()));
154
} catch (const std::out_of_range& e) {
155
throw py::index_error();
156
} catch (const py::cast_error& e) {
157
throw py::type_error();
162
[](const std::shared_ptr<ScriptList>& self,
163
const py::slice& slice,
164
const py::list& value) {
165
size_t start = 0, stop = 0, step = 0, slicelength = 0;
168
self->len(), &start, &stop, &step, &slicelength)) {
169
throw py::error_already_set();
172
if (slicelength != value.size()) {
173
throw std::runtime_error(
174
"Left and right hand size of slice assignment have different sizes");
177
for (const auto i : c10::irange(slicelength)) {
180
start, toIValue(value[i], self->type()->getElementType()));
181
} catch (const py::cast_error& e) {
182
throw py::type_error();
189
[](const std::shared_ptr<ScriptList>& self,
190
ScriptList::diff_type idx) {
193
} catch (const std::out_of_range& e) {
194
throw py::index_error();
199
[](const std::shared_ptr<ScriptList>& self) { return self->iter(); },
200
py::keep_alive<0, 1>()) // ScriptList needs to be alive at least as
201
// long as the iterator
204
[](const std::shared_ptr<ScriptList>& self, py::object value) {
207
toIValue(std::move(value), self->type()->getElementType()));
209
} catch (const py::cast_error& e) {
210
throw py::type_error();
215
[](const std::shared_ptr<ScriptList>& self, py::object value) {
218
toIValue(std::move(value), self->type()->getElementType()));
219
} catch (const py::cast_error& e) {
220
throw py::type_error();
225
[](const std::shared_ptr<ScriptList>& self, py::object value) {
228
toIValue(std::move(value), self->type()->getElementType()));
229
} catch (const py::cast_error& e) {
230
throw py::type_error();
235
[](const std::shared_ptr<ScriptList>& self) { self->clear(); })
238
[](const std::shared_ptr<ScriptList>& self, py::list list) {
240
self->extend(toIValue(std::move(list), self->type()));
241
} catch (const py::cast_error& e) {
242
throw py::type_error();
247
[](const std::shared_ptr<ScriptList>& self,
248
const py::iterable& iter) {
249
ScriptList iter_list(self->type());
252
for (py::handle obj : iter) {
253
iter_list.append(toIValue(
254
py::reinterpret_borrow<py::object>(obj),
255
self->type()->getElementType()));
257
} catch (const py::cast_error& e) {
258
throw py::type_error();
261
self->extend(toIValue(py::cast(iter_list), self->type()));
265
[](const std::shared_ptr<ScriptList>& self) {
266
return toPyObject(self->pop());
270
[](const std::shared_ptr<ScriptList>& self,
271
ScriptList::diff_type idx) { return toPyObject(self->pop(idx)); })
274
[](const std::shared_ptr<ScriptList>& self,
275
ScriptList::diff_type idx,
279
toIValue(std::move(obj), self->type()->getElementType()),
281
} catch (const py::cast_error& e) {
282
throw py::type_error();
286
[](const ScriptList& data) { // __getstate__
287
return scriptListToPyList(data);
289
[](py::list list) { // __setstate__
290
TypePtr type = nullptr;
293
// If the source list is nonempty, try to infer its type.
294
auto inferred_type = tryToInferType(list);
296
if (!inferred_type.success()) {
297
std::stringstream ss;
298
ss << "Unable to infer type of list: "
299
<< inferred_type.reason();
300
throw JITException(ss.str());
303
type = inferred_type.type();
305
// If is empty, assume the type is List[Tensor] as is done in
307
type = ListType::create(TensorType::getInferred());
310
auto data = toIValue(std::move(list), type);
311
return std::make_shared<ScriptList>(data);
315
} // namespace torch::jit