google-research

Форк
0
/
tft_layers.py 
584 строки · 19.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""A Temporal Fusion Transformer (TFT) implementation for time series.
17

18
TFT is an attention-based architecture which combines high-performance
19
multi-horizon forecasting with interpretable insights into temporal dynamics.
20
Please see https://arxiv.org/pdf/1912.09363.pdf for details.
21

22
The code is adapted from:
23
https://github.com/google-research/google-research/blob/master/tft/libs/tft_model.py
24
"""
25

26
import tensorflow as tf
27

28

29
def _dense_layer(size, activation=None, time_distributed=False, use_bias=True):
30
  """Returns a dense keras layer with activation.
31

32
  Args:
33
    size: The output size.
34
    activation: The activation to be applied to the linear layer output.
35
    time_distributed: If True, it applies the dense layer for every temporal
36
      slice of an input.
37
    use_bias: If True, it includes the bias to the dense layer.
38
  """
39
  dense = tf.keras.layers.Dense(size, activation=activation, use_bias=use_bias)
40
  if time_distributed:
41
    dense = tf.keras.layers.TimeDistributed(dense)
42
  return dense
43

44

45
def _apply_gating_layer(x,
46
                        hidden_layer_size,
47
                        dropout_rate=None,
48
                        time_distributed=True,
49
                        activation=None):
50
  """Applies a Gated Linear Unit (GLU) to an input.
51

52
  Args:
53
    x: The input to gating layer.
54
    hidden_layer_size: The hidden layer size of GLU.
55
    dropout_rate: The dropout rate to be applied to the input.
56
    time_distributed: If True, it applies the dense layer for every temporal
57
      slice of an input.
58
    activation: The activation to be applied for the linear layer.
59

60
  Returns:
61
    Tuple of tensors for: (GLU output, gate).
62
  """
63

64
  if dropout_rate is not None:
65
    x = tf.keras.layers.Dropout(dropout_rate)(x)
66

67
  activation_layer = _dense_layer(hidden_layer_size, activation,
68
                                  time_distributed)(
69
                                      x)
70

71
  gated_layer = _dense_layer(hidden_layer_size, 'sigmoid', time_distributed)(x)
72

73
  return tf.keras.layers.Multiply()([activation_layer,
74
                                     gated_layer]), gated_layer
75

76

77
def _add_and_norm(x):
78
  """Applies skip connection followed by layer normalisation.
79

80
  Args:
81
    x: The list of inputs to sum for skip connection.
82

83
  Returns:
84
    A tf.tensor output from the skip and layer normalization layer.
85
  """
86
  return tf.keras.layers.LayerNormalization()(tf.keras.layers.Add()(x))
87

88

89
def _gated_residual_network(x,
90
                            hidden_layer_size,
91
                            output_size=None,
92
                            dropout_rate=None,
93
                            time_distributed=True,
94
                            additional_context=None,
95
                            return_gate=False):
96
  """Applies the gated residual network (GRN) as defined in paper.
97

98
  Args:
99
    x: The input to the GRN.
100
    hidden_layer_size: The hidden layer size of GRN.
101
    output_size: The output layer size.
102
    dropout_rate: The dropout rate to be applied to the input.
103
    time_distributed: If True, it makes output layer apply for every temporal
104
      slice of an input.
105
    additional_context: The additional context vector to use if exists.
106
    return_gate: If True, the function returns GLU gate for diagnostic purposes.
107
      Otherwise, only the GRN output is returned.
108

109
  Returns:
110
    A tuple of tensors for (GRN output, GLU gate) when return_gate is True. If
111
    return_gate is False, it returns tf.Tensor of GRN output.
112
  """
113

114
  # Setup skip connection
115
  if output_size is None:
116
    output_size = hidden_layer_size
117
    skip = x
118
  else:
119
    skip = _dense_layer(output_size, None, time_distributed)(x)
120

121
  # Apply feedforward network
122
  hidden = _dense_layer(hidden_layer_size, None, time_distributed)(x)
123
  if additional_context is not None:
124
    context_layer = _dense_layer(
125
        hidden_layer_size,
126
        activation=None,
127
        time_distributed=time_distributed,
128
        use_bias=False)
129
    hidden = hidden + context_layer(additional_context)
130
  hidden = tf.keras.layers.Activation('elu')(hidden)
131
  hidden_layer = _dense_layer(
132
      hidden_layer_size, activation=None, time_distributed=time_distributed)
133

134
  hidden = hidden_layer(hidden)
135

136
  gating_layer, gate = _apply_gating_layer(
137
      hidden,
138
      output_size,
139
      dropout_rate=dropout_rate,
140
      time_distributed=time_distributed,
141
      activation=None)
142

143
  if return_gate:
144
    return _add_and_norm([skip, gating_layer]), gate
145
  else:
146
    return _add_and_norm([skip, gating_layer])
147

148

149
def _get_decoder_mask(self_attn_inputs, len_s):
150
  """Returns causal mask to apply for self-attention layer.
151

152
  Args:
153
    self_attn_inputs: The inputs to self attention layer to determine mask
154
      shape.
155
    len_s: Total length of the encoder and decoder sequences.
156

157
  Returns:
158
    A tf.tensor of causal mask to apply for the attention layer.
159
  """
160
  bs = tf.shape(self_attn_inputs)[:1]
161
  mask = tf.cumsum(tf.eye(len_s, batch_shape=bs), 1)
162
  return mask
163

164

165
class ScaledDotProductAttention(object):
166
  """Defines scaled dot product attention layer.
167

168
  Attributes:
169
    attn_dropout_layer: The dropout layer for the attention output.
170
    activation: The activation for the scaled dot product attention. By default,
171
      it is set to softmax.
172
  """
173

174
  def __init__(self, activation='softmax'):
175
    self.activation = tf.keras.layers.Activation(activation)
176

177
  def __call__(self, q, k, v, mask):
178
    """Applies scaled dot product attention with softmax normalization.
179

180
    Args:
181
      q: The queries to the attention layer.
182
      k: The keys to the attention layer.
183
      v: The values to the attention layer.
184
      mask: The mask applied to the input to the softmax.
185

186
    Returns:
187
      A Tuple of layer outputs and attention weights.
188
    """
189
    normalization_constant = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype='float32'))
190
    q = tf.transpose(q, [1, 0, 2])
191
    k = tf.transpose(k, [1, 0, 2])
192
    v = tf.transpose(v, [1, 0, 2])
193
    mask = tf.transpose(mask, [1, 0, 2])
194

195
    attn = tf.einsum('fbd,tbd->fbt', q, k) / normalization_constant
196

197
    attn -= tf.reduce_max(
198
        attn + tf.math.log(mask + 1e-9), axis=2, keepdims=True)
199
    attn = mask * tf.exp(attn)
200
    attn = tf.math.divide(
201
        attn,
202
        tf.reduce_sum(attn, axis=2, keepdims=True) + 1e-9,
203
    )
204

205
    output = tf.einsum('fbt,tbd->fbd', attn, v)
206
    output = tf.transpose(output, [1, 0, 2])
207

208
    return output, attn
209

210

211
class InterpretableMultiHeadAttention(object):
212
  """Defines interpretable multi-head attention layer.
213

214
  Attributes:
215
    n_head: The number of heads for attention layer.
216
    d_k: The key and query dimensionality per head.
217
    d_v: The value dimensionality.
218
    dropout: The dropout rate to apply
219
    qs_layers: The list of query layers across heads.
220
    ks_layers: The list of key layers across heads.
221
    vs_layers: The list of value layers across heads.
222
    attention: The scaled dot product attention layer associated with the
223
      output.
224
    w_o: The output weight matrix to project internal state to the original TFT
225
      state size.
226
  """
227

228
  def __init__(self, n_head, d_model, dropout):
229
    """Initialises layer.
230

231
    Args:
232
      n_head: The number of heads.
233
      d_model: The dimensionality of TFT state.
234
      dropout: The dropout rate to be applied to the output.
235
    """
236
    self.n_head = n_head
237
    self.d_k = self.d_v = d_k = d_v = d_model // n_head
238
    self.dropout = dropout
239

240
    # Use same value layer to facilitate interp
241
    vs_layer = tf.keras.layers.Dense(d_v, use_bias=False)
242
    self.qs_layers = [_dense_layer(d_k, use_bias=False) for _ in range(n_head)]
243
    self.ks_layers = [_dense_layer(d_k, use_bias=False) for _ in range(n_head)]
244
    self.vs_layers = [vs_layer for _ in range(n_head)]
245

246
    self.attention = ScaledDotProductAttention()
247
    self.w_o = tf.keras.layers.Dense(d_model, use_bias=False)
248

249
  def __call__(self, q, k, v, mask=None):
250
    """Applies interpretable multihead attention.
251

252
    Using T to denote the number of time steps fed into the transformer.
253

254
    Args:
255
      q: The query of tf.tensor with shape=(?, T, d_model).
256
      k: The key of tf.tensor with shape=(?, T, d_model).
257
      v: The value of tf.tensor with shape=(?, T, d_model).
258
      mask: The optional mask of tf.tensor with shape=(?, T, T). If None,
259
        masking is not applied for the output.
260

261
    Returns:
262
      A Tuple of (layer outputs, attention weights).
263
    """
264
    n_head = self.n_head
265

266
    heads = []
267
    attns = []
268
    for i in range(n_head):
269
      qs = self.qs_layers[i](q)
270
      ks = self.ks_layers[i](k)
271
      vs = self.vs_layers[i](v)
272
      head, attn = self.attention(qs, ks, vs, mask)
273

274
      head_dropout = tf.keras.layers.Dropout(self.dropout)(head)
275
      heads.append(head_dropout)
276
      attns.append(attn)
277
    head = tf.stack(heads) if n_head > 1 else heads[0]
278
    attn = tf.stack(attns)
279

280
    outputs = tf.reduce_mean(head, axis=0) if n_head > 1 else head
281
    outputs = self.w_o(outputs)
282
    outputs = tf.keras.layers.Dropout(self.dropout)(outputs)
283

284
    return outputs, attn
285

286

287
# TFT model definitions.
288
class TFTModel(object):
289
  """Implements Temporal Fusion Transformer."""
290

291
  def __init__(self, hparams, quantile_targets=None):
292
    """Initializes TFT model."""
293

294
    if quantile_targets is None:
295
      quantile_targets = [0.5]
296

297
    # Consider point forecasting
298
    self.output_size = len(quantile_targets)
299

300
    self.use_cudnn = False
301
    self.hidden_layer_size = hparams['num_units']
302
    self.forecast_horizon = hparams['forecast_horizon']
303
    self.keep_prob = hparams['keep_prob']
304

305
    self.num_encode = hparams['num_encode']
306
    self.num_heads = hparams['num_heads']
307
    self.num_historical_features = hparams['num_historical_features']
308
    self.num_future_features = hparams['num_future_features']
309
    self.num_static_features = hparams['num_static_features']
310

311
  def _build_base_graph(self,
312
                        historical_inputs,
313
                        future_inputs,
314
                        static_inputs,
315
                        training=True):
316
    """Returns graph defining layers of the TFT."""
317

318
    if training:
319
      self.dropout_rate = 1.0 - self.keep_prob
320
    else:
321
      self.dropout_rate = 0.0
322

323
    def _static_combine_and_mask(embedding):
324
      """Applies variable selection network to static inputs.
325

326
      Args:
327
        embedding: Transformed static inputs.
328

329
      Returns:
330
        A tf.tensor for variable selection network.
331
      """
332

333
      mlp_outputs = _gated_residual_network(
334
          embedding,
335
          self.hidden_layer_size,
336
          output_size=self.num_static_features,
337
          dropout_rate=self.dropout_rate,
338
          time_distributed=False,
339
          additional_context=None)
340

341
      sparse_weights = tf.keras.layers.Activation('softmax')(mlp_outputs)
342
      sparse_weights = tf.expand_dims(sparse_weights, axis=-1)
343

344
      trans_emb_list = []
345
      for i in range(self.num_static_features):
346
        e = _gated_residual_network(
347
            tf.expand_dims(embedding[:, i:i + 1], axis=-1),
348
            self.hidden_layer_size,
349
            output_size=self.hidden_layer_size,
350
            dropout_rate=self.dropout_rate,
351
            time_distributed=False)
352
        trans_emb_list.append(e)
353

354
      transformed_embedding = tf.concat(trans_emb_list, axis=1)
355

356
      combined = sparse_weights * transformed_embedding
357

358
      static_vec = tf.reduce_sum(combined, axis=1)
359

360
      return static_vec, sparse_weights
361

362
    def _lstm_combine_and_mask(embedding, static_context_variable_selection,
363
                               num_features):
364
      """Applies temporal variable selection networks.
365

366
      Args:
367
        embedding: The inputs for temporal variable selection networks.
368
        static_context_variable_selection: The static context variable
369
          selection.
370
        num_features: Number of features.
371

372
      Returns:
373
        A Tuple of tensors that consts of temporal context, sparse weight, and
374
        static gate.
375
      """
376

377
      expanded_static_context = tf.expand_dims(
378
          static_context_variable_selection, axis=1)
379

380
      # Variable selection weights
381
      mlp_outputs, static_gate = _gated_residual_network(
382
          embedding,
383
          self.hidden_layer_size,
384
          output_size=num_features,
385
          dropout_rate=self.dropout_rate,
386
          time_distributed=True,
387
          additional_context=expanded_static_context,
388
          return_gate=True)
389

390
      sparse_weights = tf.keras.layers.Activation('softmax')(mlp_outputs)
391

392
      sparse_weights = tf.expand_dims(sparse_weights, axis=2)
393

394
      trans_emb_list = []
395
      for i in range(num_features):
396
        grn_output = _gated_residual_network(
397
            embedding[:, :, i:i + 1],
398
            self.hidden_layer_size,
399
            output_size=self.hidden_layer_size,
400
            dropout_rate=self.dropout_rate,
401
            time_distributed=True)
402
        trans_emb_list.append(grn_output)
403

404
      transformed_embedding = tf.stack(trans_emb_list, axis=-1)
405
      combined = tf.keras.layers.Multiply()(
406
          [sparse_weights, transformed_embedding])
407
      temporal_context = tf.reduce_sum(combined, axis=-1)
408

409
      return temporal_context, sparse_weights, static_gate
410

411
    # LSTM layer
412
    def _get_lstm(return_state):
413
      """Returns LSTM cell initialized with default parameters.
414

415
      This function builds CuDNNLSTM or LSTM depending on the self.use_cudnn.
416

417
      Args:
418
        return_state: If True, the output LSTM layer returns output and state
419
          when called. Otherwise, only the output is returned when called.
420

421
      Returns:
422
        A tf.Tensor for LSTM layer.
423
      """
424
      if self.use_cudnn:
425
        lstm = tf.keras.layers.CuDNNLSTM(
426
            self.hidden_layer_size,
427
            return_sequences=True,
428
            return_state=return_state,
429
            stateful=False,
430
        )
431
      else:
432
        lstm = tf.keras.layers.LSTM(
433
            self.hidden_layer_size,
434
            return_sequences=True,
435
            return_state=return_state,
436
            stateful=False,
437
            activation='tanh',
438
            recurrent_activation='sigmoid',
439
            recurrent_dropout=0,
440
            unroll=False,
441
            use_bias=True)
442
      return lstm
443

444
    static_encoder, _ = _static_combine_and_mask(static_inputs)
445

446
    def _create_static_context():
447
      """Builds static contexts with the same structure and the same input."""
448
      return _gated_residual_network(
449
          static_encoder,
450
          self.hidden_layer_size,
451
          output_size=self.hidden_layer_size,
452
          dropout_rate=self.dropout_rate,
453
          time_distributed=False)
454

455
    static_context_variable_selection = _create_static_context()
456
    static_context_enrichment = _create_static_context()
457
    static_context_state_h = _create_static_context()
458
    static_context_state_c = _create_static_context()
459

460
    historical_features, _, _ = _lstm_combine_and_mask(
461
        historical_inputs, static_context_variable_selection,
462
        self.num_historical_features)
463
    future_features, _, _ = _lstm_combine_and_mask(
464
        future_inputs, static_context_variable_selection,
465
        self.num_future_features)
466

467
    history_lstm, state_h, state_c = _get_lstm(return_state=True)(
468
        historical_features,
469
        initial_state=[static_context_state_h, static_context_state_c])
470
    future_lstm = _get_lstm(return_state=False)(
471
        future_features, initial_state=[state_h, state_c])
472

473
    lstm_layer = tf.concat([history_lstm, future_lstm], axis=1)
474

475
    # Apply gated skip connection
476
    input_embeddings = tf.concat([historical_features, future_features], axis=1)
477

478
    lstm_layer, _ = _apply_gating_layer(
479
        lstm_layer, self.hidden_layer_size, self.dropout_rate, activation=None)
480
    temporal_feature_layer = _add_and_norm([lstm_layer, input_embeddings])
481

482
    # Static enrichment layers
483
    expanded_static_context = tf.expand_dims(static_context_enrichment, axis=1)
484
    enriched, _ = _gated_residual_network(
485
        temporal_feature_layer,
486
        self.hidden_layer_size,
487
        dropout_rate=self.dropout_rate,
488
        time_distributed=True,
489
        additional_context=expanded_static_context,
490
        return_gate=True)
491

492
    # Decoder self attention
493
    self_attn_layer = InterpretableMultiHeadAttention(
494
        self.num_heads, self.hidden_layer_size, dropout=self.dropout_rate)
495
    mask = _get_decoder_mask(enriched, self.num_encode + self.forecast_horizon)
496
    x, _ = self_attn_layer(enriched, enriched, enriched, mask=mask)
497
    x, _ = _apply_gating_layer(
498
        x,
499
        self.hidden_layer_size,
500
        dropout_rate=self.dropout_rate,
501
        activation=None)
502
    x = _add_and_norm([x, enriched])
503

504
    # Nonlinear processing on outputs
505
    decoder = _gated_residual_network(
506
        x,
507
        self.hidden_layer_size,
508
        dropout_rate=self.dropout_rate,
509
        time_distributed=True)
510

511
    # Final skip connection
512
    decoder, _ = _apply_gating_layer(
513
        decoder, self.hidden_layer_size, activation=None)
514
    transformer_layer = _add_and_norm([decoder, temporal_feature_layer])
515

516
    return transformer_layer
517

518
  def return_baseline_model(self):
519
    """Returns the Keras model object for the TFT graph."""
520

521
    # Define the input features.
522
    past_features = tf.keras.Input(
523
        shape=(
524
            self.num_encode,
525
            self.num_historical_features,
526
        ))
527
    future_features = tf.keras.Input(
528
        shape=(
529
            self.forecast_horizon,
530
            self.num_future_features,
531
        ))
532
    static_features = tf.keras.Input(shape=(self.num_static_features,))
533

534
    transformer_layer = self._build_base_graph(past_features, future_features,
535
                                               static_features)
536

537
    # Get the future predictions from encoded attention representations.
538
    predictions = _dense_layer(
539
        self.output_size, time_distributed=True)(
540
            transformer_layer[:, -self.forecast_horizon:, :])
541

542
    # Define the Keras model.
543
    tft_model = tf.keras.Model(
544
        inputs=[past_features, future_features, static_features],
545
        outputs=predictions,
546
    )
547

548
    return tft_model
549

550
  def return_self_adapting_model(self):
551
    """Returns the Keras model object for the TFT graph."""
552

553
    # Define the input features.
554
    past_features = tf.keras.Input(
555
        shape=(
556
            self.num_encode,
557
            self.num_historical_features,
558
        ))
559
    future_features = tf.keras.Input(
560
        shape=(
561
            self.forecast_horizon,
562
            self.num_future_features,
563
        ))
564
    static_features = tf.keras.Input(shape=(self.num_static_features,))
565

566
    transformer_layer = self._build_base_graph(past_features, future_features,
567
                                               static_features)
568

569
    # Get the future predictions from encoded attention representations.
570
    predictions = _dense_layer(
571
        self.output_size, time_distributed=True)(
572
            transformer_layer[:, -self.forecast_horizon:, :])
573

574
    # Get the backcasts from encoded attention representations.
575
    backcasts = _dense_layer(
576
        self.num_historical_features, time_distributed=True)(
577
            transformer_layer[:, :self.num_encode, :])
578

579
    # Define the Keras model.
580
    tft_model = tf.keras.Model(
581
        inputs=[past_features, future_features, static_features],
582
        outputs=[backcasts, predictions])
583

584
    return tft_model
585

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.