google-research

Форк
0
/
lstm_seq2seq.py 
65 строк · 2.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Time-series forecasting with encoding-decoding architecture."""
17

18
from models import architectures
19
from models import base
20

21
import tensorflow as tf
22

23

24
class ForecastModel(base.ForecastModel):
25
  """Baseline forecasting model based on LSTM seq2seq."""
26

27
  def __init__(self, loss_object, hparams):
28
    super().__init__(hparams)
29

30
    self.loss_object = loss_object
31
    self.optimizer = tf.keras.optimizers.Adam(
32
        learning_rate=hparams["learning_rate"])
33

34
    # Model layers
35
    self.encoder_architecture = architectures.LSTMEncoder(
36
        hparams, return_state=True)
37
    self.forecast_architecture = architectures.LSTMDecoder(hparams)
38

39
  @tf.function
40
  def train_step(self, input_sequence, input_static, target):
41

42
    with tf.GradientTape() as tape:
43
      encoded_representation, encoder_states = self.encoder_architecture.forward(
44
          input_sequence, input_static)
45
      predictions = self.forecast_architecture.forward(
46
          encoded_representation, input_static, self.future_features_train,
47
          encoder_states)
48
      loss = self.loss_object(target, predictions)
49
    all_trainable_weights = (
50
        self.encoder_architecture.trainable_weights +
51
        self.forecast_architecture.trainable_weights)
52
    gradients = tape.gradient(loss, all_trainable_weights)
53
    self.optimizer.apply_gradients(zip(gradients, all_trainable_weights))
54
    return loss
55

56
  @tf.function
57
  def test_step(self, input_sequence, input_static):
58

59
    encoded_representation, encoder_states = self.encoder_architecture.forward(
60
        input_sequence, input_static)
61
    predictions = self.forecast_architecture.forward(encoded_representation,
62
                                                     input_static,
63
                                                     self.future_features_eval,
64
                                                     encoder_states)
65
    return predictions
66

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.