pytorch

Форк
0
/
optim_baseline.py 
143 строки · 4.5 Кб
1
"""Script to generate baseline values from PyTorch optimization algorithms"""
2

3
import argparse
4
import math
5
import sys
6

7
import torch
8
import torch.optim
9

10

11
HEADER = """
12
#include <torch/types.h>
13

14
#include <vector>
15

16
namespace expected_parameters {
17
"""
18

19
FOOTER = "} // namespace expected_parameters"
20

21
PARAMETERS = "inline std::vector<std::vector<torch::Tensor>> {}() {{"
22

23
OPTIMIZERS = {
24
    "LBFGS": lambda p: torch.optim.LBFGS(p, 1.0),
25
    "LBFGS_with_line_search": lambda p: torch.optim.LBFGS(
26
        p, 1.0, line_search_fn="strong_wolfe"
27
    ),
28
    "Adam": lambda p: torch.optim.Adam(p, 1.0),
29
    "Adam_with_weight_decay": lambda p: torch.optim.Adam(p, 1.0, weight_decay=1e-2),
30
    "Adam_with_weight_decay_and_amsgrad": lambda p: torch.optim.Adam(
31
        p, 1.0, weight_decay=1e-6, amsgrad=True
32
    ),
33
    "AdamW": lambda p: torch.optim.AdamW(p, 1.0),
34
    "AdamW_without_weight_decay": lambda p: torch.optim.AdamW(p, 1.0, weight_decay=0),
35
    "AdamW_with_amsgrad": lambda p: torch.optim.AdamW(p, 1.0, amsgrad=True),
36
    "Adagrad": lambda p: torch.optim.Adagrad(p, 1.0),
37
    "Adagrad_with_weight_decay": lambda p: torch.optim.Adagrad(
38
        p, 1.0, weight_decay=1e-2
39
    ),
40
    "Adagrad_with_weight_decay_and_lr_decay": lambda p: torch.optim.Adagrad(
41
        p, 1.0, weight_decay=1e-6, lr_decay=1e-3
42
    ),
43
    "RMSprop": lambda p: torch.optim.RMSprop(p, 0.1),
44
    "RMSprop_with_weight_decay": lambda p: torch.optim.RMSprop(
45
        p, 0.1, weight_decay=1e-2
46
    ),
47
    "RMSprop_with_weight_decay_and_centered": lambda p: torch.optim.RMSprop(
48
        p, 0.1, weight_decay=1e-6, centered=True
49
    ),
50
    "RMSprop_with_weight_decay_and_centered_and_momentum": lambda p: torch.optim.RMSprop(
51
        p, 0.1, weight_decay=1e-6, centered=True, momentum=0.9
52
    ),
53
    "SGD": lambda p: torch.optim.SGD(p, 0.1),
54
    "SGD_with_weight_decay": lambda p: torch.optim.SGD(p, 0.1, weight_decay=1e-2),
55
    "SGD_with_weight_decay_and_momentum": lambda p: torch.optim.SGD(
56
        p, 0.1, momentum=0.9, weight_decay=1e-2
57
    ),
58
    "SGD_with_weight_decay_and_nesterov_momentum": lambda p: torch.optim.SGD(
59
        p, 0.1, momentum=0.9, weight_decay=1e-6, nesterov=True
60
    ),
61
}
62

63

64
def weight_init(module):
65
    if isinstance(module, torch.nn.Linear):
66
        stdev = 1.0 / math.sqrt(module.weight.size(1))
67
        for p in module.parameters():
68
            p.data.uniform_(-stdev, stdev)
69

70

71
def run(optimizer_name, iterations, sample_every):
72
    torch.manual_seed(0)
73
    model = torch.nn.Sequential(
74
        torch.nn.Linear(2, 3),
75
        torch.nn.Sigmoid(),
76
        torch.nn.Linear(3, 1),
77
        torch.nn.Sigmoid(),
78
    )
79
    model = model.to(torch.float64).apply(weight_init)
80

81
    optimizer = OPTIMIZERS[optimizer_name](model.parameters())
82

83
    input = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype=torch.float64)
84

85
    values = []
86
    for i in range(iterations):
87
        optimizer.zero_grad()
88

89
        output = model.forward(input)
90
        loss = output.sum()
91
        loss.backward()
92

93
        def closure():
94
            return torch.tensor([10.0])
95

96
        optimizer.step(closure)
97

98
        if i % sample_every == 0:
99
            values.append(
100
                [p.clone().flatten().data.numpy() for p in model.parameters()]
101
            )
102

103
    return values
104

105

106
def emit(optimizer_parameter_map):
107
    # Don't write generated with an @ in front, else this file is recognized as generated.
108
    print("// @{} from {}".format("generated", __file__))
109
    print(HEADER)
110
    for optimizer_name, parameters in optimizer_parameter_map.items():
111
        print(PARAMETERS.format(optimizer_name))
112
        print("  return {")
113
        for sample in parameters:
114
            print("    {")
115
            for parameter in sample:
116
                parameter_values = "{{{}}}".format(", ".join(map(str, parameter)))
117
                print(f"      torch::tensor({parameter_values}),")
118
            print("    },")
119
        print("  };")
120
        print("}\n")
121
    print(FOOTER)
122

123

124
def main():
125
    parser = argparse.ArgumentParser(
126
        "Produce optimization output baseline from PyTorch"
127
    )
128
    parser.add_argument("-i", "--iterations", default=1001, type=int)
129
    parser.add_argument("-s", "--sample-every", default=100, type=int)
130
    options = parser.parse_args()
131

132
    optimizer_parameter_map = {}
133
    for optimizer in OPTIMIZERS.keys():
134
        sys.stderr.write(f"Evaluating {optimizer} ...\n")
135
        optimizer_parameter_map[optimizer] = run(
136
            optimizer, options.iterations, options.sample_every
137
        )
138

139
    emit(optimizer_parameter_map)
140

141

142
if __name__ == "__main__":
143
    main()
144

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.