google-research
142 строки · 5.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Forecast model class that contains the functions for training and evaluation for non-stationary time-series modeling."""
17
18from models import architectures
19from models import base_saf
20import tensorflow as tf
21
22
23class ForecastModel(base_saf.ForecastModel):
24"""Self adapting model that uses a shared LSTM encoder."""
25
26def __init__(self, loss_object, self_supervised_loss_object, hparams):
27super().__init__(loss_object, self_supervised_loss_object, hparams)
28
29self.optimizer = tf.keras.optimizers.Adam(
30learning_rate=hparams["learning_rate"])
31self.optimizer_adaptation = tf.keras.optimizers.SGD(
32learning_rate=hparams["learning_rate_adaptation"])
33
34# Model layers
35self.use_backcast_errors = hparams["use_backcast_errors"]
36self.encoder_architecture = architectures.LSTMEncoder(
37hparams, return_state=True)
38self.backcast_architecture = architectures.LSTMBackcast(hparams)
39self.forecast_architecture = architectures.LSTMDecoder(hparams)
40
41self._keras_layers = [
42self.backcast_architecture,
43self.forecast_architecture,
44self.encoder_architecture,
45self.optimizer_adaptation,
46]
47
48@tf.function
49def _self_adaptation_step(self,
50input_sequence,
51input_static,
52is_training=False):
53del is_training # unused
54with tf.GradientTape() as tape:
55
56# We mask half of the input window, replacing the observed values with the
57# repeatedly applied value of the first value after the mask. The
58# objective for self-adaptation is proposed as reconstruction of the
59# entire window. Without masking the reconstruction is trivial by copying
60# the input, however, with masking, the model needs to learn the structure
61# of the data for accurate backcasts.
62
63repeated_mid_sequence = tf.tile(
64tf.expand_dims(input_sequence[:, self.num_encode // 2, :], 1),
65[1, self.num_encode // 2, 1])
66padded_input_sequence = tf.concat(
67[repeated_mid_sequence, input_sequence[:, self.num_encode // 2:, :]],
68axis=1)
69
70if self.use_backcast_errors:
71augmented_input_sequence = tf.concat(
72(padded_input_sequence, tf.zeros_like(input_sequence)), axis=2)
73else:
74augmented_input_sequence = padded_input_sequence
75
76encoded_representation, encoder_states = self.encoder_architecture.forward(
77augmented_input_sequence, input_static)
78reconstructed = self.backcast_architecture.forward(
79encoded_representation, input_static, tf.zeros_like(input_sequence),
80encoder_states)
81
82loss = self.self_supervised_loss_object(input_sequence, reconstructed)
83
84adaptation_trainable_variables = (
85self.encoder_architecture.weights + self.backcast_architecture.weights)
86
87gradients = tape.gradient(loss, adaptation_trainable_variables)
88
89self.optimizer_adaptation.apply_gradients(
90zip(gradients, adaptation_trainable_variables))
91
92encoded_representation, encoder_states = self.encoder_architecture.forward(
93augmented_input_sequence, input_static)
94reconstructed = self.backcast_architecture.forward(
95encoded_representation, input_static, tf.zeros_like(input_sequence),
96encoder_states)
97
98self_adaptation_loss = self.self_supervised_loss_object(
99input_sequence, reconstructed)
100
101backcast_errors = (input_sequence - reconstructed)
102
103return self_adaptation_loss, backcast_errors
104
105@tf.function
106def train_step(self, input_sequence, input_static, target):
107self_adaptation_loss, backcast_errors = self._self_adaptation_step(
108input_sequence, input_static, is_training=False)
109
110if self.use_backcast_errors:
111input_sequence = tf.concat((input_sequence, backcast_errors), axis=2)
112
113with tf.GradientTape() as tape:
114encoded_representation, encoder_states = self.encoder_architecture.forward(
115input_sequence, input_static)
116predictions = self.forecast_architecture.forward(
117encoded_representation, input_static, self.future_features_train,
118encoder_states)
119prediction_loss = self.loss_object(target, predictions)
120
121all_trainable_weights = (
122self.encoder_architecture.trainable_weights +
123self.forecast_architecture.trainable_weights)
124gradients = tape.gradient(prediction_loss, all_trainable_weights)
125self.optimizer.apply_gradients(zip(gradients, all_trainable_weights))
126return prediction_loss, self_adaptation_loss
127
128@tf.function
129def test_step(self, input_sequence, input_static):
130self_adaptation_loss, backcast_errors = self._self_adaptation_step(
131input_sequence, input_static, is_training=False)
132
133if self.use_backcast_errors:
134input_sequence = tf.concat((input_sequence, backcast_errors), axis=2)
135
136encoded_representation, encoder_states = self.encoder_architecture.forward(
137input_sequence, input_static)
138predictions = self.forecast_architecture.forward(encoded_representation,
139input_static,
140self.future_features_eval,
141encoder_states)
142return predictions, self_adaptation_loss
143