google-research
65 строк · 2.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Time-series forecasting with encoding-decoding architecture."""
17
18from models import architectures19from models import base20
21import tensorflow as tf22
23
24class ForecastModel(base.ForecastModel):25"""Baseline forecasting model based on LSTM seq2seq."""26
27def __init__(self, loss_object, hparams):28super().__init__(hparams)29
30self.loss_object = loss_object31self.optimizer = tf.keras.optimizers.Adam(32learning_rate=hparams["learning_rate"])33
34# Model layers35self.encoder_architecture = architectures.LSTMEncoder(36hparams, return_state=True)37self.forecast_architecture = architectures.LSTMDecoder(hparams)38
39@tf.function40def train_step(self, input_sequence, input_static, target):41
42with tf.GradientTape() as tape:43encoded_representation, encoder_states = self.encoder_architecture.forward(44input_sequence, input_static)45predictions = self.forecast_architecture.forward(46encoded_representation, input_static, self.future_features_train,47encoder_states)48loss = self.loss_object(target, predictions)49all_trainable_weights = (50self.encoder_architecture.trainable_weights +51self.forecast_architecture.trainable_weights)52gradients = tape.gradient(loss, all_trainable_weights)53self.optimizer.apply_gradients(zip(gradients, all_trainable_weights))54return loss55
56@tf.function57def test_step(self, input_sequence, input_static):58
59encoded_representation, encoder_states = self.encoder_architecture.forward(60input_sequence, input_static)61predictions = self.forecast_architecture.forward(encoded_representation,62input_static,63self.future_features_eval,64encoder_states)65return predictions66