google-research

Форк
0
247 строк · 7.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""LSTM Encoder-Decoder architecture."""
17

18
# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
19
# pylint: disable=invalid-name
20

21
import random
22

23
from data_formatting.favorita_embedder import FavoritaEmbedder
24
import torch
25
from torch import nn
26

27

28
class LSTMEncoder(nn.Module):
29
  """LSTM Encoder."""
30

31
  def __init__(
32
      self,
33
      input_size,
34
      hidden_size,
35
      num_layers,
36
      batch_first=True,
37
      device=torch.device('cpu'),
38
  ):
39
    super().__init__()
40

41
    self.input_size = input_size
42
    self.hidden_size = hidden_size
43
    self.num_layers = num_layers
44
    self.batch_first = batch_first
45
    self.device = device
46
    self.hidden = None
47

48
    self.lstm = (
49
        nn.LSTM(
50
            input_size=input_size,
51
            hidden_size=hidden_size,
52
            num_layers=num_layers,
53
            batch_first=batch_first,
54
        )
55
        .float()
56
        .to(device)
57
    )
58

59
  def forward(self, x_input):
60
    # output, hidden state, cell state
61
    output, self.hidden = self.lstm(x_input)
62
    return output, self.hidden
63

64
  def init_hidden(self, batch_size):
65
    return (
66
        torch.zeros(self.num_layers, batch_size, self.hidden_size).to(
67
            self.device
68
        ),
69
        torch.zeros(self.num_layers, batch_size, self.hidden_size).to(
70
            self.device
71
        ),
72
    )
73

74

75
class LSTMDecoder(nn.Module):
76
  """LSTM Decoder."""
77

78
  def __init__(
79
      self,
80
      input_size,
81
      hidden_size,
82
      num_layers,
83
      output_size,
84
      batch_first=True,
85
      device=torch.device('cpu'),
86
  ):
87
    super().__init__()
88

89
    self.input_size = input_size
90
    self.hidden_size = hidden_size
91
    self.num_layers = num_layers
92
    self.batch_first = batch_first
93
    self.device = device
94

95
    self.lstm = (
96
        nn.LSTM(
97
            input_size=input_size,
98
            hidden_size=hidden_size,
99
            num_layers=num_layers,
100
            batch_first=batch_first,
101
        )
102
        .float()
103
        .to(device)
104
    )
105
    self.linear = nn.Linear(hidden_size, output_size).to(device)
106

107
  def forward(self, x_input, encoder_hidden_states):
108
    lstm_out, self.hidden = self.lstm(x_input.float(), encoder_hidden_states)
109
    output = self.linear(lstm_out)
110
    return output, self.hidden
111

112

113
class LstmEncoderLstmDecoder(nn.Module):
114
  """LSTM Encoder-Decoder architecture."""
115

116
  def __init__(
117
      self,
118
      input_size,
119
      hidden_size,
120
      num_layers,
121
      forecasting_horizon,
122
      training_prediction='teacher_forcing',
123
      teacher_forcing_ratio=0.5,
124
      scale01=False,
125
      device=torch.device('cpu'),
126
      target_dims=(0,),
127
      embedding_dim=None,
128
  ):
129
    """Constructor.
130

131
    Args:
132
      input_size: size of input dimension
133
      hidden_size: number of hidden units
134
      num_layers: number of layers in both the encoder and decoder
135
      forecasting_horizon: number of timepoints to forecast
136
      training_prediction: whether to use teacher_forcing, recursive, or
137
        mixed_teacher_forcing in training
138
      teacher_forcing_ratio: probability teacher forcing is used
139
      scale01: whether predictions are scaled between 0 and 1
140
      device: device to perform computations on
141
      target_dims: dimension of input corresponding to the desired target
142
      embedding_dim: size of embeddings
143
    """
144
    super().__init__()
145
    self.input_size = input_size
146
    self.hidden_size = hidden_size
147
    self.num_layers = num_layers
148
    self.forecasting_horizon = forecasting_horizon
149
    self.training_prediction = training_prediction
150
    self.teacher_forcing_ratio = teacher_forcing_ratio
151
    self.scale01 = scale01
152
    self.device = device
153
    self.target_dims = target_dims
154
    self.embedder = None
155
    if embedding_dim is not None:
156
      self.embedder = FavoritaEmbedder(embedding_dim=10, device=device).float()
157

158
    self.encoder = LSTMEncoder(
159
        input_size=input_size,
160
        hidden_size=hidden_size,
161
        num_layers=num_layers,
162
        batch_first=True,
163
        device=device,
164
    )
165

166
    self.decoder = LSTMDecoder(
167
        input_size=input_size,
168
        hidden_size=hidden_size,
169
        num_layers=num_layers,
170
        output_size=input_size,
171
        batch_first=True,
172
        device=device,
173
    )
174

175
    self.output_layer = torch.nn.Linear(input_size, len(target_dims)).float()
176
    self.sigmoid = torch.nn.Sigmoid().to(device)
177

178
  def _embed(self, tensor):
179
    embedding = self.embedder(tensor.float())
180
    embedding = embedding.permute(0, 2, 1, 3)
181
    N, T, D, E = embedding.shape
182
    embedding = embedding.reshape(N, T, D * E)
183
    return embedding
184

185
  def forward(self, batch, in_eval=False):
186
    if in_eval:
187
      inputs = batch['eval_inputs']
188
      targets = batch['eval_targets']
189
    else:
190
      inputs = batch['model_inputs']
191
      targets = batch['model_targets']
192

193
    x_input = inputs.float()
194

195
    if self.embedder is not None:
196
      x_input = self._embed(x_input)
197

198
    # pass through encoder
199
    _, enc_hidden = self.encoder(x_input)
200

201
    # prepare first input for decoder
202
    # take the last observation of the window
203
    dec_input = x_input[:, -1, :].unsqueeze(1)
204
    dec_hidden = enc_hidden
205

206
    assert x_input.shape[-1] == self.input_size
207

208
    # decoding under different modes
209
    outputs = torch.zeros(
210
        targets.shape[0], self.forecasting_horizon, self.input_size
211
    ).to(self.device)
212
    mode = self.training_prediction
213
    if mode == 'teacher_forcing':  # each sequence entirely teacher or recursive
214
      if random.random() < self.teacher_forcing_ratio:
215
        mode = 'teacher'
216
      else:
217
        mode = 'recursive'
218

219
    for t in range(self.forecasting_horizon):
220
      dec_output, dec_hidden = self.decoder(dec_input, dec_hidden)
221
      assert dec_output.shape[1] == 1
222
      outputs[:, t] = dec_output.squeeze(1)
223

224
      if mode == 'recursive':
225
        dec_input = dec_output
226
      elif mode == 'teacher':
227
        dec_input = targets[:, t, :].unsqueeze(1)
228
        if self.embedder is not None:
229
          dec_input = self._embed(dec_input.float())
230
      elif (
231
          mode == 'mixed_teacher_forcing'
232
      ):  # each sequence is a mix of teacher and recursive
233
        if random.random() < self.teacher_forcing_ratio:
234
          dec_input = targets[:, t, :].unsqueeze(1)
235
          if self.embedder is not None:
236
            dec_input = self._embed(dec_input)
237
        else:
238
          dec_input = dec_output
239

240
    if self.embedder is not None:
241
      outputs = self.output_layer(outputs.float())
242
    else:
243
      outputs = outputs[:, :, self.target_dims]
244

245
    if self.scale01:
246
      outputs = self.sigmoid(outputs)
247
    return outputs
248

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.