google-research
80 строк · 2.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Forecast model class that contains the functions for training.
17
18and evaluation for Temporal Fusion Transformers (TFT) model:
19https://arxiv.org/pdf/1912.09363.pdf.
20"""
21from models import base
22from models import losses
23from models import tft_layers
24import tensorflow as tf
25
26
27class ForecastModel(base.ForecastModel):
28"""Forecast model."""
29
30# pylint: disable=dangerous-default-value
31def __init__(self, loss_object, hparams, quantile_targets=[0.5]):
32super().__init__(hparams)
33
34self.future_features_train = tf.zeros(
35[self.batch_size, self.forecast_horizon, 1])
36self.future_features_eval = tf.zeros(
37[hparams["temporal_batch_size_eval"], self.forecast_horizon, 1])
38
39hparams["num_future_features"] = 1
40hparams["num_historical_features"] = hparams["num_features"]
41hparams["num_static_features"] = hparams["num_static"] - hparams[
42"static_index_cutoff"]
43
44self.loss_object = loss_object
45self.optimizer = tf.keras.optimizers.Adam(
46learning_rate=hparams["learning_rate"])
47
48# Model layers
49self.quantile_targets = quantile_targets
50tft_architecture = tft_layers.TFTModel(hparams, self.quantile_targets)
51self.tft_model = tft_architecture.return_baseline_model()
52
53@tf.function
54def train_step(self, input_sequence, input_static, target):
55
56with tf.GradientTape() as tape:
57predictions = self.tft_model.call(
58inputs=[input_sequence, self.future_features_train, input_static],
59training=True)
60if len(self.quantile_targets) == 1:
61loss = self.loss_object(target, predictions[:, :, 0])
62else:
63loss = losses.quantile_loss(target, predictions, self.quantile_targets)
64
65all_trainable_weights = (self.tft_model.trainable_weights)
66gradients = tape.gradient(loss, all_trainable_weights)
67self.optimizer.apply_gradients(zip(gradients, all_trainable_weights))
68return loss
69
70@tf.function
71def test_step(self, input_sequence, input_static):
72
73predictions = self.tft_model.call(
74inputs=[input_sequence, self.future_features_eval, input_static],
75training=False)
76
77if len(self.quantile_targets) == 1:
78return predictions[:, :, 0]
79else:
80return predictions[:, :, (len(self.quantile_targets) - 1) // 2]
81