google-research
239 строк · 9.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""A base timeseries forecasting model to help reduce copy paste code."""
17
18import abc19import logging20import time21
22from models import losses23
24import numpy as np25import tensorflow as tf26
27
28class ForecastModel(abc.ABC):29"""Base time series forecasting model."""30
31def __init__(self, hparams):32
33self.batch_size = hparams["batch_size"]34self.display_iterations = hparams["display_iterations"]35self.forecast_horizon = hparams["forecast_horizon"]36self.num_iterations = hparams["iterations"]37self.num_encode = hparams["num_encode"]38self.num_features = hparams["num_features"]39self.num_test_splits = hparams["num_test_splits"]40self.num_val_splits = hparams["num_val_splits"]41self.static_index_cutoff = hparams["static_index_cutoff"]42self.target_index = hparams["target_index"]43self.temporal_batch_size_eval = hparams["temporal_batch_size_eval"]44
45self.future_features_train = tf.zeros(46[self.batch_size, self.forecast_horizon, 1])47self.future_features_eval = tf.zeros(48[self.temporal_batch_size_eval, self.forecast_horizon, 1])49
50@abc.abstractmethod51def train_step(self, input_sequence, input_static,52target):53pass54
55@abc.abstractmethod56def test_step(self, input_sequence,57input_static):58pass59
60@tf.function61def val_step(self, input_sequence,62input_static):63# By default call the test step for validation.64return self.test_step(input_sequence, input_static)65
66def run_train_eval_pipeline(self, batched_train_dataset,67batched_valid_dataset, batched_test_dataset):68"""Runs training and testing for batched time-series data."""69
70val_mae = []71val_mape = []72val_wmape = []73val_mse = []74val_q_score = []75test_mae = []76test_mape = []77test_wmape = []78test_mse = []79test_q_score = []80train_losses = []81display_iterations = []82
83val_mae_per_split = np.zeros(self.num_val_splits)84test_mae_per_split = np.zeros(self.num_test_splits)85val_mape_per_split = np.zeros(self.num_val_splits)86test_mape_per_split = np.zeros(self.num_test_splits)87val_wmape_den_per_split = np.zeros(self.num_val_splits)88test_wmape_den_per_split = np.zeros(self.num_test_splits)89val_mse_per_split = np.zeros(self.num_val_splits)90test_mse_per_split = np.zeros(self.num_test_splits)91val_q_score_den_per_split = np.zeros(self.num_val_splits)92test_q_score_den_per_split = np.zeros(self.num_test_splits)93
94total_train_time = 095for iteration in range(self.num_iterations):96
97if (iteration % self.display_iterations == 0 or98iteration == self.num_iterations - 1):99
100# Validation stage101
102for split_ind in range(self.num_val_splits):103(input_sequence_valid_batch, input_static_valid_batch,104target_valid_batch) = batched_valid_dataset.get_next()105
106# Do not use the last two static features, as they are for107# unnormalizing data.108output_shift = tf.expand_dims(input_static_valid_batch[:, -2], -1)109output_scale = tf.expand_dims(input_static_valid_batch[:, -1], -1)110input_static_valid_batch = (111input_static_valid_batch[:, :-self.static_index_cutoff])112
113# Slice the encoding window114input_sequence_valid_batch = input_sequence_valid_batch[:, -self.115num_encode:, :]116
117valid_predictions = self.val_step(input_sequence_valid_batch,118input_static_valid_batch)119
120# Apply denormalization121valid_predictions = (valid_predictions * output_scale + output_shift)122target_valid_batch = (123target_valid_batch * output_scale + output_shift)124
125val_mae_per_split[split_ind] = losses.mae_per_batch(126valid_predictions, target_valid_batch)127val_mape_per_split[split_ind] = losses.mape_per_batch(128valid_predictions, target_valid_batch)129val_wmape_den_per_split[split_ind] = tf.reduce_mean(130target_valid_batch)131val_mse_per_split[split_ind] = losses.mse_per_batch(132valid_predictions, target_valid_batch)133val_q_score_den_per_split[134split_ind] = losses.q_score_denominator_per_batch(135target_valid_batch)136
137val_mae.append(np.mean(val_mae_per_split))138val_mape.append(np.mean(val_mape_per_split))139val_wmape.append(100 * np.mean(val_mae_per_split) /140np.mean(val_wmape_den_per_split))141val_mse.append(np.mean(val_mse_per_split))142val_q_score.append(143np.mean(val_mae_per_split) / np.mean(val_q_score_den_per_split))144
145# Test stage146
147for split_ind in range(self.num_test_splits):148(input_sequence_test_batch, input_static_test_batch,149target_test_batch) = batched_test_dataset.get_next()150
151# Do not use the last two static features, as they are for152# unnormalizing data.153output_shift = tf.expand_dims(input_static_test_batch[:, -2], -1)154output_scale = tf.expand_dims(input_static_test_batch[:, -1], -1)155input_static_test_batch = input_static_test_batch[:, :-self.156static_index_cutoff]157
158# Slice the encoding window159input_sequence_test_batch = input_sequence_test_batch[:, -self160.num_encode:, :]161
162test_predictions = self.test_step(input_sequence_test_batch,163input_static_test_batch)164
165# Apply denormalization166test_predictions = (test_predictions * output_scale + output_shift)167target_test_batch = (target_test_batch * output_scale + output_shift)168
169test_mae_per_split[split_ind] = losses.mae_per_batch(170test_predictions, target_test_batch)171test_mape_per_split[split_ind] = losses.mape_per_batch(172test_predictions, target_test_batch)173test_wmape_den_per_split[split_ind] = tf.reduce_mean(174target_test_batch)175test_mse_per_split[split_ind] = losses.mse_per_batch(176test_predictions, target_test_batch)177test_q_score_den_per_split[178split_ind] = losses.q_score_denominator_per_batch(179target_test_batch)180
181test_mae.append(np.mean(test_mae_per_split))182test_mape.append(np.mean(test_mape_per_split))183test_wmape.append(100 * np.mean(test_mae_per_split) /184np.mean(test_wmape_den_per_split))185test_mse.append(np.mean(test_mse_per_split))186test_q_score.append(187np.mean(val_mae_per_split) / np.mean(test_q_score_den_per_split))188
189display_iterations.append(iteration)190
191# Early stopping condition is defined as no improvement on the192# validation scores between consecutive self.display_iterations193# iterations.194
195if len(val_mae) > 1 and val_mae[-2] < val_mae[-1]:196break197
198# Training stage199t = time.perf_counter()200(input_sequence_train_batch, input_static_train_batch,201target_train_batch) = batched_train_dataset.get_next()202
203# Do not use the last two static features, as they are for unnormalizing204# data.205input_static_train_batch = input_static_train_batch[:, :-self206.static_index_cutoff]207
208# Slice the encoding window209input_sequence_train_batch = input_sequence_train_batch[:, -self210.num_encode:, :]211
212train_loss = self.train_step(input_sequence_train_batch,213input_static_train_batch, target_train_batch)214train_losses.append(train_loss)215
216step_time = time.perf_counter() - t217if iteration > 0:218total_train_time += step_time219if (iteration % self.display_iterations == 0 or220iteration == self.num_iterations - 1):221logging.debug("Iteration %d took %0.3g seconds (ave %0.3g)", iteration,222step_time, total_train_time / max(iteration, 1))223
224evaluation_metrics = {225"train_losses": train_losses,226"display_iterations": display_iterations,227"val_mae": val_mae,228"val_mape": val_mape,229"val_wmape": val_wmape,230"val_mse": val_mse,231"val_q_score": val_q_score,232"test_mae": test_mae,233"test_mape": test_mape,234"test_wmape": test_wmape,235"test_mse": test_mse,236"test_q_score": test_q_score237}238
239return evaluation_metrics240