google-research
247 строк · 7.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""LSTM Encoder-Decoder architecture."""
17
18# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
19# pylint: disable=invalid-name
20
21import random22
23from data_formatting.favorita_embedder import FavoritaEmbedder24import torch25from torch import nn26
27
28class LSTMEncoder(nn.Module):29"""LSTM Encoder."""30
31def __init__(32self,33input_size,34hidden_size,35num_layers,36batch_first=True,37device=torch.device('cpu'),38):39super().__init__()40
41self.input_size = input_size42self.hidden_size = hidden_size43self.num_layers = num_layers44self.batch_first = batch_first45self.device = device46self.hidden = None47
48self.lstm = (49nn.LSTM(50input_size=input_size,51hidden_size=hidden_size,52num_layers=num_layers,53batch_first=batch_first,54)55.float()56.to(device)57)58
59def forward(self, x_input):60# output, hidden state, cell state61output, self.hidden = self.lstm(x_input)62return output, self.hidden63
64def init_hidden(self, batch_size):65return (66torch.zeros(self.num_layers, batch_size, self.hidden_size).to(67self.device68),69torch.zeros(self.num_layers, batch_size, self.hidden_size).to(70self.device71),72)73
74
75class LSTMDecoder(nn.Module):76"""LSTM Decoder."""77
78def __init__(79self,80input_size,81hidden_size,82num_layers,83output_size,84batch_first=True,85device=torch.device('cpu'),86):87super().__init__()88
89self.input_size = input_size90self.hidden_size = hidden_size91self.num_layers = num_layers92self.batch_first = batch_first93self.device = device94
95self.lstm = (96nn.LSTM(97input_size=input_size,98hidden_size=hidden_size,99num_layers=num_layers,100batch_first=batch_first,101)102.float()103.to(device)104)105self.linear = nn.Linear(hidden_size, output_size).to(device)106
107def forward(self, x_input, encoder_hidden_states):108lstm_out, self.hidden = self.lstm(x_input.float(), encoder_hidden_states)109output = self.linear(lstm_out)110return output, self.hidden111
112
113class LstmEncoderLstmDecoder(nn.Module):114"""LSTM Encoder-Decoder architecture."""115
116def __init__(117self,118input_size,119hidden_size,120num_layers,121forecasting_horizon,122training_prediction='teacher_forcing',123teacher_forcing_ratio=0.5,124scale01=False,125device=torch.device('cpu'),126target_dims=(0,),127embedding_dim=None,128):129"""Constructor.130
131Args:
132input_size: size of input dimension
133hidden_size: number of hidden units
134num_layers: number of layers in both the encoder and decoder
135forecasting_horizon: number of timepoints to forecast
136training_prediction: whether to use teacher_forcing, recursive, or
137mixed_teacher_forcing in training
138teacher_forcing_ratio: probability teacher forcing is used
139scale01: whether predictions are scaled between 0 and 1
140device: device to perform computations on
141target_dims: dimension of input corresponding to the desired target
142embedding_dim: size of embeddings
143"""
144super().__init__()145self.input_size = input_size146self.hidden_size = hidden_size147self.num_layers = num_layers148self.forecasting_horizon = forecasting_horizon149self.training_prediction = training_prediction150self.teacher_forcing_ratio = teacher_forcing_ratio151self.scale01 = scale01152self.device = device153self.target_dims = target_dims154self.embedder = None155if embedding_dim is not None:156self.embedder = FavoritaEmbedder(embedding_dim=10, device=device).float()157
158self.encoder = LSTMEncoder(159input_size=input_size,160hidden_size=hidden_size,161num_layers=num_layers,162batch_first=True,163device=device,164)165
166self.decoder = LSTMDecoder(167input_size=input_size,168hidden_size=hidden_size,169num_layers=num_layers,170output_size=input_size,171batch_first=True,172device=device,173)174
175self.output_layer = torch.nn.Linear(input_size, len(target_dims)).float()176self.sigmoid = torch.nn.Sigmoid().to(device)177
178def _embed(self, tensor):179embedding = self.embedder(tensor.float())180embedding = embedding.permute(0, 2, 1, 3)181N, T, D, E = embedding.shape182embedding = embedding.reshape(N, T, D * E)183return embedding184
185def forward(self, batch, in_eval=False):186if in_eval:187inputs = batch['eval_inputs']188targets = batch['eval_targets']189else:190inputs = batch['model_inputs']191targets = batch['model_targets']192
193x_input = inputs.float()194
195if self.embedder is not None:196x_input = self._embed(x_input)197
198# pass through encoder199_, enc_hidden = self.encoder(x_input)200
201# prepare first input for decoder202# take the last observation of the window203dec_input = x_input[:, -1, :].unsqueeze(1)204dec_hidden = enc_hidden205
206assert x_input.shape[-1] == self.input_size207
208# decoding under different modes209outputs = torch.zeros(210targets.shape[0], self.forecasting_horizon, self.input_size211).to(self.device)212mode = self.training_prediction213if mode == 'teacher_forcing': # each sequence entirely teacher or recursive214if random.random() < self.teacher_forcing_ratio:215mode = 'teacher'216else:217mode = 'recursive'218
219for t in range(self.forecasting_horizon):220dec_output, dec_hidden = self.decoder(dec_input, dec_hidden)221assert dec_output.shape[1] == 1222outputs[:, t] = dec_output.squeeze(1)223
224if mode == 'recursive':225dec_input = dec_output226elif mode == 'teacher':227dec_input = targets[:, t, :].unsqueeze(1)228if self.embedder is not None:229dec_input = self._embed(dec_input.float())230elif (231mode == 'mixed_teacher_forcing'232): # each sequence is a mix of teacher and recursive233if random.random() < self.teacher_forcing_ratio:234dec_input = targets[:, t, :].unsqueeze(1)235if self.embedder is not None:236dec_input = self._embed(dec_input)237else:238dec_input = dec_output239
240if self.embedder is not None:241outputs = self.output_layer(outputs.float())242else:243outputs = outputs[:, :, self.target_dims]244
245if self.scale01:246outputs = self.sigmoid(outputs)247return outputs248