pytorch

Форк
0
/
python_dict.cpp 
200 строк · 6.4 Кб
1
#include <ATen/core/ivalue.h>
2
#include <pybind11/cast.h>
3
#include <pybind11/detail/common.h>
4
#include <torch/csrc/jit/python/pybind_utils.h>
5
#include <torch/csrc/jit/python/python_dict.h>
6
#include <torch/csrc/jit/runtime/jit_exception.h>
7
#include <torch/csrc/utils/pybind.h>
8
#include <sstream>
9
#include <stdexcept>
10

11
namespace torch::jit {
12

13
IValue ScriptDictIterator::next() {
14
  if (iter_ == end_) {
15
    throw py::stop_iteration();
16
  }
17

18
  // Since this is the iterator for .items(), the current key and value
19
  // should be returned as a tuple.
20
  IValue result = c10::ivalue::Tuple::create({iter_->key(), iter_->value()});
21

22
  // Advance the iterator for next time.
23
  iter_++;
24

25
  return result;
26
}
27

28
IValue ScriptDictKeyIterator::next() {
29
  if (iter_ == end_) {
30
    throw py::stop_iteration();
31
  }
32

33
  // Since this is the iterator for .keys() and __iter__(), return only the key.
34
  IValue result = iter_->key();
35

36
  // Advance the iterator for next time.
37
  iter_++;
38

39
  return result;
40
}
41

42
void initScriptDictBindings(PyObject* module) {
43
  auto m = py::handle(module).cast<py::module>();
44

45
  py::class_<ScriptDictKeyIterator>(m, "ScriptDictKeyIterator")
46
      .def(
47
          "__next__",
48
          [](ScriptDictKeyIterator& iter) {
49
            auto result = iter.next();
50
            return toPyObject(result);
51
          })
52
      .def("__iter__", [](ScriptDictKeyIterator& iter) { return iter; });
53

54
  py::class_<ScriptDictIterator>(m, "ScriptDictIterator")
55
      .def(
56
          "__next__",
57
          [](ScriptDictIterator& iter) {
58
            auto result = iter.next();
59
            return toPyObject(result);
60
          })
61
      .def("__iter__", [](ScriptDictIterator& iter) { return iter; });
62

63
  py::class_<ScriptDict, std::shared_ptr<ScriptDict>>(m, "ScriptDict")
64
      .def(py::init([](py::dict dict) {
65
        TypePtr type = nullptr;
66

67
        if (!dict.empty()) {
68
          // If the source dictionary is nonempty, try to infer its type.
69
          auto inferred_type = tryToInferType(dict);
70

71
          if (!inferred_type.success()) {
72
            std::stringstream ss;
73
            ss << "Unable to infer type of dictionary: "
74
               << inferred_type.reason();
75
            throw JITException(ss.str());
76
          }
77

78
          type = inferred_type.type();
79
        } else {
80
          // If is empty, assume the type is Dict[str, Tensor] as is done in
81
          // TorchScript code.
82
          type = DictType::create(StringType::get(), TensorType::getInferred());
83
        }
84

85
        auto data = toIValue(std::move(dict), type);
86
        return std::make_shared<ScriptDict>(data);
87
      }))
88
      .def(
89
          "__repr__",
90
          [](const std::shared_ptr<ScriptDict>& self) {
91
            return toPyObject(self->repr());
92
          })
93
      .def(
94
          "__bool__",
95
          [](const std::shared_ptr<ScriptDict>& self) {
96
            return toPyObject(self->toBool());
97
          })
98
      .def(
99
          "__len__",
100
          [](const std::shared_ptr<ScriptDict>& self) {
101
            return toPyObject(self->len());
102
          })
103
      .def(
104
          "__contains__",
105
          [](const std::shared_ptr<ScriptDict>& self, py::object key) {
106
            try {
107
              return toPyObject(self->contains(
108
                  toIValue(std::move(key), self->type()->getKeyType())));
109
            } catch (const py::cast_error& e) {
110
              throw py::key_error();
111
            }
112
          })
113
      .def(
114
          "__getitem__",
115
          [](const std::shared_ptr<ScriptDict>& self, py::object key) {
116
            IValue value;
117

118
            // Convert key to IValue.
119
            try {
120
              value = toIValue(std::move(key), self->type()->getKeyType());
121
            } catch (const py::cast_error& e) {
122
              // It would be nice to throw py::type_error here but py::key_error
123
              // needs to be thrown for parity with eager mode.
124
              throw py::key_error();
125
            }
126

127
            // Call getItem on self.
128
            try {
129
              value = self->getItem(value);
130
            } catch (const std::out_of_range& e) { // Key doesn't exist.
131
              throw py::key_error();
132
            }
133

134
            return toPyObject(std::move(value));
135
          },
136
          py::return_value_policy::
137
              reference_internal) // Return value is a reference to an object
138
                                  // that resides in the ScriptDict
139
      .def(
140
          "__setitem__",
141
          [](const std::shared_ptr<ScriptDict>& self,
142
             py::object key,
143
             py::object value) {
144
            IValue key_ivalue, value_ivalue;
145

146
            // Try to convert the key to an IValue.
147
            try {
148
              key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
149
            } catch (const py::cast_error& e) {
150
              throw py::type_error();
151
            }
152

153
            // Try to convert the value to an IValue.
154
            try {
155
              value_ivalue =
156
                  toIValue(std::move(value), self->type()->getValueType());
157
            } catch (const py::cast_error& e) {
158
              throw py::type_error();
159
            }
160

161
            self->setItem(key_ivalue, value_ivalue);
162
          })
163
      .def(
164
          "__delitem__",
165
          [](const std::shared_ptr<ScriptDict>& self, py::object key) {
166
            IValue key_ivalue;
167

168
            // Try to convert the key to an IValue.
169
            try {
170
              key_ivalue = toIValue(std::move(key), self->type()->getKeyType());
171
            } catch (const py::cast_error& e) {
172
              throw py::type_error();
173
            }
174

175
            // If removed = false, that means the key didn't exist in the
176
            // dictionary.
177
            bool removed = self->delItem(key_ivalue);
178

179
            if (!removed) {
180
              throw py::key_error();
181
            }
182
          })
183
      .def(
184
          "__iter__",
185
          [](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
186
          py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
187
                                  // long as the iterator
188
      .def(
189
          "items",
190
          [](const std::shared_ptr<ScriptDict>& self) { return self->items(); },
191
          py::keep_alive<0, 1>()) // ScriptDict needs to be alive at least as
192
                                  // long as the iterator
193
      .def(
194
          "keys",
195
          [](const std::shared_ptr<ScriptDict>& self) { return self->iter(); },
196
          py::keep_alive<0, 1>()); // ScriptDict needs to be alive at least as
197
                                   // long as the iterator
198
}
199

200
} // namespace torch::jit
201

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.