1
#include <c10/util/flat_hash_map.h>
2
#include <torch/csrc/Exceptions.h>
3
#include <torch/csrc/python_dimname.h>
4
#include <torch/csrc/utils/python_strings.h>
8
struct InternedStringsTable {
9
InternedStringsTable() = default;
11
~InternedStringsTable();
12
InternedStringsTable(const InternedStringsTable&) = delete;
13
InternedStringsTable& operator=(InternedStringsTable const&) = delete;
14
InternedStringsTable(InternedStringsTable&&) = delete;
15
InternedStringsTable& operator=(InternedStringsTable&&) = delete;
17
at::optional<at::Dimname> lookup(PyObject* obj);
19
void addMapping(PyObject* obj, at::Dimname dimname);
22
ska::flat_hash_map<PyObject*, at::Dimname> py_interned_string_to_dimname_;
25
InternedStringsTable kPyInternedStringToDimname;
28
InternedStringsTable::~InternedStringsTable() {
30
if (Py_IsInitialized()) {
31
pybind11::gil_scoped_acquire gil;
32
for (auto it = py_interned_string_to_dimname_.begin();
33
it != py_interned_string_to_dimname_.end();
41
at::optional<at::Dimname> InternedStringsTable::lookup(PyObject* obj) {
42
auto it = py_interned_string_to_dimname_.find(obj);
43
if (it == py_interned_string_to_dimname_.end()) {
49
void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) {
55
py_interned_string_to_dimname_.emplace(obj, dimname);
60
bool THPUtils_checkDimname(PyObject* obj) {
61
return obj == Py_None || THPUtils_checkString(obj);
66
bool THPUtils_checkDimnameList(PyObject* obj) {
67
auto tuple = PyTuple_Check(obj);
68
if (!tuple && !PyList_Check(obj)) {
72
const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
77
tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
78
return THPUtils_checkDimname(first_elt);
81
at::Dimname THPDimname_parse(PyObject* obj) {
83
return at::Dimname::wildcard();
87
THPUtils_checkString(obj),
88
"expected None or string for Dimname but got ",
89
Py_TYPE(obj)->tp_name);
91
if (!THPUtils_isInterned(obj)) {
97
THPUtils_internStringInPlace(&obj);
101
auto maybeDimname = torch::kPyInternedStringToDimname.lookup(obj);
103
return *maybeDimname;
106
const auto name = THPUtils_unpackString(obj);
107
auto dimname = at::Dimname::fromSymbol(at::Symbol::dimname(name));
108
torch::kPyInternedStringToDimname.addMapping(obj, dimname);