google-research

Форк
0
239 строк · 9.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""A base timeseries forecasting model to help reduce copy paste code."""
17

18
import abc
19
import logging
20
import time
21

22
from models import losses
23

24
import numpy as np
25
import tensorflow as tf
26

27

28
class ForecastModel(abc.ABC):
29
  """Base time series forecasting model."""
30

31
  def __init__(self, hparams):
32

33
    self.batch_size = hparams["batch_size"]
34
    self.display_iterations = hparams["display_iterations"]
35
    self.forecast_horizon = hparams["forecast_horizon"]
36
    self.num_iterations = hparams["iterations"]
37
    self.num_encode = hparams["num_encode"]
38
    self.num_features = hparams["num_features"]
39
    self.num_test_splits = hparams["num_test_splits"]
40
    self.num_val_splits = hparams["num_val_splits"]
41
    self.static_index_cutoff = hparams["static_index_cutoff"]
42
    self.target_index = hparams["target_index"]
43
    self.temporal_batch_size_eval = hparams["temporal_batch_size_eval"]
44

45
    self.future_features_train = tf.zeros(
46
        [self.batch_size, self.forecast_horizon, 1])
47
    self.future_features_eval = tf.zeros(
48
        [self.temporal_batch_size_eval, self.forecast_horizon, 1])
49

50
  @abc.abstractmethod
51
  def train_step(self, input_sequence, input_static,
52
                 target):
53
    pass
54

55
  @abc.abstractmethod
56
  def test_step(self, input_sequence,
57
                input_static):
58
    pass
59

60
  @tf.function
61
  def val_step(self, input_sequence,
62
               input_static):
63
    # By default call the test step for validation.
64
    return self.test_step(input_sequence, input_static)
65

66
  def run_train_eval_pipeline(self, batched_train_dataset,
67
                              batched_valid_dataset, batched_test_dataset):
68
    """Runs training and testing for batched time-series data."""
69

70
    val_mae = []
71
    val_mape = []
72
    val_wmape = []
73
    val_mse = []
74
    val_q_score = []
75
    test_mae = []
76
    test_mape = []
77
    test_wmape = []
78
    test_mse = []
79
    test_q_score = []
80
    train_losses = []
81
    display_iterations = []
82

83
    val_mae_per_split = np.zeros(self.num_val_splits)
84
    test_mae_per_split = np.zeros(self.num_test_splits)
85
    val_mape_per_split = np.zeros(self.num_val_splits)
86
    test_mape_per_split = np.zeros(self.num_test_splits)
87
    val_wmape_den_per_split = np.zeros(self.num_val_splits)
88
    test_wmape_den_per_split = np.zeros(self.num_test_splits)
89
    val_mse_per_split = np.zeros(self.num_val_splits)
90
    test_mse_per_split = np.zeros(self.num_test_splits)
91
    val_q_score_den_per_split = np.zeros(self.num_val_splits)
92
    test_q_score_den_per_split = np.zeros(self.num_test_splits)
93

94
    total_train_time = 0
95
    for iteration in range(self.num_iterations):
96

97
      if (iteration % self.display_iterations == 0 or
98
          iteration == self.num_iterations - 1):
99

100
        # Validation stage
101

102
        for split_ind in range(self.num_val_splits):
103
          (input_sequence_valid_batch, input_static_valid_batch,
104
           target_valid_batch) = batched_valid_dataset.get_next()
105

106
          # Do not use the last two static features, as they are for
107
          # unnormalizing data.
108
          output_shift = tf.expand_dims(input_static_valid_batch[:, -2], -1)
109
          output_scale = tf.expand_dims(input_static_valid_batch[:, -1], -1)
110
          input_static_valid_batch = (
111
              input_static_valid_batch[:, :-self.static_index_cutoff])
112

113
          # Slice the encoding window
114
          input_sequence_valid_batch = input_sequence_valid_batch[:, -self.
115
                                                                  num_encode:, :]
116

117
          valid_predictions = self.val_step(input_sequence_valid_batch,
118
                                            input_static_valid_batch)
119

120
          # Apply denormalization
121
          valid_predictions = (valid_predictions * output_scale + output_shift)
122
          target_valid_batch = (
123
              target_valid_batch * output_scale + output_shift)
124

125
          val_mae_per_split[split_ind] = losses.mae_per_batch(
126
              valid_predictions, target_valid_batch)
127
          val_mape_per_split[split_ind] = losses.mape_per_batch(
128
              valid_predictions, target_valid_batch)
129
          val_wmape_den_per_split[split_ind] = tf.reduce_mean(
130
              target_valid_batch)
131
          val_mse_per_split[split_ind] = losses.mse_per_batch(
132
              valid_predictions, target_valid_batch)
133
          val_q_score_den_per_split[
134
              split_ind] = losses.q_score_denominator_per_batch(
135
                  target_valid_batch)
136

137
        val_mae.append(np.mean(val_mae_per_split))
138
        val_mape.append(np.mean(val_mape_per_split))
139
        val_wmape.append(100 * np.mean(val_mae_per_split) /
140
                         np.mean(val_wmape_den_per_split))
141
        val_mse.append(np.mean(val_mse_per_split))
142
        val_q_score.append(
143
            np.mean(val_mae_per_split) / np.mean(val_q_score_den_per_split))
144

145
        # Test stage
146

147
        for split_ind in range(self.num_test_splits):
148
          (input_sequence_test_batch, input_static_test_batch,
149
           target_test_batch) = batched_test_dataset.get_next()
150

151
          # Do not use the last two static features, as they are for
152
          # unnormalizing data.
153
          output_shift = tf.expand_dims(input_static_test_batch[:, -2], -1)
154
          output_scale = tf.expand_dims(input_static_test_batch[:, -1], -1)
155
          input_static_test_batch = input_static_test_batch[:, :-self.
156
                                                            static_index_cutoff]
157

158
          # Slice the encoding window
159
          input_sequence_test_batch = input_sequence_test_batch[:, -self
160
                                                                .num_encode:, :]
161

162
          test_predictions = self.test_step(input_sequence_test_batch,
163
                                            input_static_test_batch)
164

165
          # Apply denormalization
166
          test_predictions = (test_predictions * output_scale + output_shift)
167
          target_test_batch = (target_test_batch * output_scale + output_shift)
168

169
          test_mae_per_split[split_ind] = losses.mae_per_batch(
170
              test_predictions, target_test_batch)
171
          test_mape_per_split[split_ind] = losses.mape_per_batch(
172
              test_predictions, target_test_batch)
173
          test_wmape_den_per_split[split_ind] = tf.reduce_mean(
174
              target_test_batch)
175
          test_mse_per_split[split_ind] = losses.mse_per_batch(
176
              test_predictions, target_test_batch)
177
          test_q_score_den_per_split[
178
              split_ind] = losses.q_score_denominator_per_batch(
179
                  target_test_batch)
180

181
        test_mae.append(np.mean(test_mae_per_split))
182
        test_mape.append(np.mean(test_mape_per_split))
183
        test_wmape.append(100 * np.mean(test_mae_per_split) /
184
                          np.mean(test_wmape_den_per_split))
185
        test_mse.append(np.mean(test_mse_per_split))
186
        test_q_score.append(
187
            np.mean(val_mae_per_split) / np.mean(test_q_score_den_per_split))
188

189
        display_iterations.append(iteration)
190

191
        # Early stopping condition is defined as no improvement on the
192
        # validation scores between consecutive self.display_iterations
193
        # iterations.
194

195
        if len(val_mae) > 1 and val_mae[-2] < val_mae[-1]:
196
          break
197

198
      # Training stage
199
      t = time.perf_counter()
200
      (input_sequence_train_batch, input_static_train_batch,
201
       target_train_batch) = batched_train_dataset.get_next()
202

203
      # Do not use the last two static features, as they are for unnormalizing
204
      # data.
205
      input_static_train_batch = input_static_train_batch[:, :-self
206
                                                          .static_index_cutoff]
207

208
      # Slice the encoding window
209
      input_sequence_train_batch = input_sequence_train_batch[:, -self
210
                                                              .num_encode:, :]
211

212
      train_loss = self.train_step(input_sequence_train_batch,
213
                                   input_static_train_batch, target_train_batch)
214
      train_losses.append(train_loss)
215

216
      step_time = time.perf_counter() - t
217
      if iteration > 0:
218
        total_train_time += step_time
219
      if (iteration % self.display_iterations == 0 or
220
          iteration == self.num_iterations - 1):
221
        logging.debug("Iteration %d took %0.3g seconds (ave %0.3g)", iteration,
222
                      step_time, total_train_time / max(iteration, 1))
223

224
    evaluation_metrics = {
225
        "train_losses": train_losses,
226
        "display_iterations": display_iterations,
227
        "val_mae": val_mae,
228
        "val_mape": val_mape,
229
        "val_wmape": val_wmape,
230
        "val_mse": val_mse,
231
        "val_q_score": val_q_score,
232
        "test_mae": test_mae,
233
        "test_mape": test_mape,
234
        "test_wmape": test_wmape,
235
        "test_mse": test_mse,
236
        "test_q_score": test_q_score
237
    }
238

239
    return evaluation_metrics
240

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.