google-research
584 строки · 19.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""A Temporal Fusion Transformer (TFT) implementation for time series.
17
18TFT is an attention-based architecture which combines high-performance
19multi-horizon forecasting with interpretable insights into temporal dynamics.
20Please see https://arxiv.org/pdf/1912.09363.pdf for details.
21
22The code is adapted from:
23https://github.com/google-research/google-research/blob/master/tft/libs/tft_model.py
24"""
25
26import tensorflow as tf27
28
29def _dense_layer(size, activation=None, time_distributed=False, use_bias=True):30"""Returns a dense keras layer with activation.31
32Args:
33size: The output size.
34activation: The activation to be applied to the linear layer output.
35time_distributed: If True, it applies the dense layer for every temporal
36slice of an input.
37use_bias: If True, it includes the bias to the dense layer.
38"""
39dense = tf.keras.layers.Dense(size, activation=activation, use_bias=use_bias)40if time_distributed:41dense = tf.keras.layers.TimeDistributed(dense)42return dense43
44
45def _apply_gating_layer(x,46hidden_layer_size,47dropout_rate=None,48time_distributed=True,49activation=None):50"""Applies a Gated Linear Unit (GLU) to an input.51
52Args:
53x: The input to gating layer.
54hidden_layer_size: The hidden layer size of GLU.
55dropout_rate: The dropout rate to be applied to the input.
56time_distributed: If True, it applies the dense layer for every temporal
57slice of an input.
58activation: The activation to be applied for the linear layer.
59
60Returns:
61Tuple of tensors for: (GLU output, gate).
62"""
63
64if dropout_rate is not None:65x = tf.keras.layers.Dropout(dropout_rate)(x)66
67activation_layer = _dense_layer(hidden_layer_size, activation,68time_distributed)(69x)70
71gated_layer = _dense_layer(hidden_layer_size, 'sigmoid', time_distributed)(x)72
73return tf.keras.layers.Multiply()([activation_layer,74gated_layer]), gated_layer75
76
77def _add_and_norm(x):78"""Applies skip connection followed by layer normalisation.79
80Args:
81x: The list of inputs to sum for skip connection.
82
83Returns:
84A tf.tensor output from the skip and layer normalization layer.
85"""
86return tf.keras.layers.LayerNormalization()(tf.keras.layers.Add()(x))87
88
89def _gated_residual_network(x,90hidden_layer_size,91output_size=None,92dropout_rate=None,93time_distributed=True,94additional_context=None,95return_gate=False):96"""Applies the gated residual network (GRN) as defined in paper.97
98Args:
99x: The input to the GRN.
100hidden_layer_size: The hidden layer size of GRN.
101output_size: The output layer size.
102dropout_rate: The dropout rate to be applied to the input.
103time_distributed: If True, it makes output layer apply for every temporal
104slice of an input.
105additional_context: The additional context vector to use if exists.
106return_gate: If True, the function returns GLU gate for diagnostic purposes.
107Otherwise, only the GRN output is returned.
108
109Returns:
110A tuple of tensors for (GRN output, GLU gate) when return_gate is True. If
111return_gate is False, it returns tf.Tensor of GRN output.
112"""
113
114# Setup skip connection115if output_size is None:116output_size = hidden_layer_size117skip = x118else:119skip = _dense_layer(output_size, None, time_distributed)(x)120
121# Apply feedforward network122hidden = _dense_layer(hidden_layer_size, None, time_distributed)(x)123if additional_context is not None:124context_layer = _dense_layer(125hidden_layer_size,126activation=None,127time_distributed=time_distributed,128use_bias=False)129hidden = hidden + context_layer(additional_context)130hidden = tf.keras.layers.Activation('elu')(hidden)131hidden_layer = _dense_layer(132hidden_layer_size, activation=None, time_distributed=time_distributed)133
134hidden = hidden_layer(hidden)135
136gating_layer, gate = _apply_gating_layer(137hidden,138output_size,139dropout_rate=dropout_rate,140time_distributed=time_distributed,141activation=None)142
143if return_gate:144return _add_and_norm([skip, gating_layer]), gate145else:146return _add_and_norm([skip, gating_layer])147
148
149def _get_decoder_mask(self_attn_inputs, len_s):150"""Returns causal mask to apply for self-attention layer.151
152Args:
153self_attn_inputs: The inputs to self attention layer to determine mask
154shape.
155len_s: Total length of the encoder and decoder sequences.
156
157Returns:
158A tf.tensor of causal mask to apply for the attention layer.
159"""
160bs = tf.shape(self_attn_inputs)[:1]161mask = tf.cumsum(tf.eye(len_s, batch_shape=bs), 1)162return mask163
164
165class ScaledDotProductAttention(object):166"""Defines scaled dot product attention layer.167
168Attributes:
169attn_dropout_layer: The dropout layer for the attention output.
170activation: The activation for the scaled dot product attention. By default,
171it is set to softmax.
172"""
173
174def __init__(self, activation='softmax'):175self.activation = tf.keras.layers.Activation(activation)176
177def __call__(self, q, k, v, mask):178"""Applies scaled dot product attention with softmax normalization.179
180Args:
181q: The queries to the attention layer.
182k: The keys to the attention layer.
183v: The values to the attention layer.
184mask: The mask applied to the input to the softmax.
185
186Returns:
187A Tuple of layer outputs and attention weights.
188"""
189normalization_constant = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype='float32'))190q = tf.transpose(q, [1, 0, 2])191k = tf.transpose(k, [1, 0, 2])192v = tf.transpose(v, [1, 0, 2])193mask = tf.transpose(mask, [1, 0, 2])194
195attn = tf.einsum('fbd,tbd->fbt', q, k) / normalization_constant196
197attn -= tf.reduce_max(198attn + tf.math.log(mask + 1e-9), axis=2, keepdims=True)199attn = mask * tf.exp(attn)200attn = tf.math.divide(201attn,202tf.reduce_sum(attn, axis=2, keepdims=True) + 1e-9,203)204
205output = tf.einsum('fbt,tbd->fbd', attn, v)206output = tf.transpose(output, [1, 0, 2])207
208return output, attn209
210
211class InterpretableMultiHeadAttention(object):212"""Defines interpretable multi-head attention layer.213
214Attributes:
215n_head: The number of heads for attention layer.
216d_k: The key and query dimensionality per head.
217d_v: The value dimensionality.
218dropout: The dropout rate to apply
219qs_layers: The list of query layers across heads.
220ks_layers: The list of key layers across heads.
221vs_layers: The list of value layers across heads.
222attention: The scaled dot product attention layer associated with the
223output.
224w_o: The output weight matrix to project internal state to the original TFT
225state size.
226"""
227
228def __init__(self, n_head, d_model, dropout):229"""Initialises layer.230
231Args:
232n_head: The number of heads.
233d_model: The dimensionality of TFT state.
234dropout: The dropout rate to be applied to the output.
235"""
236self.n_head = n_head237self.d_k = self.d_v = d_k = d_v = d_model // n_head238self.dropout = dropout239
240# Use same value layer to facilitate interp241vs_layer = tf.keras.layers.Dense(d_v, use_bias=False)242self.qs_layers = [_dense_layer(d_k, use_bias=False) for _ in range(n_head)]243self.ks_layers = [_dense_layer(d_k, use_bias=False) for _ in range(n_head)]244self.vs_layers = [vs_layer for _ in range(n_head)]245
246self.attention = ScaledDotProductAttention()247self.w_o = tf.keras.layers.Dense(d_model, use_bias=False)248
249def __call__(self, q, k, v, mask=None):250"""Applies interpretable multihead attention.251
252Using T to denote the number of time steps fed into the transformer.
253
254Args:
255q: The query of tf.tensor with shape=(?, T, d_model).
256k: The key of tf.tensor with shape=(?, T, d_model).
257v: The value of tf.tensor with shape=(?, T, d_model).
258mask: The optional mask of tf.tensor with shape=(?, T, T). If None,
259masking is not applied for the output.
260
261Returns:
262A Tuple of (layer outputs, attention weights).
263"""
264n_head = self.n_head265
266heads = []267attns = []268for i in range(n_head):269qs = self.qs_layers[i](q)270ks = self.ks_layers[i](k)271vs = self.vs_layers[i](v)272head, attn = self.attention(qs, ks, vs, mask)273
274head_dropout = tf.keras.layers.Dropout(self.dropout)(head)275heads.append(head_dropout)276attns.append(attn)277head = tf.stack(heads) if n_head > 1 else heads[0]278attn = tf.stack(attns)279
280outputs = tf.reduce_mean(head, axis=0) if n_head > 1 else head281outputs = self.w_o(outputs)282outputs = tf.keras.layers.Dropout(self.dropout)(outputs)283
284return outputs, attn285
286
287# TFT model definitions.
288class TFTModel(object):289"""Implements Temporal Fusion Transformer."""290
291def __init__(self, hparams, quantile_targets=None):292"""Initializes TFT model."""293
294if quantile_targets is None:295quantile_targets = [0.5]296
297# Consider point forecasting298self.output_size = len(quantile_targets)299
300self.use_cudnn = False301self.hidden_layer_size = hparams['num_units']302self.forecast_horizon = hparams['forecast_horizon']303self.keep_prob = hparams['keep_prob']304
305self.num_encode = hparams['num_encode']306self.num_heads = hparams['num_heads']307self.num_historical_features = hparams['num_historical_features']308self.num_future_features = hparams['num_future_features']309self.num_static_features = hparams['num_static_features']310
311def _build_base_graph(self,312historical_inputs,313future_inputs,314static_inputs,315training=True):316"""Returns graph defining layers of the TFT."""317
318if training:319self.dropout_rate = 1.0 - self.keep_prob320else:321self.dropout_rate = 0.0322
323def _static_combine_and_mask(embedding):324"""Applies variable selection network to static inputs.325
326Args:
327embedding: Transformed static inputs.
328
329Returns:
330A tf.tensor for variable selection network.
331"""
332
333mlp_outputs = _gated_residual_network(334embedding,335self.hidden_layer_size,336output_size=self.num_static_features,337dropout_rate=self.dropout_rate,338time_distributed=False,339additional_context=None)340
341sparse_weights = tf.keras.layers.Activation('softmax')(mlp_outputs)342sparse_weights = tf.expand_dims(sparse_weights, axis=-1)343
344trans_emb_list = []345for i in range(self.num_static_features):346e = _gated_residual_network(347tf.expand_dims(embedding[:, i:i + 1], axis=-1),348self.hidden_layer_size,349output_size=self.hidden_layer_size,350dropout_rate=self.dropout_rate,351time_distributed=False)352trans_emb_list.append(e)353
354transformed_embedding = tf.concat(trans_emb_list, axis=1)355
356combined = sparse_weights * transformed_embedding357
358static_vec = tf.reduce_sum(combined, axis=1)359
360return static_vec, sparse_weights361
362def _lstm_combine_and_mask(embedding, static_context_variable_selection,363num_features):364"""Applies temporal variable selection networks.365
366Args:
367embedding: The inputs for temporal variable selection networks.
368static_context_variable_selection: The static context variable
369selection.
370num_features: Number of features.
371
372Returns:
373A Tuple of tensors that consts of temporal context, sparse weight, and
374static gate.
375"""
376
377expanded_static_context = tf.expand_dims(378static_context_variable_selection, axis=1)379
380# Variable selection weights381mlp_outputs, static_gate = _gated_residual_network(382embedding,383self.hidden_layer_size,384output_size=num_features,385dropout_rate=self.dropout_rate,386time_distributed=True,387additional_context=expanded_static_context,388return_gate=True)389
390sparse_weights = tf.keras.layers.Activation('softmax')(mlp_outputs)391
392sparse_weights = tf.expand_dims(sparse_weights, axis=2)393
394trans_emb_list = []395for i in range(num_features):396grn_output = _gated_residual_network(397embedding[:, :, i:i + 1],398self.hidden_layer_size,399output_size=self.hidden_layer_size,400dropout_rate=self.dropout_rate,401time_distributed=True)402trans_emb_list.append(grn_output)403
404transformed_embedding = tf.stack(trans_emb_list, axis=-1)405combined = tf.keras.layers.Multiply()(406[sparse_weights, transformed_embedding])407temporal_context = tf.reduce_sum(combined, axis=-1)408
409return temporal_context, sparse_weights, static_gate410
411# LSTM layer412def _get_lstm(return_state):413"""Returns LSTM cell initialized with default parameters.414
415This function builds CuDNNLSTM or LSTM depending on the self.use_cudnn.
416
417Args:
418return_state: If True, the output LSTM layer returns output and state
419when called. Otherwise, only the output is returned when called.
420
421Returns:
422A tf.Tensor for LSTM layer.
423"""
424if self.use_cudnn:425lstm = tf.keras.layers.CuDNNLSTM(426self.hidden_layer_size,427return_sequences=True,428return_state=return_state,429stateful=False,430)431else:432lstm = tf.keras.layers.LSTM(433self.hidden_layer_size,434return_sequences=True,435return_state=return_state,436stateful=False,437activation='tanh',438recurrent_activation='sigmoid',439recurrent_dropout=0,440unroll=False,441use_bias=True)442return lstm443
444static_encoder, _ = _static_combine_and_mask(static_inputs)445
446def _create_static_context():447"""Builds static contexts with the same structure and the same input."""448return _gated_residual_network(449static_encoder,450self.hidden_layer_size,451output_size=self.hidden_layer_size,452dropout_rate=self.dropout_rate,453time_distributed=False)454
455static_context_variable_selection = _create_static_context()456static_context_enrichment = _create_static_context()457static_context_state_h = _create_static_context()458static_context_state_c = _create_static_context()459
460historical_features, _, _ = _lstm_combine_and_mask(461historical_inputs, static_context_variable_selection,462self.num_historical_features)463future_features, _, _ = _lstm_combine_and_mask(464future_inputs, static_context_variable_selection,465self.num_future_features)466
467history_lstm, state_h, state_c = _get_lstm(return_state=True)(468historical_features,469initial_state=[static_context_state_h, static_context_state_c])470future_lstm = _get_lstm(return_state=False)(471future_features, initial_state=[state_h, state_c])472
473lstm_layer = tf.concat([history_lstm, future_lstm], axis=1)474
475# Apply gated skip connection476input_embeddings = tf.concat([historical_features, future_features], axis=1)477
478lstm_layer, _ = _apply_gating_layer(479lstm_layer, self.hidden_layer_size, self.dropout_rate, activation=None)480temporal_feature_layer = _add_and_norm([lstm_layer, input_embeddings])481
482# Static enrichment layers483expanded_static_context = tf.expand_dims(static_context_enrichment, axis=1)484enriched, _ = _gated_residual_network(485temporal_feature_layer,486self.hidden_layer_size,487dropout_rate=self.dropout_rate,488time_distributed=True,489additional_context=expanded_static_context,490return_gate=True)491
492# Decoder self attention493self_attn_layer = InterpretableMultiHeadAttention(494self.num_heads, self.hidden_layer_size, dropout=self.dropout_rate)495mask = _get_decoder_mask(enriched, self.num_encode + self.forecast_horizon)496x, _ = self_attn_layer(enriched, enriched, enriched, mask=mask)497x, _ = _apply_gating_layer(498x,499self.hidden_layer_size,500dropout_rate=self.dropout_rate,501activation=None)502x = _add_and_norm([x, enriched])503
504# Nonlinear processing on outputs505decoder = _gated_residual_network(506x,507self.hidden_layer_size,508dropout_rate=self.dropout_rate,509time_distributed=True)510
511# Final skip connection512decoder, _ = _apply_gating_layer(513decoder, self.hidden_layer_size, activation=None)514transformer_layer = _add_and_norm([decoder, temporal_feature_layer])515
516return transformer_layer517
518def return_baseline_model(self):519"""Returns the Keras model object for the TFT graph."""520
521# Define the input features.522past_features = tf.keras.Input(523shape=(524self.num_encode,525self.num_historical_features,526))527future_features = tf.keras.Input(528shape=(529self.forecast_horizon,530self.num_future_features,531))532static_features = tf.keras.Input(shape=(self.num_static_features,))533
534transformer_layer = self._build_base_graph(past_features, future_features,535static_features)536
537# Get the future predictions from encoded attention representations.538predictions = _dense_layer(539self.output_size, time_distributed=True)(540transformer_layer[:, -self.forecast_horizon:, :])541
542# Define the Keras model.543tft_model = tf.keras.Model(544inputs=[past_features, future_features, static_features],545outputs=predictions,546)547
548return tft_model549
550def return_self_adapting_model(self):551"""Returns the Keras model object for the TFT graph."""552
553# Define the input features.554past_features = tf.keras.Input(555shape=(556self.num_encode,557self.num_historical_features,558))559future_features = tf.keras.Input(560shape=(561self.forecast_horizon,562self.num_future_features,563))564static_features = tf.keras.Input(shape=(self.num_static_features,))565
566transformer_layer = self._build_base_graph(past_features, future_features,567static_features)568
569# Get the future predictions from encoded attention representations.570predictions = _dense_layer(571self.output_size, time_distributed=True)(572transformer_layer[:, -self.forecast_horizon:, :])573
574# Get the backcasts from encoded attention representations.575backcasts = _dense_layer(576self.num_historical_features, time_distributed=True)(577transformer_layer[:, :self.num_encode, :])578
579# Define the Keras model.580tft_model = tf.keras.Model(581inputs=[past_features, future_features, static_features],582outputs=[backcasts, predictions])583
584return tft_model585