google-research

Форк
0
144 строки · 5.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""A TFT forecast model that self-adapts and uses backcast errors as features."""
17

18
from models import base_saf
19
from models import tft_layers
20

21
import tensorflow as tf
22

23

24
class ForecastModel(base_saf.ForecastModel):
25
  """TFT that uses backcast errors as feature."""
26

27
  def __init__(self,
28
               loss_object,
29
               self_supervised_loss_object,
30
               hparams,
31
               quantile_targets=(0.5,)):
32
    # For now we will include all of the errors as features.
33

34
    self.use_backcast_errors = hparams["use_backcast_errors"]
35
    if self.use_backcast_errors:
36
      self.num_error_features = hparams["num_features"]
37
    else:
38
      self.num_error_features = 0
39

40
    if "num_historical_features" not in hparams:
41
      hparams["num_historical_features"] = (
42
          hparams["num_features"] + self.num_error_features)
43
    hparams["num_future_features"] = 1
44
    hparams["num_static_features"] = hparams["num_static"] - hparams[
45
        "static_index_cutoff"]
46

47
    super().__init__(loss_object, self_supervised_loss_object, hparams)
48

49
    # Model layers
50
    self.quantile_targets = list(quantile_targets)
51
    tft_architecture = tft_layers.TFTModel(hparams, self.quantile_targets)
52
    self.tft_model = tft_architecture.return_self_adapting_model()
53
    self._keras_layers = [self.tft_model, self.optimizer_adaptation]
54

55
  @tf.function
56
  def _self_adaptation_step(self,
57
                            input_sequence,
58
                            input_static,
59
                            is_training=False):
60
    # We mask half of the input window, replacing the observed values with the
61
    # repeatedly applied value of the first value after the mask. The
62
    # objective for self-adaptation is proposed as reconstruction of the
63
    # entire window. Without masking the reconstruction is trivial by copying
64
    # the input, however, with masking, the model needs to learn the structure
65
    # of the data for accurate backcasts.
66

67
    repeated_mid_sequence = tf.tile(
68
        tf.expand_dims(input_sequence[:, self.num_encode // 2, :], 1),
69
        [1, self.num_encode // 2, 1])
70
    padded_input_sequence = tf.concat(
71
        [repeated_mid_sequence, input_sequence[:, self.num_encode // 2:, :]],
72
        axis=1)
73

74
    with tf.GradientTape() as tape:
75
      future_features = (
76
          self.future_features_train
77
          if is_training else self.future_features_eval)
78

79
      if self.use_backcast_errors:
80
        augmented_input_sequence = tf.concat(
81
            (padded_input_sequence, tf.zeros_like(input_sequence)), axis=2)
82
      else:
83
        augmented_input_sequence = padded_input_sequence
84
      backcasts, _ = self.tft_model.call(
85
          inputs=[augmented_input_sequence, future_features, input_static],
86
          training=is_training)
87
      if self.use_backcast_errors:
88
        # Remove the forecasts of the error features.
89
        backcasts = backcasts[:, :, :-self.num_error_features]
90

91
      self_adaptation_loss = self.self_supervised_loss_object(
92
          input_sequence, backcasts)
93

94
    adaptation_trainable_variables = self.tft_model.trainable_weights
95

96
    gradients = tape.gradient(self_adaptation_loss,
97
                              adaptation_trainable_variables)
98

99
    self.optimizer_adaptation.apply_gradients(
100
        zip(gradients, adaptation_trainable_variables))
101

102
    updated_backcasts, _ = self.tft_model.call(
103
        inputs=[augmented_input_sequence, future_features, input_static],
104
        training=is_training)
105

106
    if self.use_backcast_errors:
107
      # Remove the forecasts of the error features.
108
      updated_backcasts = updated_backcasts[:, :, :-self.num_error_features]
109

110
    backcast_errors = (input_sequence - updated_backcasts)
111

112
    return self_adaptation_loss, backcast_errors
113

114
  @tf.function
115
  def train_step(self, input_sequence, input_static, target):
116
    self_adaptation_loss, backcast_errors = self._self_adaptation_step(
117
        input_sequence, input_static, is_training=True)
118

119
    if self.use_backcast_errors:
120
      input_sequence = tf.concat((input_sequence, backcast_errors), axis=2)
121

122
    with tf.GradientTape() as tape:
123
      _, predictions = self.tft_model.call(
124
          inputs=[input_sequence, self.future_features_train, input_static],
125
          training=True)
126
      prediction_loss = self.loss_object(target, predictions[:, :, 0])
127

128
    all_trainable_weights = self.tft_model.trainable_weights
129
    gradients = tape.gradient(prediction_loss, all_trainable_weights)
130
    self.optimizer.apply_gradients(zip(gradients, all_trainable_weights))
131
    return prediction_loss, self_adaptation_loss
132

133
  @tf.function
134
  def test_step(self, input_sequence, input_static):
135
    self_adaptation_loss, backcast_errors = self._self_adaptation_step(
136
        input_sequence, input_static, is_training=False)
137

138
    if self.use_backcast_errors:
139
      input_sequence = tf.concat((input_sequence, backcast_errors), axis=2)
140

141
    _, predictions = self.tft_model.call(
142
        inputs=[input_sequence, self.future_features_eval, input_static],
143
        training=False)
144
    return predictions[:, :, 0], self_adaptation_loss
145

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.