google-research

Форк
0
80 строк · 2.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Forecast model class that contains the functions for training.
17

18
and evaluation for Temporal Fusion Transformers (TFT) model:
19
https://arxiv.org/pdf/1912.09363.pdf.
20
"""
21
from models import base
22
from models import losses
23
from models import tft_layers
24
import tensorflow as tf
25

26

27
class ForecastModel(base.ForecastModel):
28
  """Forecast model."""
29

30
  # pylint: disable=dangerous-default-value
31
  def __init__(self, loss_object, hparams, quantile_targets=[0.5]):
32
    super().__init__(hparams)
33

34
    self.future_features_train = tf.zeros(
35
        [self.batch_size, self.forecast_horizon, 1])
36
    self.future_features_eval = tf.zeros(
37
        [hparams["temporal_batch_size_eval"], self.forecast_horizon, 1])
38

39
    hparams["num_future_features"] = 1
40
    hparams["num_historical_features"] = hparams["num_features"]
41
    hparams["num_static_features"] = hparams["num_static"] - hparams[
42
        "static_index_cutoff"]
43

44
    self.loss_object = loss_object
45
    self.optimizer = tf.keras.optimizers.Adam(
46
        learning_rate=hparams["learning_rate"])
47

48
    # Model layers
49
    self.quantile_targets = quantile_targets
50
    tft_architecture = tft_layers.TFTModel(hparams, self.quantile_targets)
51
    self.tft_model = tft_architecture.return_baseline_model()
52

53
  @tf.function
54
  def train_step(self, input_sequence, input_static, target):
55

56
    with tf.GradientTape() as tape:
57
      predictions = self.tft_model.call(
58
          inputs=[input_sequence, self.future_features_train, input_static],
59
          training=True)
60
      if len(self.quantile_targets) == 1:
61
        loss = self.loss_object(target, predictions[:, :, 0])
62
      else:
63
        loss = losses.quantile_loss(target, predictions, self.quantile_targets)
64

65
    all_trainable_weights = (self.tft_model.trainable_weights)
66
    gradients = tape.gradient(loss, all_trainable_weights)
67
    self.optimizer.apply_gradients(zip(gradients, all_trainable_weights))
68
    return loss
69

70
  @tf.function
71
  def test_step(self, input_sequence, input_static):
72

73
    predictions = self.tft_model.call(
74
        inputs=[input_sequence, self.future_features_eval, input_static],
75
        training=False)
76

77
    if len(self.quantile_targets) == 1:
78
      return predictions[:, :, 0]
79
    else:
80
      return predictions[:, :, (len(self.quantile_targets) - 1) // 2]
81

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.