pytorch

Форк
0
91 строка · 3.1 Кб
1
import argparse
2
import os
3
import sys
4

5
import torch
6

7

8
# grab modules from test_jit_hooks.cpp
9
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
10
sys.path.append(pytorch_test_dir)
11
from jit.test_hooks_modules import (
12
    create_forward_tuple_input,
13
    create_module_forward_multiple_inputs,
14
    create_module_forward_single_input,
15
    create_module_hook_return_nothing,
16
    create_module_multiple_hooks_multiple_inputs,
17
    create_module_multiple_hooks_single_input,
18
    create_module_no_forward_input,
19
    create_module_same_hook_repeated,
20
    create_submodule_forward_multiple_inputs,
21
    create_submodule_forward_single_input,
22
    create_submodule_hook_return_nothing,
23
    create_submodule_multiple_hooks_multiple_inputs,
24
    create_submodule_multiple_hooks_single_input,
25
    create_submodule_same_hook_repeated,
26
    create_submodule_to_call_directly_with_hooks,
27
)
28

29

30
# Create saved modules for JIT forward hooks and pre-hooks
31
def main():
32
    parser = argparse.ArgumentParser(
33
        description="Serialize a script modules with hooks attached"
34
    )
35
    parser.add_argument("--export-script-module-to", required=True)
36
    options = parser.parse_args()
37
    global save_name
38
    save_name = options.export_script_module_to + "_"
39

40
    tests = [
41
        (
42
            "test_submodule_forward_single_input",
43
            create_submodule_forward_single_input(),
44
        ),
45
        (
46
            "test_submodule_forward_multiple_inputs",
47
            create_submodule_forward_multiple_inputs(),
48
        ),
49
        (
50
            "test_submodule_multiple_hooks_single_input",
51
            create_submodule_multiple_hooks_single_input(),
52
        ),
53
        (
54
            "test_submodule_multiple_hooks_multiple_inputs",
55
            create_submodule_multiple_hooks_multiple_inputs(),
56
        ),
57
        ("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
58
        ("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
59
        ("test_module_forward_single_input", create_module_forward_single_input()),
60
        (
61
            "test_module_forward_multiple_inputs",
62
            create_module_forward_multiple_inputs(),
63
        ),
64
        (
65
            "test_module_multiple_hooks_single_input",
66
            create_module_multiple_hooks_single_input(),
67
        ),
68
        (
69
            "test_module_multiple_hooks_multiple_inputs",
70
            create_module_multiple_hooks_multiple_inputs(),
71
        ),
72
        ("test_module_hook_return_nothing", create_module_hook_return_nothing()),
73
        ("test_module_same_hook_repeated", create_module_same_hook_repeated()),
74
        ("test_module_no_forward_input", create_module_no_forward_input()),
75
        ("test_forward_tuple_input", create_forward_tuple_input()),
76
        (
77
            "test_submodule_to_call_directly_with_hooks",
78
            create_submodule_to_call_directly_with_hooks(),
79
        ),
80
    ]
81

82
    for name, model in tests:
83
        m_scripted = torch.jit.script(model)
84
        filename = save_name + name + ".pt"
85
        torch.jit.save(m_scripted, filename)
86

87
    print("OK: completed saving modules with hooks!")
88

89

90
if __name__ == "__main__":
91
    main()
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.