1
#include <ATen/Layout.h>
2
#include <c10/core/ScalarType.h>
3
#include <torch/csrc/DynamicTypes.h>
4
#include <torch/csrc/Exceptions.h>
5
#include <torch/csrc/Layout.h>
6
#include <torch/csrc/python_headers.h>
7
#include <torch/csrc/utils/object_ptr.h>
8
#include <torch/csrc/utils/tensor_layouts.h>
13
#define REGISTER_LAYOUT(layout, LAYOUT) \
14
PyObject* layout##_layout = \
15
THPLayout_New(at::Layout::LAYOUT, "torch." #layout); \
16
Py_INCREF(layout##_layout); \
17
if (PyModule_AddObject(torch_module, "" #layout, layout##_layout) != 0) { \
18
throw python_error(); \
20
registerLayoutObject((THPLayout*)layout##_layout, at::Layout::LAYOUT);
22
void initializeLayouts() {
23
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
27
PyObject* strided_layout =
28
THPLayout_New(at::Layout::Strided, "torch.strided");
29
Py_INCREF(strided_layout);
30
if (PyModule_AddObject(torch_module, "strided", strided_layout) != 0) {
33
registerLayoutObject((THPLayout*)strided_layout, at::Layout::Strided);
35
PyObject* sparse_coo_layout =
36
THPLayout_New(at::Layout::Sparse, "torch.sparse_coo");
37
Py_INCREF(sparse_coo_layout);
38
if (PyModule_AddObject(torch_module, "sparse_coo", sparse_coo_layout) != 0) {
41
registerLayoutObject((THPLayout*)sparse_coo_layout, at::Layout::Sparse);
43
REGISTER_LAYOUT(sparse_csr, SparseCsr)
44
REGISTER_LAYOUT(sparse_csc, SparseCsc)
45
REGISTER_LAYOUT(sparse_bsr, SparseBsr)
46
REGISTER_LAYOUT(sparse_bsc, SparseBsc)
48
PyObject* mkldnn_layout = THPLayout_New(at::Layout::Mkldnn, "torch._mkldnn");
49
Py_INCREF(mkldnn_layout);
50
if (PyModule_AddObject(torch_module, "_mkldnn", mkldnn_layout) != 0) {
53
registerLayoutObject((THPLayout*)mkldnn_layout, at::Layout::Mkldnn);
55
REGISTER_LAYOUT(jagged, Jagged);