google-research
144 строки · 5.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""A TFT forecast model that self-adapts and uses backcast errors as features."""
17
18from models import base_saf
19from models import tft_layers
20
21import tensorflow as tf
22
23
24class ForecastModel(base_saf.ForecastModel):
25"""TFT that uses backcast errors as feature."""
26
27def __init__(self,
28loss_object,
29self_supervised_loss_object,
30hparams,
31quantile_targets=(0.5,)):
32# For now we will include all of the errors as features.
33
34self.use_backcast_errors = hparams["use_backcast_errors"]
35if self.use_backcast_errors:
36self.num_error_features = hparams["num_features"]
37else:
38self.num_error_features = 0
39
40if "num_historical_features" not in hparams:
41hparams["num_historical_features"] = (
42hparams["num_features"] + self.num_error_features)
43hparams["num_future_features"] = 1
44hparams["num_static_features"] = hparams["num_static"] - hparams[
45"static_index_cutoff"]
46
47super().__init__(loss_object, self_supervised_loss_object, hparams)
48
49# Model layers
50self.quantile_targets = list(quantile_targets)
51tft_architecture = tft_layers.TFTModel(hparams, self.quantile_targets)
52self.tft_model = tft_architecture.return_self_adapting_model()
53self._keras_layers = [self.tft_model, self.optimizer_adaptation]
54
55@tf.function
56def _self_adaptation_step(self,
57input_sequence,
58input_static,
59is_training=False):
60# We mask half of the input window, replacing the observed values with the
61# repeatedly applied value of the first value after the mask. The
62# objective for self-adaptation is proposed as reconstruction of the
63# entire window. Without masking the reconstruction is trivial by copying
64# the input, however, with masking, the model needs to learn the structure
65# of the data for accurate backcasts.
66
67repeated_mid_sequence = tf.tile(
68tf.expand_dims(input_sequence[:, self.num_encode // 2, :], 1),
69[1, self.num_encode // 2, 1])
70padded_input_sequence = tf.concat(
71[repeated_mid_sequence, input_sequence[:, self.num_encode // 2:, :]],
72axis=1)
73
74with tf.GradientTape() as tape:
75future_features = (
76self.future_features_train
77if is_training else self.future_features_eval)
78
79if self.use_backcast_errors:
80augmented_input_sequence = tf.concat(
81(padded_input_sequence, tf.zeros_like(input_sequence)), axis=2)
82else:
83augmented_input_sequence = padded_input_sequence
84backcasts, _ = self.tft_model.call(
85inputs=[augmented_input_sequence, future_features, input_static],
86training=is_training)
87if self.use_backcast_errors:
88# Remove the forecasts of the error features.
89backcasts = backcasts[:, :, :-self.num_error_features]
90
91self_adaptation_loss = self.self_supervised_loss_object(
92input_sequence, backcasts)
93
94adaptation_trainable_variables = self.tft_model.trainable_weights
95
96gradients = tape.gradient(self_adaptation_loss,
97adaptation_trainable_variables)
98
99self.optimizer_adaptation.apply_gradients(
100zip(gradients, adaptation_trainable_variables))
101
102updated_backcasts, _ = self.tft_model.call(
103inputs=[augmented_input_sequence, future_features, input_static],
104training=is_training)
105
106if self.use_backcast_errors:
107# Remove the forecasts of the error features.
108updated_backcasts = updated_backcasts[:, :, :-self.num_error_features]
109
110backcast_errors = (input_sequence - updated_backcasts)
111
112return self_adaptation_loss, backcast_errors
113
114@tf.function
115def train_step(self, input_sequence, input_static, target):
116self_adaptation_loss, backcast_errors = self._self_adaptation_step(
117input_sequence, input_static, is_training=True)
118
119if self.use_backcast_errors:
120input_sequence = tf.concat((input_sequence, backcast_errors), axis=2)
121
122with tf.GradientTape() as tape:
123_, predictions = self.tft_model.call(
124inputs=[input_sequence, self.future_features_train, input_static],
125training=True)
126prediction_loss = self.loss_object(target, predictions[:, :, 0])
127
128all_trainable_weights = self.tft_model.trainable_weights
129gradients = tape.gradient(prediction_loss, all_trainable_weights)
130self.optimizer.apply_gradients(zip(gradients, all_trainable_weights))
131return prediction_loss, self_adaptation_loss
132
133@tf.function
134def test_step(self, input_sequence, input_static):
135self_adaptation_loss, backcast_errors = self._self_adaptation_step(
136input_sequence, input_static, is_training=False)
137
138if self.use_backcast_errors:
139input_sequence = tf.concat((input_sequence, backcast_errors), axis=2)
140
141_, predictions = self.tft_model.call(
142inputs=[input_sequence, self.future_features_eval, input_static],
143training=False)
144return predictions[:, :, 0], self_adaptation_loss
145