google-research

Форк
0
/
lstm_seq2seq_saf.py 
142 строки · 5.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Forecast model class that contains the functions for training and evaluation for non-stationary time-series modeling."""
17

18
from models import architectures
19
from models import base_saf
20
import tensorflow as tf
21

22

23
class ForecastModel(base_saf.ForecastModel):
24
  """Self adapting model that uses a shared LSTM encoder."""
25

26
  def __init__(self, loss_object, self_supervised_loss_object, hparams):
27
    super().__init__(loss_object, self_supervised_loss_object, hparams)
28

29
    self.optimizer = tf.keras.optimizers.Adam(
30
        learning_rate=hparams["learning_rate"])
31
    self.optimizer_adaptation = tf.keras.optimizers.SGD(
32
        learning_rate=hparams["learning_rate_adaptation"])
33

34
    # Model layers
35
    self.use_backcast_errors = hparams["use_backcast_errors"]
36
    self.encoder_architecture = architectures.LSTMEncoder(
37
        hparams, return_state=True)
38
    self.backcast_architecture = architectures.LSTMBackcast(hparams)
39
    self.forecast_architecture = architectures.LSTMDecoder(hparams)
40

41
    self._keras_layers = [
42
        self.backcast_architecture,
43
        self.forecast_architecture,
44
        self.encoder_architecture,
45
        self.optimizer_adaptation,
46
    ]
47

48
  @tf.function
49
  def _self_adaptation_step(self,
50
                            input_sequence,
51
                            input_static,
52
                            is_training=False):
53
    del is_training  # unused
54
    with tf.GradientTape() as tape:
55

56
      # We mask half of the input window, replacing the observed values with the
57
      # repeatedly applied value of the first value after the mask. The
58
      # objective for self-adaptation is proposed as reconstruction of the
59
      # entire window. Without masking the reconstruction is trivial by copying
60
      # the input, however, with masking, the model needs to learn the structure
61
      # of the data for accurate backcasts.
62

63
      repeated_mid_sequence = tf.tile(
64
          tf.expand_dims(input_sequence[:, self.num_encode // 2, :], 1),
65
          [1, self.num_encode // 2, 1])
66
      padded_input_sequence = tf.concat(
67
          [repeated_mid_sequence, input_sequence[:, self.num_encode // 2:, :]],
68
          axis=1)
69

70
      if self.use_backcast_errors:
71
        augmented_input_sequence = tf.concat(
72
            (padded_input_sequence, tf.zeros_like(input_sequence)), axis=2)
73
      else:
74
        augmented_input_sequence = padded_input_sequence
75

76
      encoded_representation, encoder_states = self.encoder_architecture.forward(
77
          augmented_input_sequence, input_static)
78
      reconstructed = self.backcast_architecture.forward(
79
          encoded_representation, input_static, tf.zeros_like(input_sequence),
80
          encoder_states)
81

82
      loss = self.self_supervised_loss_object(input_sequence, reconstructed)
83

84
    adaptation_trainable_variables = (
85
        self.encoder_architecture.weights + self.backcast_architecture.weights)
86

87
    gradients = tape.gradient(loss, adaptation_trainable_variables)
88

89
    self.optimizer_adaptation.apply_gradients(
90
        zip(gradients, adaptation_trainable_variables))
91

92
    encoded_representation, encoder_states = self.encoder_architecture.forward(
93
        augmented_input_sequence, input_static)
94
    reconstructed = self.backcast_architecture.forward(
95
        encoded_representation, input_static, tf.zeros_like(input_sequence),
96
        encoder_states)
97

98
    self_adaptation_loss = self.self_supervised_loss_object(
99
        input_sequence, reconstructed)
100

101
    backcast_errors = (input_sequence - reconstructed)
102

103
    return self_adaptation_loss, backcast_errors
104

105
  @tf.function
106
  def train_step(self, input_sequence, input_static, target):
107
    self_adaptation_loss, backcast_errors = self._self_adaptation_step(
108
        input_sequence, input_static, is_training=False)
109

110
    if self.use_backcast_errors:
111
      input_sequence = tf.concat((input_sequence, backcast_errors), axis=2)
112

113
    with tf.GradientTape() as tape:
114
      encoded_representation, encoder_states = self.encoder_architecture.forward(
115
          input_sequence, input_static)
116
      predictions = self.forecast_architecture.forward(
117
          encoded_representation, input_static, self.future_features_train,
118
          encoder_states)
119
      prediction_loss = self.loss_object(target, predictions)
120

121
    all_trainable_weights = (
122
        self.encoder_architecture.trainable_weights +
123
        self.forecast_architecture.trainable_weights)
124
    gradients = tape.gradient(prediction_loss, all_trainable_weights)
125
    self.optimizer.apply_gradients(zip(gradients, all_trainable_weights))
126
    return prediction_loss, self_adaptation_loss
127

128
  @tf.function
129
  def test_step(self, input_sequence, input_static):
130
    self_adaptation_loss, backcast_errors = self._self_adaptation_step(
131
        input_sequence, input_static, is_training=False)
132

133
    if self.use_backcast_errors:
134
      input_sequence = tf.concat((input_sequence, backcast_errors), axis=2)
135

136
    encoded_representation, encoder_states = self.encoder_architecture.forward(
137
        input_sequence, input_static)
138
    predictions = self.forecast_architecture.forward(encoded_representation,
139
                                                     input_static,
140
                                                     self.future_features_eval,
141
                                                     encoder_states)
142
    return predictions, self_adaptation_loss
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.