1
#include <ATen/core/ivalue.h>
2
#include <pybind11/cast.h>
3
#include <pybind11/detail/common.h>
4
#include <torch/csrc/jit/python/pybind_utils.h>
5
#include <torch/csrc/jit/python/python_dict.h>
6
#include <torch/csrc/jit/runtime/jit_exception.h>
7
#include <torch/csrc/utils/pybind.h>
13
IValue ScriptDictIterator::next() {
15
throw py::stop_iteration();
18
// Since this is the iterator for .items(), the current key and value
19
// should be returned as a tuple.
20
IValue result = c10::ivalue::Tuple::create({iter_->key(), iter_->value()});
22
// Advance the iterator for next time.
28
IValue ScriptDictKeyIterator::next() {
30
throw py::stop_iteration();
33
// Since this is the iterator for .keys() and __iter__(), return only the key.
34
IValue result = iter_->key();
36
// Advance the iterator for next time.
42
void initScriptDictBindings(PyObject* module) {
43
auto m = py::handle(module).cast<py::module>();
45
py::class_<ScriptDictKeyIterator>(m, "ScriptDictKeyIterator")
48
[](ScriptDictKeyIterator& iter) {
49
auto result = iter.next();
50
return toPyObject(result);
52
.def("__iter__", [](ScriptDictKeyIterator& iter) { return iter; });
54
py::class_<ScriptDictIterator>(m, "ScriptDictIterator")
57
[](ScriptDictIterator& iter) {
58
auto result = iter.next();
59
return toPyObject(result);
61
.def("__iter__", [](ScriptDictIterator& iter) { return iter; });
63
py::class_<ScriptDict, std::shared_ptr<ScriptDict>>(m, "ScriptDict")
64
.def(py::init([](py::dict dict) {
65
TypePtr type = nullptr;
68
// If the source dictionary is nonempty, try to infer its type.
69
auto inferred_type = tryToInferType(dict);
71
if (!inferred_type.success()) {
73
ss << "Unable to infer type of dictionary: "
74
<< inferred_type.reason();
75
throw JITException(ss.str());
78
type = inferred_type.type();
80
// If is empty, assume the type is Dict[str, Tensor] as is done in
82
type = DictType::create(StringType::get(), TensorType::getInferred());
85
auto data = toIValue(std::move(dict), type);
86
return std::make_shared<ScriptDict>(data);
90
[](const std::shared_ptr<ScriptDict>& self) {
91
return toPyObject(self->repr());
95
[](const std::shared_ptr<ScriptDict>& self) {
96
return toPyObject(self->toBool());
100
[](const std::shared_ptr<ScriptDict>& self) {
101
return toPyObject(self->len());
105
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
107
return toPyObject(self->contains(
108
toIValue(std::move(key), self->type()->getKeyType())));
109
} catch (const py::cast_error& e) {
110
throw py::key_error();
115
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
118
// Convert key to IValue.
120
value = toIValue(std::move(key), self->type()->getKeyType());
121
} catch (const py::cast_error& e) {
122
// It would be nice to throw py::type_error here but py::key_error
123
// needs to be thrown for parity with eager mode.
124
throw py::key_error();
127
// Call getItem on self.
129
value = self->getItem(value);
130
} catch (const std::out_of_range& e) { // Key doesn't exist.
131
throw py::key_error();
134
return toPyObject(std::move(value));
136
py::return_value_policy::
137
reference_internal) // Return value is a reference to an object
138
// that resides in the ScriptDict
141
[](const std::shared_ptr<ScriptDict>& self,
144
IValue key_ivalue, value_ivalue;
146
// Try to convert the key to an IValue.
148
key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
149
} catch (const py::cast_error& e) {
150
throw py::type_error();
153
// Try to convert the value to an IValue.
156
toIValue(std::move(value), self->type()->getValueType());
157
} catch (const py::cast_error& e) {
158
throw py::type_error();
161
self->setItem(key_ivalue, value_ivalue);
165
[](const std::shared_ptr<ScriptDict>& self, py::object key) {
168
// Try to convert the key to an IValue.
170
key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
171
} catch (const py::cast_error& e) {
172
throw py::type_error();
175
// If removed = false, that means the key didn't exist in the
177
bool removed = self->delItem(key_ivalue);
180
throw py::key_error();
185
[](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
186
py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
187
// long as the iterator
190
[](const std::shared_ptr<ScriptDict>& self) { return self->items(); },
191
py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
192
// long as the iterator
195
[](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
196
py::keep_alive<0, 1>()); // ScriptDict needs to be alive at least as
197
// long as the iterator
200
} // namespace torch::jit