pytorch
61 строка · 1.9 Кб
1#include <torch/extension.h>2
3// test include_dirs in setuptools.setup with relative path
4#include <tmp.h>5#include <ATen/OpMathType.h>6
7torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {8return x.sigmoid() + y.sigmoid();9}
10
11struct MatrixMultiplier {12MatrixMultiplier(int A, int B) {13tensor_ =14torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));15}16torch::Tensor forward(torch::Tensor weights) {17return tensor_.mm(weights);18}19torch::Tensor get() const {20return tensor_;21}22
23private:24torch::Tensor tensor_;25};26
27bool function_taking_optional(std::optional<torch::Tensor> tensor) {28return tensor.has_value();29}
30
31torch::Tensor random_tensor() {32return torch::randn({1});33}
34
35at::ScalarType get_math_type(at::ScalarType other) {36return at::toOpMathType(other);37}
38
39PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {40m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");41m.def(42"function_taking_optional",43&function_taking_optional,44"function_taking_optional");45py::class_<MatrixMultiplier>(m, "MatrixMultiplier")46.def(py::init<int, int>())47.def("forward", &MatrixMultiplier::forward)48.def("get", &MatrixMultiplier::get);49
50m.def("get_complex", []() { return c10::complex<double>(1.0, 2.0); });51m.def("get_device", []() { return at::device_of(random_tensor()).value(); });52m.def("get_generator", []() { return at::detail::getDefaultCPUGenerator(); });53m.def("get_intarrayref", []() { return at::IntArrayRef({1, 2, 3}); });54m.def("get_memory_format", []() { return c10::get_contiguous_memory_format(); });55m.def("get_storage", []() { return random_tensor().storage(); });56m.def("get_symfloat", []() { return c10::SymFloat(1.0); });57m.def("get_symint", []() { return c10::SymInt(1); });58m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); });59m.def("get_tensor", []() { return random_tensor(); });60m.def("get_math_type", &get_math_type);61}
62