pytorch

Форк
0
/
test_custom_backend.cpp 
39 строк · 1.3 Кб
1
#include <torch/cuda.h>
2
#include <torch/script.h>
3

4
#include <string>
5

6
#include "custom_backend.h"
7

8
// Load a module lowered for the custom backend from \p path and test that
9
// it can be executed and produces correct results.
10
void load_serialized_lowered_module_and_execute(const std::string& path) {
11
  torch::jit::Module module = torch::jit::load(path);
12
  // The custom backend is hardcoded to compute f(a, b) = (a + b, a - b).
13
  auto tensor = torch::ones(5);
14
  std::vector<torch::jit::IValue> inputs{tensor, tensor};
15
  auto output = module.forward(inputs);
16
  AT_ASSERT(output.isTuple());
17
  auto output_elements = output.toTupleRef().elements();
18
  for (auto& e : output_elements) {
19
    AT_ASSERT(e.isTensor());
20
  }
21
  AT_ASSERT(output_elements.size(), 2);
22
  AT_ASSERT(output_elements[0].toTensor().allclose(tensor + tensor));
23
  AT_ASSERT(output_elements[1].toTensor().allclose(tensor - tensor));
24
}
25

26
int main(int argc, const char* argv[]) {
27
  if (argc != 2) {
28
    std::cerr
29
        << "usage: test_custom_backend <path-to-exported-script-module>\n";
30
    return -1;
31
  }
32
  const std::string path_to_exported_script_module = argv[1];
33

34
  std::cout << "Testing " << torch::custom_backend::getBackendName() << "\n";
35
  load_serialized_lowered_module_and_execute(path_to_exported_script_module);
36

37
  std::cout << "OK\n";
38
  return 0;
39
}
40

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.