9
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
10
sys.path.append(pytorch_test_dir)
11
from jit.test_hooks_modules import (
12
create_forward_tuple_input,
13
create_module_forward_multiple_inputs,
14
create_module_forward_single_input,
15
create_module_hook_return_nothing,
16
create_module_multiple_hooks_multiple_inputs,
17
create_module_multiple_hooks_single_input,
18
create_module_no_forward_input,
19
create_module_same_hook_repeated,
20
create_submodule_forward_multiple_inputs,
21
create_submodule_forward_single_input,
22
create_submodule_hook_return_nothing,
23
create_submodule_multiple_hooks_multiple_inputs,
24
create_submodule_multiple_hooks_single_input,
25
create_submodule_same_hook_repeated,
26
create_submodule_to_call_directly_with_hooks,
32
parser = argparse.ArgumentParser(
33
description="Serialize a script modules with hooks attached"
35
parser.add_argument("--export-script-module-to", required=True)
36
options = parser.parse_args()
38
save_name = options.export_script_module_to + "_"
42
"test_submodule_forward_single_input",
43
create_submodule_forward_single_input(),
46
"test_submodule_forward_multiple_inputs",
47
create_submodule_forward_multiple_inputs(),
50
"test_submodule_multiple_hooks_single_input",
51
create_submodule_multiple_hooks_single_input(),
54
"test_submodule_multiple_hooks_multiple_inputs",
55
create_submodule_multiple_hooks_multiple_inputs(),
57
("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
58
("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
59
("test_module_forward_single_input", create_module_forward_single_input()),
61
"test_module_forward_multiple_inputs",
62
create_module_forward_multiple_inputs(),
65
"test_module_multiple_hooks_single_input",
66
create_module_multiple_hooks_single_input(),
69
"test_module_multiple_hooks_multiple_inputs",
70
create_module_multiple_hooks_multiple_inputs(),
72
("test_module_hook_return_nothing", create_module_hook_return_nothing()),
73
("test_module_same_hook_repeated", create_module_same_hook_repeated()),
74
("test_module_no_forward_input", create_module_no_forward_input()),
75
("test_forward_tuple_input", create_forward_tuple_input()),
77
"test_submodule_to_call_directly_with_hooks",
78
create_submodule_to_call_directly_with_hooks(),
82
for name, model in tests:
83
m_scripted = torch.jit.script(model)
84
filename = save_name + name + ".pt"
85
torch.jit.save(m_scripted, filename)
87
print("OK: completed saving modules with hooks!")
90
if __name__ == "__main__":