pytorch

Форк
0
/
test_jit_hooks.cpp 
136 строк · 6.1 Кб
1
#include <torch/script.h>
2

3
#include <memory>
4
#include <string>
5
#include <sstream>
6
#include <vector>
7

8
#include <iostream>
9

10
void test_module_forward_invocation_no_hooks_run(
11
    const std::string &path_to_exported_script_module) {
12
  std::cout << "testing: "
13
            << "test_module_forward_invocation_no_hooks_run" << std::endl;
14
  torch::jit::Module module =
15
      torch::jit::load(path_to_exported_script_module + "_" +
16
                       "test_module_forward_multiple_inputs" + ".pt");
17
  std::vector<torch::jit::IValue> inputs = {torch::List<std::string>({"a"}),
18
                                            torch::jit::IValue("no_pre_hook")};
19

20
  auto output = module(inputs);
21
  auto output_forward = module.forward(inputs);
22
  torch::jit::IValue correct_direct_output =
23
      std::tuple<torch::List<std::string>, std::string>(
24
          {"a", "outer_mod_name", "inner_mod_name"}, "no_pre_hook_");
25
  std::cout << "----- module output: " << output << std::endl;
26
  std::cout << "----- module forward output: " << output_forward << std::endl;
27
  AT_ASSERT(correct_direct_output == output_forward);
28
}
29

30
void test_submodule_called_directly_with_hooks(
31
    const std::string &path_to_exported_script_module) {
32
  std::cout << "testing: "
33
            << "test_submodule_to_call_directly_with_hooks" << std::endl;
34
  torch::jit::Module module =
35
      torch::jit::load(path_to_exported_script_module + "_" +
36
                       "test_submodule_to_call_directly_with_hooks" + ".pt");
37
  torch::jit::Module submodule = *module.modules().begin();
38
  std::vector<torch::jit::IValue> inputs = {"a"};
39

40
  auto output = submodule(inputs);
41
  torch::jit::IValue correct_output = "pre_hook_override_name_inner_mod_fh";
42
  std::cout << "----- submodule's output: " << output << std::endl;
43
  std::cout << "----- expected output   : " << correct_output << std::endl;
44
  AT_ASSERT(correct_output == correct_output);
45
}
46

47
struct HooksTestCase {
48
  std::string name;
49
  std::vector<torch::jit::IValue> inputs;
50
  torch::jit::IValue output;
51
  HooksTestCase(std::string name, std::vector<torch::jit::IValue> inputs,
52
                torch::jit::IValue output)
53
      : name(name), inputs(std::move(inputs)), output(std::move(output)) {}
54
};
55

56
int main(int argc, const char *argv[]) {
57
  if (argc != 2) {
58
    std::cerr << "usage: test_jit_hooks <path-to-exported-script-module>\n";
59
    return -1;
60
  }
61
  const std::string path_to_exported_script_module = argv[1];
62
  std::cout << "path to exported module:" << path_to_exported_script_module
63
            << std::endl;
64
  std::cout << "Tesing JIT Hooks in CPP" << std::endl;
65

66
  // Note: Modules loaded in this file are produced in /test/jit_hooks/model.py
67

68
  std::vector<HooksTestCase> test_cases = {
69
      HooksTestCase("test_submodule_multiple_hooks_single_input",
70
                    {torch::jit::IValue("a")},
71
                    "pre_hook_override_name2_inner_mod_fwh1"),
72
      HooksTestCase("test_submodule_hook_return_nothing",
73
                    {torch::jit::IValue("a")}, "a_outermod_inner_mod"),
74
      HooksTestCase("test_submodule_same_hook_repeated",
75
                    {torch::jit::IValue("a")},
76
                    "a_outermod_ph_ph_inner_mod_fh_fh"),
77
      HooksTestCase("test_submodule_forward_single_input",
78
                    {torch::jit::IValue("a")},
79
                    "pre_hook_override_name_inner_mod"),
80
      HooksTestCase(
81
          "test_submodule_multiple_hooks_multiple_inputs",
82
          {torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
83
          std::tuple<torch::List<std::string>, std::string>(
84
              {"pre_hook_override_name", "inner_mod_name"},
85
              "pre_hook_override2_fh1_fh2")),
86
      HooksTestCase(
87
          "test_submodule_forward_multiple_inputs",
88
          {torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
89
          std::tuple<torch::List<std::string>, std::string>(
90
              {"pre_hook_override_name", "inner_mod_name"},
91
              "pre_hook_override_fh")),
92
      HooksTestCase("test_module_forward_single_input",
93
                    {torch::jit::IValue("a")},
94
                    "pre_hook_override_name_outermod_inner_mod_fh"),
95
      HooksTestCase("test_module_multiple_hooks_single_input",
96
                    {torch::jit::IValue("a")},
97
                    "pre_hook_override_name2_outermod_inner_mod_fh1_fh2"),
98
      HooksTestCase("test_module_hook_return_nothing",
99
                    {torch::jit::IValue("a")}, "a_outermod_inner_mod"),
100
      HooksTestCase("test_module_same_hook_repeated", {torch::jit::IValue("a")},
101
                    "a_ph_ph_outermod_inner_mod_fh_fh"),
102
      HooksTestCase(
103
          "test_module_forward_multiple_inputs",
104
          {torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
105
          std::tuple<torch::List<std::string>, std::string>(
106
              {"pre_hook_override_name", "outer_mod_name", "inner_mod_name"},
107
              "pre_hook_override_fh")),
108
      HooksTestCase(
109
          "test_module_multiple_hooks_multiple_inputs",
110
          {torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
111
          std::tuple<torch::List<std::string>, std::string>(
112
              {"pre_hook_override_name2", "outer_mod_name", "inner_mod_name"},
113
              "pre_hook_override_fh1_fh2")),
114
      HooksTestCase("test_module_no_forward_input", {}, torch::jit::IValue()),
115
      HooksTestCase("test_forward_tuple_input", {std::tuple<int>(11)},
116
                    {std::tuple<int>(11)}),
117
  };
118

119
  for (HooksTestCase &test_case : test_cases) {
120
    std::cout << "testing: " << test_case.name << std::endl;
121
    torch::jit::Module module = torch::jit::load(
122
        path_to_exported_script_module + "_" + test_case.name + ".pt");
123
    torch::jit::IValue output = module(test_case.inputs);
124
    std::cout << "----- module's output: " << output << std::endl;
125
    std::cout << "----- expected output: " << test_case.output << std::endl;
126
    AT_ASSERT(output == test_case.output);
127
  }
128

129
  // special test cases that don't call the imported module directly
130
  test_module_forward_invocation_no_hooks_run(path_to_exported_script_module);
131
  test_submodule_called_directly_with_hooks(path_to_exported_script_module);
132

133
  std::cout << "JIT CPP Hooks okay!" << std::endl;
134

135
  return 0;
136
}
137

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.