pytorch

Форк
0
/
extension.cpp 
61 строка · 1.9 Кб
1
#include <torch/extension.h>
2

3
// test include_dirs in setuptools.setup with relative path
4
#include <tmp.h>
5
#include <ATen/OpMathType.h>
6

7
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
8
  return x.sigmoid() + y.sigmoid();
9
}
10

11
struct MatrixMultiplier {
12
  MatrixMultiplier(int A, int B) {
13
    tensor_ =
14
        torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
15
  }
16
  torch::Tensor forward(torch::Tensor weights) {
17
    return tensor_.mm(weights);
18
  }
19
  torch::Tensor get() const {
20
    return tensor_;
21
  }
22

23
 private:
24
  torch::Tensor tensor_;
25
};
26

27
bool function_taking_optional(std::optional<torch::Tensor> tensor) {
28
  return tensor.has_value();
29
}
30

31
torch::Tensor random_tensor() {
32
  return torch::randn({1});
33
}
34

35
at::ScalarType get_math_type(at::ScalarType other) {
36
  return at::toOpMathType(other);
37
}
38

39
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
40
  m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
41
  m.def(
42
      "function_taking_optional",
43
      &function_taking_optional,
44
      "function_taking_optional");
45
  py::class_<MatrixMultiplier>(m, "MatrixMultiplier")
46
      .def(py::init<int, int>())
47
      .def("forward", &MatrixMultiplier::forward)
48
      .def("get", &MatrixMultiplier::get);
49

50
  m.def("get_complex", []() { return c10::complex<double>(1.0, 2.0); });
51
  m.def("get_device", []() { return at::device_of(random_tensor()).value(); });
52
  m.def("get_generator", []() { return at::detail::getDefaultCPUGenerator(); });
53
  m.def("get_intarrayref", []() { return at::IntArrayRef({1, 2, 3}); });
54
  m.def("get_memory_format", []() { return c10::get_contiguous_memory_format(); });
55
  m.def("get_storage", []() { return random_tensor().storage(); });
56
  m.def("get_symfloat", []() { return c10::SymFloat(1.0); });
57
  m.def("get_symint", []() { return c10::SymInt(1); });
58
  m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); });
59
  m.def("get_tensor", []() { return random_tensor(); });
60
  m.def("get_math_type", &get_math_type);
61
}
62

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.