google-research

Форк
0
/
experiment_synthetic_autoregressive.py 
272 строки · 9.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Running training and evaluation on 4 synthetic autoregressive datasets."""
17

18
import datetime
19
import os
20
import random
21

22
from absl import app
23
from absl import flags
24
import analyze_experiments
25
import datasets
26
import matplotlib.pyplot as plt
27
import model_library
28
from models import model_utils
29
import numpy as np
30
import tensorflow as tf
31

32
FLAGS = flags.FLAGS
33

34
now = datetime.datetime.now()
35
launch_time = now.strftime("%H:%M:%S")
36

37
flags.DEFINE_integer(
38
    "gpu_index", -1,
39
    "GPU index to run the job among the available GPUs, if it is -1, use CPUs.")
40
flags.DEFINE_integer("num_trials", 100,
41
                     "Number of hyperparameter trials to search for.")
42
flags.DEFINE_integer("seed", 2, "Random seed")
43
flags.DEFINE_string("model_type", "tft_saf", "Proposed forecasting method.")
44
flags.DEFINE_bool("display_all_models", "True",
45
                  "Whether to print all the models or on the best model.")
46
flags.DEFINE_integer("len_total", 750, "Number of samples.")
47
flags.DEFINE_integer("synthetic_data_option", 1, "Synthethic data choice.")
48
flags.DEFINE_string("filename", "experiment_synthetic" + launch_time,
49
                    "Filename to save the model artifacts.")
50

51

52
def main(args):
53
  """Orchestrates dataset creation, model training and evaluation.
54

55
  Args:
56
    args: Not used.
57
  """
58
  del args  # Not used.
59

60
  tf.keras.backend.set_floatx("float32")
61
  tf.autograph.set_verbosity(0)
62

63
  if not os.path.exists("figures"):
64
    os.makedirs("figures")
65

66
  model_utils.set_seed(FLAGS.seed)
67

68
  # Set the GPU index.
69
  if FLAGS.gpu_index >= 0:
70
    gpus = tf.config.experimental.list_physical_devices(device_type="GPU")
71
    tf.config.experimental.set_visible_devices(
72
        devices=gpus[FLAGS.gpu_index], device_type="GPU")
73
    tf.config.experimental.set_memory_growth(
74
        device=gpus[FLAGS.gpu_index], enable=True)
75

76
  (train_dataset, valid_dataset, test_dataset,
77
   dataset_params) = datasets.synthetic_autoregressive(
78
       synthetic_data_option=FLAGS.synthetic_data_option,
79
       len_total=FLAGS.len_total)
80

81
  # Hyperparameter search
82
  use_nowcast_errors_candidates = [True, False]
83
  temporal_batch_size_eval = dataset_params["num_items"]
84
  batch_size_candidates = [32, 64, 128, 256]
85
  learning_rate_candidates = [0.0001, 0.0003, 0.001, 0.003]
86
  learning_rate_adaptation_candidates = [0.0003, 0.001, 0.003, 0.01]
87
  num_units_candidates = [16, 32, 64]
88
  iterations_candidates = [3000]
89
  num_encode_candidates = [10, 30, 50]
90
  keep_prob_candidates = [0.5, 0.8, 1.0]
91
  num_heads_candidates = [1, 2]
92
  representation_combination_candidates = ["concatenation", "addition"]
93
  reset_weights_each_eval_step_candidates = [True]
94

95
  best_valid_metric = 1e128
96
  best_hparams = []
97
  if FLAGS.display_all_models:
98
    all_val_mae = []
99
    all_test_mae = []
100
    all_val_mape = []
101
    all_test_mape = []
102
    all_val_wmape = []
103
    all_test_wmape = []
104
    all_val_mse = []
105
    all_test_mse = []
106

107
  for ni in range(FLAGS.num_trials):
108
    # Try setting the random seed each trial so that we tend to get repeatable
109
    # hyper-parameters. If you add one at the end the previous ones should
110
    # remain unchanged.
111
    model_utils.set_seed(FLAGS.seed + ni)
112

113
    chosen_hparams = {
114
        "batch_size":
115
            random.sample(batch_size_candidates, 1)[0],
116
        "learning_rate":
117
            random.sample(learning_rate_candidates, 1)[0],
118
        "learning_rate_adaptation":
119
            random.sample(learning_rate_adaptation_candidates, 1)[0],
120
        "num_units":
121
            random.sample(num_units_candidates, 1)[0],
122
        "iterations":
123
            random.sample(iterations_candidates, 1)[0],
124
        "num_encode":
125
            random.sample(num_encode_candidates, 1)[0],
126
        "keep_prob":
127
            random.sample(keep_prob_candidates, 1)[0],
128
        "num_heads":
129
            random.sample(num_heads_candidates, 1)[0],
130
        "representation_combination":
131
            random.sample(representation_combination_candidates, 1)[0],
132
        "reset_weights_each_eval_step":
133
            random.sample(reset_weights_each_eval_step_candidates, 1)[0],
134
        "use_nowcast_errors":
135
            random.sample(use_nowcast_errors_candidates, 1)[0],
136
        "target_index":
137
            dataset_params["target_index"],
138
        "static_index_cutoff":
139
            dataset_params["static_index_cutoff"],
140
        "display_iterations":
141
            250,
142
        "forecast_horizon":
143
            dataset_params["forecast_horizon"],
144
        "num_features":
145
            dataset_params["num_features"],
146
        "num_static":
147
            dataset_params["num_static"],
148
        "num_val_splits":
149
            (dataset_params["num_items"] * dataset_params["len_val"] //
150
             temporal_batch_size_eval),
151
        "num_test_splits":
152
            (dataset_params["num_items"] * dataset_params["len_test"] //
153
             temporal_batch_size_eval),
154
        "temporal_batch_size_eval":
155
            temporal_batch_size_eval,
156
    }
157

158
    model = model_library.get_model_type(
159
        FLAGS.model_type, chosen_hparams, loss_form="MSE")
160

161
    batched_train_dataset = iter(
162
        train_dataset.shuffle(1000).repeat(100000000).batch(
163
            chosen_hparams["batch_size"]))
164
    batched_valid_dataset = iter(
165
        valid_dataset.repeat(100000000).batch(temporal_batch_size_eval))
166
    batched_test_dataset = iter(
167
        test_dataset.repeat(100000000).batch(temporal_batch_size_eval))
168

169
    eval_metrics = model.run_train_eval_pipeline(batched_train_dataset,
170
                                                 batched_valid_dataset,
171
                                                 batched_test_dataset)
172

173
    if FLAGS.display_all_models and not np.isnan(eval_metrics["val_mse"][-1]):
174
      print("Best hyperparameter combination: ", flush=True)
175
      print(best_hparams, flush=True)
176

177
      # Select the model iteration based on the validation performance
178
      model_selection_index = np.argmin(eval_metrics["val_mse"])
179

180
      all_val_mae.append(eval_metrics["val_mae"][model_selection_index])
181
      all_test_mae.append(eval_metrics["test_mae"][model_selection_index])
182
      all_val_mape.append(eval_metrics["val_mape"][model_selection_index])
183
      all_test_mape.append(eval_metrics["test_mape"][model_selection_index])
184
      all_val_wmape.append(eval_metrics["val_wmape"][model_selection_index])
185
      all_test_wmape.append(eval_metrics["test_wmape"][model_selection_index])
186
      all_val_mse.append(eval_metrics["val_mse"][model_selection_index])
187
      all_test_mse.append(eval_metrics["test_mse"][model_selection_index])
188

189
      analyze_experiments.display_metrics(
190
          all_val_mse, all_test_mse, "MSE", 100,
191
          "figures/" + FLAGS.filename + "_all_hparam_runs_MSE.png")
192

193
      print("Best test mae: ", flush=True)
194
      print(all_test_mae[np.argmin(all_val_mae)], flush=True)
195

196
      print("Best test mape: ", flush=True)
197
      print(all_test_mape[np.argmin(all_val_mape)], flush=True)
198

199
      print("Best test wmape: ", flush=True)
200
      print(all_test_wmape[np.argmin(all_val_wmape)], flush=True)
201

202
      print("Best test mse: ", flush=True)
203
      print(all_test_mse[np.argmin(all_val_mse)], flush=True)
204

205
      print("Correlation: ", flush=True)
206
      print(str(np.corrcoef(all_val_mse, all_test_mse)[0, 1]), flush=True)
207

208
      print("Average val/test performance: ", flush=True)
209
      print(
210
          np.mean(np.asarray(all_val_mae) / np.asarray(all_test_mae)),
211
          flush=True)
212

213
    current_valid_metric = eval_metrics["val_mse"][model_selection_index]
214
    if current_valid_metric < best_valid_metric:
215

216
      best_hparams = chosen_hparams
217

218
      plt.figure()
219
      plt.plot(
220
          eval_metrics["display_iterations"],
221
          eval_metrics["val_mae"],
222
          "-r",
223
          label="Val")
224
      plt.plot(
225
          eval_metrics["display_iterations"],
226
          eval_metrics["test_mae"],
227
          "-b",
228
          label="Test")
229
      plt.xlabel("Iterations")
230
      plt.ylabel("MAE")
231
      plt.legend()
232
      plt.savefig("figures/" + FLAGS.filename + "_mae_convergence.png")
233

234
      plt.figure()
235
      plt.plot(
236
          eval_metrics["display_iterations"],
237
          eval_metrics["val_mse"],
238
          "-r",
239
          label="Val")
240
      plt.plot(
241
          eval_metrics["display_iterations"],
242
          eval_metrics["test_mse"],
243
          "-b",
244
          label="Test")
245
      plt.xlabel("Iterations")
246
      plt.ylabel("MSE")
247
      plt.legend()
248
      plt.savefig("figures/" + FLAGS.filename + "_mse_convergence.png")
249

250
      if "meta_self_adapting" in FLAGS.model_type:
251
        plt.figure()
252
        plt.plot(
253
            eval_metrics["display_iterations"],
254
            eval_metrics["val_self_adaptation"],
255
            "-b",
256
            label="Valid self-adapting")
257
        plt.plot(
258
            eval_metrics["display_iterations"],
259
            eval_metrics["test_self_adaptation"],
260
            "-r",
261
            label="Test self-adapting")
262
        plt.xlabel("Iterations")
263
        plt.ylabel("MSE loss")
264
        plt.legend()
265
        plt.savefig("figures/" + FLAGS.filename +
266
                    "_self_adaptation_convergence.png")
267

268
      best_valid_metric = current_valid_metric
269

270

271
if __name__ == "__main__":
272
  app.run(main)
273

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.