pytorch
28 строк · 827.0 Байт
1#include <torch/extension.h>
2#include <torch/script.h>
3
4using torch::List;
5using torch::Tensor;
6
7Tensor consume(Tensor a) {
8return a;
9}
10
11List<Tensor> consume_list(List<Tensor> a) {
12return a;
13}
14
15// When JIT tracing is used on function with constant for loop,
16// the for loop is optimized away because of dead code elimination.
17// That caused an issue for our op benchmark which needs to run an op
18// in a loop and report the execution time. This diff resolves that issue by
19// registering this consume op with correct alias information which is DEFAULT.
20TORCH_LIBRARY_FRAGMENT(operator_benchmark, m) {
21m.def("_consume", &consume);
22m.def("_consume.list", &consume_list);
23}
24
25PYBIND11_MODULE(benchmark_cpp_extension, m) {
26m.def("_consume", &consume, "consume");
27m.def("_consume_list", &consume_list, "consume_list");
28}
29