google-research
272 строки · 9.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Running training and evaluation on 4 synthetic autoregressive datasets."""
17
18import datetime
19import os
20import random
21
22from absl import app
23from absl import flags
24import analyze_experiments
25import datasets
26import matplotlib.pyplot as plt
27import model_library
28from models import model_utils
29import numpy as np
30import tensorflow as tf
31
32FLAGS = flags.FLAGS
33
34now = datetime.datetime.now()
35launch_time = now.strftime("%H:%M:%S")
36
37flags.DEFINE_integer(
38"gpu_index", -1,
39"GPU index to run the job among the available GPUs, if it is -1, use CPUs.")
40flags.DEFINE_integer("num_trials", 100,
41"Number of hyperparameter trials to search for.")
42flags.DEFINE_integer("seed", 2, "Random seed")
43flags.DEFINE_string("model_type", "tft_saf", "Proposed forecasting method.")
44flags.DEFINE_bool("display_all_models", "True",
45"Whether to print all the models or on the best model.")
46flags.DEFINE_integer("len_total", 750, "Number of samples.")
47flags.DEFINE_integer("synthetic_data_option", 1, "Synthethic data choice.")
48flags.DEFINE_string("filename", "experiment_synthetic" + launch_time,
49"Filename to save the model artifacts.")
50
51
52def main(args):
53"""Orchestrates dataset creation, model training and evaluation.
54
55Args:
56args: Not used.
57"""
58del args # Not used.
59
60tf.keras.backend.set_floatx("float32")
61tf.autograph.set_verbosity(0)
62
63if not os.path.exists("figures"):
64os.makedirs("figures")
65
66model_utils.set_seed(FLAGS.seed)
67
68# Set the GPU index.
69if FLAGS.gpu_index >= 0:
70gpus = tf.config.experimental.list_physical_devices(device_type="GPU")
71tf.config.experimental.set_visible_devices(
72devices=gpus[FLAGS.gpu_index], device_type="GPU")
73tf.config.experimental.set_memory_growth(
74device=gpus[FLAGS.gpu_index], enable=True)
75
76(train_dataset, valid_dataset, test_dataset,
77dataset_params) = datasets.synthetic_autoregressive(
78synthetic_data_option=FLAGS.synthetic_data_option,
79len_total=FLAGS.len_total)
80
81# Hyperparameter search
82use_nowcast_errors_candidates = [True, False]
83temporal_batch_size_eval = dataset_params["num_items"]
84batch_size_candidates = [32, 64, 128, 256]
85learning_rate_candidates = [0.0001, 0.0003, 0.001, 0.003]
86learning_rate_adaptation_candidates = [0.0003, 0.001, 0.003, 0.01]
87num_units_candidates = [16, 32, 64]
88iterations_candidates = [3000]
89num_encode_candidates = [10, 30, 50]
90keep_prob_candidates = [0.5, 0.8, 1.0]
91num_heads_candidates = [1, 2]
92representation_combination_candidates = ["concatenation", "addition"]
93reset_weights_each_eval_step_candidates = [True]
94
95best_valid_metric = 1e128
96best_hparams = []
97if FLAGS.display_all_models:
98all_val_mae = []
99all_test_mae = []
100all_val_mape = []
101all_test_mape = []
102all_val_wmape = []
103all_test_wmape = []
104all_val_mse = []
105all_test_mse = []
106
107for ni in range(FLAGS.num_trials):
108# Try setting the random seed each trial so that we tend to get repeatable
109# hyper-parameters. If you add one at the end the previous ones should
110# remain unchanged.
111model_utils.set_seed(FLAGS.seed + ni)
112
113chosen_hparams = {
114"batch_size":
115random.sample(batch_size_candidates, 1)[0],
116"learning_rate":
117random.sample(learning_rate_candidates, 1)[0],
118"learning_rate_adaptation":
119random.sample(learning_rate_adaptation_candidates, 1)[0],
120"num_units":
121random.sample(num_units_candidates, 1)[0],
122"iterations":
123random.sample(iterations_candidates, 1)[0],
124"num_encode":
125random.sample(num_encode_candidates, 1)[0],
126"keep_prob":
127random.sample(keep_prob_candidates, 1)[0],
128"num_heads":
129random.sample(num_heads_candidates, 1)[0],
130"representation_combination":
131random.sample(representation_combination_candidates, 1)[0],
132"reset_weights_each_eval_step":
133random.sample(reset_weights_each_eval_step_candidates, 1)[0],
134"use_nowcast_errors":
135random.sample(use_nowcast_errors_candidates, 1)[0],
136"target_index":
137dataset_params["target_index"],
138"static_index_cutoff":
139dataset_params["static_index_cutoff"],
140"display_iterations":
141250,
142"forecast_horizon":
143dataset_params["forecast_horizon"],
144"num_features":
145dataset_params["num_features"],
146"num_static":
147dataset_params["num_static"],
148"num_val_splits":
149(dataset_params["num_items"] * dataset_params["len_val"] //
150temporal_batch_size_eval),
151"num_test_splits":
152(dataset_params["num_items"] * dataset_params["len_test"] //
153temporal_batch_size_eval),
154"temporal_batch_size_eval":
155temporal_batch_size_eval,
156}
157
158model = model_library.get_model_type(
159FLAGS.model_type, chosen_hparams, loss_form="MSE")
160
161batched_train_dataset = iter(
162train_dataset.shuffle(1000).repeat(100000000).batch(
163chosen_hparams["batch_size"]))
164batched_valid_dataset = iter(
165valid_dataset.repeat(100000000).batch(temporal_batch_size_eval))
166batched_test_dataset = iter(
167test_dataset.repeat(100000000).batch(temporal_batch_size_eval))
168
169eval_metrics = model.run_train_eval_pipeline(batched_train_dataset,
170batched_valid_dataset,
171batched_test_dataset)
172
173if FLAGS.display_all_models and not np.isnan(eval_metrics["val_mse"][-1]):
174print("Best hyperparameter combination: ", flush=True)
175print(best_hparams, flush=True)
176
177# Select the model iteration based on the validation performance
178model_selection_index = np.argmin(eval_metrics["val_mse"])
179
180all_val_mae.append(eval_metrics["val_mae"][model_selection_index])
181all_test_mae.append(eval_metrics["test_mae"][model_selection_index])
182all_val_mape.append(eval_metrics["val_mape"][model_selection_index])
183all_test_mape.append(eval_metrics["test_mape"][model_selection_index])
184all_val_wmape.append(eval_metrics["val_wmape"][model_selection_index])
185all_test_wmape.append(eval_metrics["test_wmape"][model_selection_index])
186all_val_mse.append(eval_metrics["val_mse"][model_selection_index])
187all_test_mse.append(eval_metrics["test_mse"][model_selection_index])
188
189analyze_experiments.display_metrics(
190all_val_mse, all_test_mse, "MSE", 100,
191"figures/" + FLAGS.filename + "_all_hparam_runs_MSE.png")
192
193print("Best test mae: ", flush=True)
194print(all_test_mae[np.argmin(all_val_mae)], flush=True)
195
196print("Best test mape: ", flush=True)
197print(all_test_mape[np.argmin(all_val_mape)], flush=True)
198
199print("Best test wmape: ", flush=True)
200print(all_test_wmape[np.argmin(all_val_wmape)], flush=True)
201
202print("Best test mse: ", flush=True)
203print(all_test_mse[np.argmin(all_val_mse)], flush=True)
204
205print("Correlation: ", flush=True)
206print(str(np.corrcoef(all_val_mse, all_test_mse)[0, 1]), flush=True)
207
208print("Average val/test performance: ", flush=True)
209print(
210np.mean(np.asarray(all_val_mae) / np.asarray(all_test_mae)),
211flush=True)
212
213current_valid_metric = eval_metrics["val_mse"][model_selection_index]
214if current_valid_metric < best_valid_metric:
215
216best_hparams = chosen_hparams
217
218plt.figure()
219plt.plot(
220eval_metrics["display_iterations"],
221eval_metrics["val_mae"],
222"-r",
223label="Val")
224plt.plot(
225eval_metrics["display_iterations"],
226eval_metrics["test_mae"],
227"-b",
228label="Test")
229plt.xlabel("Iterations")
230plt.ylabel("MAE")
231plt.legend()
232plt.savefig("figures/" + FLAGS.filename + "_mae_convergence.png")
233
234plt.figure()
235plt.plot(
236eval_metrics["display_iterations"],
237eval_metrics["val_mse"],
238"-r",
239label="Val")
240plt.plot(
241eval_metrics["display_iterations"],
242eval_metrics["test_mse"],
243"-b",
244label="Test")
245plt.xlabel("Iterations")
246plt.ylabel("MSE")
247plt.legend()
248plt.savefig("figures/" + FLAGS.filename + "_mse_convergence.png")
249
250if "meta_self_adapting" in FLAGS.model_type:
251plt.figure()
252plt.plot(
253eval_metrics["display_iterations"],
254eval_metrics["val_self_adaptation"],
255"-b",
256label="Valid self-adapting")
257plt.plot(
258eval_metrics["display_iterations"],
259eval_metrics["test_self_adaptation"],
260"-r",
261label="Test self-adapting")
262plt.xlabel("Iterations")
263plt.ylabel("MSE loss")
264plt.legend()
265plt.savefig("figures/" + FLAGS.filename +
266"_self_adaptation_convergence.png")
267
268best_valid_metric = current_valid_metric
269
270
271if __name__ == "__main__":
272app.run(main)
273