google-research
1066 строк · 40.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""RealFormer: https://arxiv.org/abs/2012.11747."""
17
18import collections19import copy20import json21import math22import re23import numpy as np24import six25import tensorflow.compat.v1 as tf26import tf_slim.layers as tf_slim_layers27
28
29class BertConfig(object):30"""Configuration for `BertModel`."""31
32def __init__(self,33vocab_size,34hidden_size=768,35num_hidden_layers=12,36num_attention_heads=12,37intermediate_size=3072,38hidden_act="gelu",39hidden_dropout_prob=0.1,40attention_probs_dropout_prob=0.1,41max_position_embeddings=512,42type_vocab_size=16,43initializer_range=0.02,44use_running_mean=False):45"""Constructs BertConfig.46
47Args:
48vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
49hidden_size: Size of the encoder layers and the pooler layer.
50num_hidden_layers: Number of hidden layers in the Transformer encoder.
51num_attention_heads: Number of attention heads for each attention layer in
52the Transformer encoder.
53intermediate_size: The size of the "intermediate" (i.e., feed-forward)
54layer in the Transformer encoder.
55hidden_act: The non-linear activation function (function or string) in the
56encoder and pooler.
57hidden_dropout_prob: The dropout probability for all fully connected
58layers in the embeddings, encoder, and pooler.
59attention_probs_dropout_prob: The dropout ratio for the attention
60probabilities.
61max_position_embeddings: The maximum sequence length that this model might
62ever be used with. Typically set this to something large just in case
63(e.g., 512 or 1024 or 2048).
64type_vocab_size: The vocabulary size of the `token_type_ids` passed into
65`BertModel`.
66initializer_range: The stdev of the truncated_normal_initializer for
67initializing all weight matrices.
68use_running_mean: Whether to softmax running mean instead of running
69sum of attention scores to compute attention probabilities. Running
70mean is found to be helpful for deep RealFormer models with 30+ layers.
71"""
72self.vocab_size = vocab_size73self.hidden_size = hidden_size74self.num_hidden_layers = num_hidden_layers75self.num_attention_heads = num_attention_heads76self.hidden_act = hidden_act77self.intermediate_size = intermediate_size78self.hidden_dropout_prob = hidden_dropout_prob79self.attention_probs_dropout_prob = attention_probs_dropout_prob80self.max_position_embeddings = max_position_embeddings81self.type_vocab_size = type_vocab_size82self.initializer_range = initializer_range83self.use_running_mean = use_running_mean84
85@classmethod86def from_dict(cls, json_object):87"""Constructs a `BertConfig` from a Python dictionary of parameters."""88config = cls(vocab_size=None)89for (key, value) in six.iteritems(json_object):90config.__dict__[key] = value91return config92
93@classmethod94def from_json_file(cls, json_file):95"""Constructs a `BertConfig` from a json file of parameters."""96with tf.io.gfile.GFile(json_file, "r") as reader:97text = reader.read()98return cls.from_dict(json.loads(text))99
100def to_dict(self):101"""Serializes this instance to a Python dictionary."""102output = copy.deepcopy(self.__dict__)103return output104
105def to_json_string(self):106"""Serializes this instance to a JSON string."""107return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"108
109
110class BertModel(object):111"""BERT model with RealFormer as the backbone Transformer.112
113Example usage:
114
115```python
116# Already been converted into WordPiece token ids
117input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
118input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
119token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
120
121config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
122num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
123
124model = modeling.BertModel(config=config, is_training=True,
125input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
126
127label_embeddings = tf.get_variable(...)
128pooled_output = model.get_pooled_output()
129logits = tf.matmul(pooled_output, label_embeddings)
130...
131```
132"""
133
134def __init__(self,135config,136is_training,137input_ids,138input_mask=None,139token_type_ids=None,140use_one_hot_embeddings=False,141scope=None):142"""Constructor for BertModel.143
144Args:
145config: `BertConfig` instance.
146is_training: bool. true for training model, false for eval model. Controls
147whether dropout will be applied.
148input_ids: int32 Tensor of shape [batch_size, seq_length].
149input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
150token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
151use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
152embeddings or tf.embedding_lookup() for the word embeddings.
153scope: (optional) variable scope. Defaults to "bert".
154
155Raises:
156ValueError: The config is invalid or one of the input tensor shapes
157is invalid.
158"""
159config = copy.deepcopy(config)160if not is_training:161config.hidden_dropout_prob = 0.0162config.attention_probs_dropout_prob = 0.0163
164input_shape = get_shape_list(input_ids, expected_rank=2)165batch_size = input_shape[0]166seq_length = input_shape[1]167
168if input_mask is None:169input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)170
171if token_type_ids is None:172token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)173
174with tf.variable_scope(scope, default_name="bert"):175with tf.variable_scope("embeddings"):176# Perform embedding lookup on the word ids.177(self.word_embedding_output, self.embedding_table) = embedding_lookup(178input_ids=input_ids,179vocab_size=config.vocab_size,180embedding_size=config.hidden_size,181initializer_range=config.initializer_range,182word_embedding_name="word_embeddings",183use_one_hot_embeddings=use_one_hot_embeddings)184
185# Add positional embeddings and token type embeddings, then layer186# normalize and perform dropout.187self.embedding_output = embedding_postprocessor(188input_tensor=self.word_embedding_output,189use_token_type=True,190token_type_ids=token_type_ids,191token_type_vocab_size=config.type_vocab_size,192token_type_embedding_name="token_type_embeddings",193use_position_embeddings=True,194position_embedding_name="position_embeddings",195initializer_range=config.initializer_range,196max_position_embeddings=config.max_position_embeddings,197dropout_prob=config.hidden_dropout_prob)198
199with tf.variable_scope("encoder"):200# This converts a 2D mask of shape [batch_size, seq_length] to a 3D201# mask of shape [batch_size, seq_length, seq_length] which is used202# for the attention scores.203attention_mask = create_attention_mask_from_input_mask(204input_ids, input_mask)205
206# Run the stacked transformer.207# `sequence_output` shape = [batch_size, seq_length, hidden_size].208self.all_encoder_layers = realformer_model(209input_tensor=self.embedding_output,210attention_mask=attention_mask,211hidden_size=config.hidden_size,212num_hidden_layers=config.num_hidden_layers,213num_attention_heads=config.num_attention_heads,214intermediate_size=config.intermediate_size,215intermediate_act_fn=get_activation(config.hidden_act),216hidden_dropout_prob=config.hidden_dropout_prob,217attention_probs_dropout_prob=config.attention_probs_dropout_prob,218initializer_range=config.initializer_range,219use_running_mean=config.use_running_mean,220do_return_all_layers=True)221
222self.sequence_output = self.all_encoder_layers[-1]223# The "pooler" converts the encoded sequence tensor of shape224# [batch_size, seq_length, hidden_size] to a tensor of shape225# [batch_size, hidden_size]. This is necessary for segment-level226# (or segment-pair-level) classification tasks where we need a fixed227# dimensional representation of the segment.228with tf.variable_scope("pooler"):229# We "pool" the model by simply taking the hidden state corresponding230# to the first token. We assume that this has been pre-trained231first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)232self.pooled_output = tf.layers.dense(233first_token_tensor,234config.hidden_size,235activation=tf.tanh,236kernel_initializer=create_initializer(config.initializer_range))237
238def get_pooled_output(self):239return self.pooled_output240
241def get_sequence_output(self):242"""Gets final hidden layer of encoder.243
244Returns:
245float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
246to the final hidden of the transformer encoder.
247"""
248return self.sequence_output249
250def get_all_encoder_layers(self):251return self.all_encoder_layers252
253def get_word_embedding_output(self):254"""Get output of the word(piece) embedding lookup.255
256This is BEFORE positional embeddings and token type embeddings have been
257added.
258
259Returns:
260float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
261to the output of the word(piece) embedding layer.
262"""
263return self.word_embedding_output264
265def get_embedding_output(self):266"""Gets output of the embedding lookup (i.e., input to the transformer).267
268Returns:
269float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
270to the output of the embedding layer, after summing the word
271embeddings with the positional embeddings and the token type embeddings,
272then performing layer normalization. This is the input to the transformer.
273"""
274return self.embedding_output275
276def get_embedding_table(self):277return self.embedding_table278
279
280def gelu(x):281"""Gaussian Error Linear Unit.282
283This is a smoother version of the RELU.
284Original paper: https://arxiv.org/abs/1606.08415
285Args:
286x: float Tensor to perform activation.
287
288Returns:
289`x` with the GELU activation applied.
290"""
291cdf = 0.5 * (1.0 + tf.tanh(292(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))293return x * cdf294
295
296def get_activation(activation_string):297"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.298
299Args:
300activation_string: String name of the activation function.
301
302Returns:
303A Python function corresponding to the activation function. If
304`activation_string` is None, empty, or "linear", this will return None.
305If `activation_string` is not a string, it will return `activation_string`.
306
307Raises:
308ValueError: The `activation_string` does not correspond to a known
309activation.
310"""
311
312# We assume that anything that"s not a string is already an activation313# function, so we just return it.314if not isinstance(activation_string, six.string_types):315return activation_string316
317if not activation_string:318return None319
320act = activation_string.lower()321if act == "linear":322return None323elif act == "relu":324return tf.nn.relu325elif act == "gelu":326return gelu327elif act == "tanh":328return tf.tanh329else:330raise ValueError("Unsupported activation: %s" % act)331
332
333def get_assignment_map_from_checkpoint(tvars, init_checkpoint):334"""Compute the union of the current variables and checkpoint variables."""335assignment_map = {}336initialized_variable_names = {}337
338name_to_variable = collections.OrderedDict()339for var in tvars:340name = var.name341m = re.match("^(.*):\\d+$", name)342if m is not None:343name = m.group(1)344name_to_variable[name] = var345
346init_vars = tf.train.list_variables(init_checkpoint)347
348assignment_map = collections.OrderedDict()349for x in init_vars:350(name, var) = (x[0], x[1])351if name not in name_to_variable:352continue353assignment_map[name] = name354initialized_variable_names[name] = 1355initialized_variable_names[name + ":0"] = 1356
357return (assignment_map, initialized_variable_names)358
359
360def dropout(input_tensor, dropout_prob):361"""Perform dropout.362
363Args:
364input_tensor: float Tensor.
365dropout_prob: Python float. The probability of dropping out a value (NOT of
366*keeping* a dimension as in `tf.nn.dropout`).
367
368Returns:
369A version of `input_tensor` with dropout applied.
370"""
371if dropout_prob is None or dropout_prob == 0.0:372return input_tensor373
374output = tf.nn.dropout(input_tensor, rate=dropout_prob)375return output376
377
378def layer_norm(input_tensor, name=None):379"""Run layer normalization on the last dimension of the tensor."""380return tf_slim_layers.layer_norm(381inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)382
383
384def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):385"""Runs layer normalization followed by dropout."""386output_tensor = layer_norm(input_tensor, name)387output_tensor = dropout(output_tensor, dropout_prob)388return output_tensor389
390
391def create_initializer(initializer_range=0.02):392"""Creates a `truncated_normal_initializer` with the given range."""393return tf.truncated_normal_initializer(stddev=initializer_range)394
395
396def embedding_lookup(input_ids,397vocab_size,398embedding_size=128,399initializer_range=0.02,400word_embedding_name="word_embeddings",401use_one_hot_embeddings=False):402"""Looks up words embeddings for id tensor.403
404Args:
405input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
406ids.
407vocab_size: int. Size of the embedding vocabulary.
408embedding_size: int. Width of the word embeddings.
409initializer_range: float. Embedding initialization range.
410word_embedding_name: string. Name of the embedding table.
411use_one_hot_embeddings: bool. If True, use one-hot method for word
412embeddings. If False, use `tf.nn.embedding_lookup()`.
413
414Returns:
415float Tensor of shape [batch_size, seq_length, embedding_size].
416"""
417# This function assumes that the input is of shape [batch_size, seq_length,418# num_inputs].419#420# If the input is a 2D tensor of shape [batch_size, seq_length], we421# reshape to [batch_size, seq_length, 1].422if input_ids.shape.ndims == 2:423input_ids = tf.expand_dims(input_ids, axis=[-1])424
425embedding_table = tf.get_variable(426name=word_embedding_name,427shape=[vocab_size, embedding_size],428initializer=create_initializer(initializer_range))429
430if use_one_hot_embeddings:431flat_input_ids = tf.reshape(input_ids, [-1])432one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)433output = tf.matmul(one_hot_input_ids, embedding_table)434else:435output = tf.nn.embedding_lookup(embedding_table, input_ids)436
437input_shape = get_shape_list(input_ids)438
439output = tf.reshape(output,440input_shape[0:-1] + [input_shape[-1] * embedding_size])441return (output, embedding_table)442
443
444def embedding_postprocessor(input_tensor,445use_token_type=False,446token_type_ids=None,447token_type_vocab_size=16,448token_type_embedding_name="token_type_embeddings",449use_position_embeddings=True,450position_embedding_name="position_embeddings",451initializer_range=0.02,452max_position_embeddings=512,453dropout_prob=0.1):454"""Performs various post-processing on a word embedding tensor.455
456Args:
457input_tensor: float Tensor of shape [batch_size, seq_length,
458embedding_size].
459use_token_type: bool. Whether to add embeddings for `token_type_ids`.
460token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
461Must be specified if `use_token_type` is True.
462token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
463token_type_embedding_name: string. The name of the embedding table variable
464for token type ids.
465use_position_embeddings: bool. Whether to add position embeddings for the
466position of each token in the sequence.
467position_embedding_name: string. The name of the embedding table variable
468for positional embeddings.
469initializer_range: float. Range of the weight initialization.
470max_position_embeddings: int. Maximum sequence length that might ever be
471used with this model. This can be longer than the sequence length of
472input_tensor, but cannot be shorter.
473dropout_prob: float. Dropout probability applied to the final output tensor.
474
475Returns:
476float tensor with same shape as `input_tensor`.
477
478Raises:
479ValueError: One of the tensor shapes or input values is invalid.
480"""
481input_shape = get_shape_list(input_tensor, expected_rank=3)482batch_size = input_shape[0]483seq_length = input_shape[1]484width = input_shape[2]485
486output = input_tensor487
488if use_token_type:489if token_type_ids is None:490raise ValueError("`token_type_ids` must be specified if"491"`use_token_type` is True.")492token_type_table = tf.get_variable(493name=token_type_embedding_name,494shape=[token_type_vocab_size, width],495initializer=create_initializer(initializer_range))496# This vocab will be small so we always do one-hot here, since it is always497# faster for a small vocabulary.498flat_token_type_ids = tf.reshape(token_type_ids, [-1])499one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)500token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)501token_type_embeddings = tf.reshape(token_type_embeddings,502[batch_size, seq_length, width])503output += token_type_embeddings504
505if use_position_embeddings:506# Create the variable outside the assertion to avoid TF2 compatibility507# issues.508full_position_embeddings = tf.get_variable(509name=position_embedding_name,510shape=[max_position_embeddings, width],511initializer=create_initializer(initializer_range))512
513assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)514with tf.control_dependencies([assert_op]):515# Since the position embedding table is a learned variable, we create it516# using a (long) sequence length `max_position_embeddings`. The actual517# sequence length might be shorter than this, for faster training of518# tasks that do not have long sequences.519#520# So `full_position_embeddings` is effectively an embedding table521# for position [0, 1, 2, ..., max_position_embeddings-1], and the current522# sequence has positions [0, 1, 2, ... seq_length-1], so we can just523# perform a slice.524position_embeddings = tf.slice(full_position_embeddings, [0, 0],525[seq_length, -1])526num_dims = len(output.shape.as_list())527
528# Only the last two dimensions are relevant (`seq_length` and `width`), so529# we broadcast among the first dimensions, which is typically just530# the batch size.531position_broadcast_shape = []532for _ in range(num_dims - 2):533position_broadcast_shape.append(1)534position_broadcast_shape.extend([seq_length, width])535position_embeddings = tf.reshape(position_embeddings,536position_broadcast_shape)537output += position_embeddings538
539output = layer_norm_and_dropout(output, dropout_prob)540return output541
542
543def create_attention_mask_from_input_mask(from_tensor, to_mask):544"""Create 3D attention mask from a 2D tensor mask.545
546Args:
547from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
548to_mask: int32 Tensor of shape [batch_size, to_seq_length].
549
550Returns:
551float Tensor of shape [batch_size, from_seq_length, to_seq_length].
552"""
553from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])554batch_size = from_shape[0]555from_seq_length = from_shape[1]556
557to_shape = get_shape_list(to_mask, expected_rank=2)558to_seq_length = to_shape[1]559
560to_mask = tf.cast(561tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)562
563# We don't assume that `from_tensor` is a mask (although it could be). We564# don't actually care if we attend *from* padding tokens (only *to* padding)565# tokens so we create a tensor of all ones.566#567# `broadcast_ones` = [batch_size, from_seq_length, 1]568broadcast_ones = tf.ones(569shape=[batch_size, from_seq_length, 1], dtype=tf.float32)570
571# Here we broadcast along two dimensions to create the mask.572mask = broadcast_ones * to_mask573
574return mask575
576
577def dense_layer_3d(input_tensor,578num_attention_heads,579size_per_head,580initializer,581activation,582name=None):583"""A dense layer with 3D kernel.584
585Args:
586input_tensor: float Tensor of shape [batch, seq_length, hidden_size].
587num_attention_heads: Number of attention heads.
588size_per_head: The size per attention head.
589initializer: Kernel initializer.
590activation: Actication function.
591name: The name scope of this layer.
592
593Returns:
594float logits Tensor.
595"""
596
597last_dim = get_shape_list(input_tensor)[-1]598
599with tf.variable_scope(name):600w = tf.get_variable(601name="kernel",602shape=[last_dim, num_attention_heads * size_per_head],603initializer=initializer)604w = tf.reshape(w, [last_dim, num_attention_heads, size_per_head])605b = tf.get_variable(606name="bias",607shape=[num_attention_heads * size_per_head],608initializer=tf.zeros_initializer)609b = tf.reshape(b, [num_attention_heads, size_per_head])610ret = tf.einsum("abc,cde->abde", input_tensor, w)611ret += b612if activation is not None:613return activation(ret)614else:615return ret616
617
618def dense_layer_3d_proj(input_tensor,619hidden_size,620num_attention_heads,621head_size,622initializer,623activation,624name=None):625"""A dense layer with 3D kernel for projection.626
627Args:
628input_tensor: float Tensor of shape [batch,from_seq_length,
629num_attention_heads, size_per_head].
630hidden_size: The size of hidden layer.
631num_attention_heads: The size of output dimension.
632head_size: The size of head.
633initializer: Kernel initializer.
634activation: Actication function.
635name: The name scope of this layer.
636
637Returns:
638float logits Tensor.
639"""
640with tf.variable_scope(name):641w = tf.get_variable(642name="kernel",643shape=[hidden_size, hidden_size],644initializer=initializer)645w = tf.reshape(w, [num_attention_heads, head_size, hidden_size])646b = tf.get_variable(647name="bias", shape=[hidden_size], initializer=tf.zeros_initializer)648
649ret = tf.einsum("BFNH,NHD->BFD", input_tensor, w)650ret += b651if activation is not None:652return activation(ret)653else:654return ret655
656
657def dense_layer_2d(input_tensor,658output_size,659initializer,660activation,661name=None):662"""A dense layer with 2D kernel.663
664Args:
665input_tensor: Float tensor with rank 3.
666output_size: The size of output dimension.
667initializer: Kernel initializer.
668activation: Actication function.
669name: The name scope of this layer.
670
671Returns:
672float logits Tensor.
673"""
674last_dim = get_shape_list(input_tensor)[-1]675with tf.variable_scope(name):676w = tf.get_variable(677name="kernel", shape=[last_dim, output_size], initializer=initializer)678b = tf.get_variable(679name="bias", shape=[output_size], initializer=tf.zeros_initializer)680
681ret = tf.einsum("abc,cd->abd", input_tensor, w)682ret += b683if activation is not None:684return activation(ret)685else:686return ret687
688
689def residual_attention_layer(from_tensor,690to_tensor,691attention_mask=None,692num_attention_heads=1,693size_per_head=512,694query_act=None,695key_act=None,696value_act=None,697attention_probs_dropout_prob=0.0,698initializer_range=0.02,699batch_size=None,700from_seq_length=None,701to_seq_length=None,702prev_attention=None,703num_prev_layers=0,704use_running_mean=False):705r"""Performs multi-headed attention from `from_tensor` to `to_tensor`.706
707This is an implementation of multi-headed attention based on "Attention
708is all you Need" with a residual edge added to connect attention modules in
709adjacent layers. If `from_tensor` and `to_tensor` are the same, then
710this is self-attention. Each timestep in `from_tensor` attends to the
711corresponding sequence in `to_tensor`, and returns a fixed-with vector.
712
713This function first projects `from_tensor` into a "query" tensor and
714`to_tensor` into "key" and "value" tensors. These are (effectively) a list
715of tensors of length `num_attention_heads`, where each tensor is of shape
716[batch_size, seq_length, size_per_head].
717
718Then, the query and key tensors are dot-producted and scaled. Running sum of
719these dot-products are optionally rescaled (to be running mean) before they
720are softmaxed to obtain attention probabilities. The value tensors are then
721interpolated by these probabilities, then concatenated back to a single
722tensor and returned.
723
724In practice, the multi-headed attention are done with tf.einsum as follows:
725Input_tensor: [BFD]
726Wq, Wk, Wv: [DNH]
727Q:[BFNH] = einsum('BFD,DNH->BFNH', Input_tensor, Wq)
728K:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wk)
729V:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wv)
730attention_scores:[BNFT] = einsum('BFNH,BTNH>BNFT', Q, K) / sqrt(H)
731attention_logits:[BNFT] = \sum_{l=0}^{cur} attention_scores_l
732(optional) attention_logits:[BNFT] = attention_logits / (cur + 1)
733attention_probs:[BNFT] = softmax(attention_logits)
734context_layer:[BFNH] = einsum('BNFT,BTNH->BFNH', attention_probs, V)
735Wout:[DNH]
736Output:[BFD] = einsum('BFNH,DNH>BFD', context_layer, Wout)
737
738Args:
739from_tensor: float Tensor of shape [batch_size, from_seq_length,
740from_width].
741to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
742attention_mask: (optional) int32 Tensor of shape [batch_size,
743from_seq_length, to_seq_length]. The values should be 1 or 0. The
744attention scores will effectively be set to -infinity for any positions in
745the mask that are 0, and will be unchanged for positions that are 1.
746num_attention_heads: int. Number of attention heads.
747size_per_head: int. Size of each attention head.
748query_act: (optional) Activation function for the query transform.
749key_act: (optional) Activation function for the key transform.
750value_act: (optional) Activation function for the value transform.
751attention_probs_dropout_prob: (optional) float. Dropout probability of the
752attention probabilities.
753initializer_range: float. Range of the weight initializer.
754batch_size: (Optional) int. If the input is 2D, this might be the batch size
755of the 3D version of the `from_tensor` and `to_tensor`.
756from_seq_length: (Optional) If the input is 2D, this might be the seq length
757of the 3D version of the `from_tensor`.
758to_seq_length: (Optional) If the input is 2D, this might be the seq length
759of the 3D version of the `to_tensor`.
760prev_attention: (Optional) float Tensor of shape [batch_size,
761num_attention_heads, from_seq_length, to_seq_length]. Running sum of
762attention scores before the current layer.
763num_prev_layers: int. Number of previous layers that have contributed to
764`prev_attention`.
765use_running_mean: bool. Whether to softmax running mean instead of running
766sum of attention scores to compute attention probabilities. Running mean
767is found to be helpful for deep RealFormer models with 30+ layers.
768
769Returns:
770float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
771size_per_head].
772float Tensor of shape [batch_size, num_attention_heads, from_seq_length,
773to_seq_length].
774
775Raises:
776ValueError: Any of the arguments or tensor shapes are invalid.
777"""
778from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])779to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])780
781if len(from_shape) != len(to_shape):782raise ValueError(783"The rank of `from_tensor` must match the rank of `to_tensor`.")784
785if len(from_shape) == 3:786batch_size = from_shape[0]787from_seq_length = from_shape[1]788to_seq_length = to_shape[1]789elif len(from_shape) == 2:790if (batch_size is None or from_seq_length is None or to_seq_length is None):791raise ValueError(792"When passing in rank 2 tensors to attention_layer, the values "793"for `batch_size`, `from_seq_length`, and `to_seq_length` "794"must all be specified.")795
796# Scalar dimensions referenced here:797# B = batch size (number of sequences)798# F = `from_tensor` sequence length799# T = `to_tensor` sequence length800# N = `num_attention_heads`801# H = `size_per_head`802
803# `query_layer` = [B, F, N, H]804query_layer = dense_layer_3d(from_tensor, num_attention_heads, size_per_head,805create_initializer(initializer_range), query_act,806"query")807
808# `key_layer` = [B, T, N, H]809key_layer = dense_layer_3d(to_tensor, num_attention_heads, size_per_head,810create_initializer(initializer_range), key_act,811"key")812
813# `value_layer` = [B, T, N, H]814value_layer = dense_layer_3d(to_tensor, num_attention_heads, size_per_head,815create_initializer(initializer_range), value_act,816"value")817
818# Take the dot product between "query" and "key" to get the raw819# attention scores.820attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_layer, query_layer)821attention_scores = tf.multiply(attention_scores,8221.0 / math.sqrt(float(size_per_head)))823
824cur_attention = attention_scores825if prev_attention is not None:826cur_attention += prev_attention827
828attention_logits = cur_attention829if use_running_mean:830attention_logits /= (num_prev_layers + 1.0)831
832if attention_mask is not None:833# `attention_mask` = [B, 1, F, T]834attention_mask = tf.expand_dims(attention_mask, axis=[1])835
836# Since attention_mask is 1.0 for positions we want to attend and 0.0 for837# masked positions, this operation will create a tensor which is 0.0 for838# positions we want to attend and -10000.0 for masked positions.839adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0840
841# Since we are adding it to the raw logits before the softmax, this is842# effectively the same as removing these entirely.843attention_logits += adder844
845# Normalize the attention logits to probabilities.846# `attention_probs` = [B, N, F, T]847attention_probs = tf.nn.softmax(attention_logits)848
849# This is actually dropping out entire tokens to attend to, which might850# seem a bit unusual, but is taken from the original Transformer paper.851attention_probs = dropout(attention_probs, attention_probs_dropout_prob)852
853# `context_layer` = [B, F, N, H]854context_layer = tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_layer)855
856return context_layer, cur_attention857
858
859def realformer_model(input_tensor,860attention_mask=None,861hidden_size=768,862num_hidden_layers=12,863num_attention_heads=12,864intermediate_size=3072,865intermediate_act_fn=gelu,866hidden_dropout_prob=0.1,867attention_probs_dropout_prob=0.1,868initializer_range=0.02,869use_running_mean=False,870do_return_all_layers=False):871"""Multi-headed, multi-layer RealFormer from https://arxiv.org/abs/2012.11747.872
873Args:
874input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
875attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
876seq_length], with 1 for positions that can be attended to and 0 in
877positions that should not be.
878hidden_size: int. Hidden size of the Transformer.
879num_hidden_layers: int. Number of layers (blocks) in the Transformer.
880num_attention_heads: int. Number of attention heads in the Transformer.
881intermediate_size: int. The size of the "intermediate" (a.k.a., feed
882forward) layer.
883intermediate_act_fn: function. The non-linear activation function to apply
884to the output of the intermediate/feed-forward layer.
885hidden_dropout_prob: float. Dropout probability for the hidden layers.
886attention_probs_dropout_prob: float. Dropout probability of the attention
887probabilities.
888initializer_range: float. Range of the initializer (stddev of truncated
889normal).
890use_running_mean: Whether to softmax running mean instead of running sum of
891attention scores to compute attention probabilities. Running mean is found
892to be helpful for deep RealFormer models with 30+ layers.
893do_return_all_layers: Whether to also return all layers or just the final
894layer.
895
896Returns:
897float Tensor of shape [batch_size, seq_length, hidden_size], the final
898hidden layer of the Transformer.
899
900Raises:
901ValueError: A Tensor shape or parameter is invalid.
902"""
903if hidden_size % num_attention_heads != 0:904raise ValueError(905"The hidden size (%d) is not a multiple of the number of attention "906"heads (%d)" % (hidden_size, num_attention_heads))907
908attention_head_size = int(hidden_size / num_attention_heads)909input_shape = get_shape_list(input_tensor, expected_rank=3)910input_width = input_shape[2]911
912# The Transformer performs sum residuals on all layers so the input needs913# to be the same as the hidden size.914if input_width != hidden_size:915raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %916(input_width, hidden_size))917
918prev_output = input_tensor919prev_attention = None920all_layer_outputs = []921for layer_idx in range(num_hidden_layers):922with tf.variable_scope("layer_%d" % layer_idx):923layer_input = prev_output924
925with tf.variable_scope("attention"):926with tf.variable_scope("self"):927attention_output, prev_attention = residual_attention_layer(928from_tensor=layer_input,929to_tensor=layer_input,930attention_mask=attention_mask,931num_attention_heads=num_attention_heads,932size_per_head=attention_head_size,933attention_probs_dropout_prob=attention_probs_dropout_prob,934initializer_range=initializer_range,935prev_attention=prev_attention,936num_prev_layers=layer_idx,937use_running_mean=use_running_mean)938
939# Run a linear projection of `hidden_size` then add a residual940# with `layer_input`.941with tf.variable_scope("output"):942attention_output = dense_layer_3d_proj(943attention_output, hidden_size,944num_attention_heads, attention_head_size,945create_initializer(initializer_range), None, "dense")946attention_output = dropout(attention_output, hidden_dropout_prob)947attention_output = layer_norm(attention_output + layer_input)948
949# The activation is only applied to the "intermediate" hidden layer.950with tf.variable_scope("intermediate"):951intermediate_output = dense_layer_2d(952attention_output, intermediate_size,953create_initializer(initializer_range), intermediate_act_fn, "dense")954
955# Down-project back to `hidden_size` then add the residual.956with tf.variable_scope("output"):957layer_output = dense_layer_2d(intermediate_output, hidden_size,958create_initializer(initializer_range),959None, "dense")960layer_output = dropout(layer_output, hidden_dropout_prob)961layer_output = layer_norm(layer_output + attention_output)962prev_output = layer_output963all_layer_outputs.append(layer_output)964
965if do_return_all_layers:966return all_layer_outputs967else:968return all_layer_outputs[-1]969
970
971def get_shape_list(tensor, expected_rank=None, name=None):972"""Returns a list of the shape of tensor, preferring static dimensions.973
974Args:
975tensor: A tf.Tensor object to find the shape of.
976expected_rank: (optional) int. The expected rank of `tensor`. If this is
977specified and the `tensor` has a different rank, and exception will be
978thrown.
979name: Optional name of the tensor for the error message.
980
981Returns:
982A list of dimensions of the shape of tensor. All static dimensions will
983be returned as python integers, and dynamic dimensions will be returned
984as tf.Tensor scalars.
985"""
986if name is None:987# Tensor.name is not supported in Eager mode.988if tf.executing_eagerly():989name = "get_shape_list"990else:991name = tensor.name992
993if expected_rank is not None:994assert_rank(tensor, expected_rank, name)995
996shape = tensor.shape.as_list()997
998non_static_indexes = []999for (index, dim) in enumerate(shape):1000if dim is None:1001non_static_indexes.append(index)1002
1003if not non_static_indexes:1004return shape1005
1006dyn_shape = tf.shape(tensor)1007for index in non_static_indexes:1008shape[index] = dyn_shape[index]1009return shape1010
1011
1012def reshape_to_matrix(input_tensor):1013"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""1014ndims = input_tensor.shape.ndims1015if ndims < 2:1016raise ValueError("Input tensor must have at least rank 2. Shape = %s" %1017(input_tensor.shape))1018if ndims == 2:1019return input_tensor1020
1021width = input_tensor.shape[-1]1022output_tensor = tf.reshape(input_tensor, [-1, width])1023return output_tensor1024
1025
1026def reshape_from_matrix(output_tensor, orig_shape_list):1027"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""1028if len(orig_shape_list) == 2:1029return output_tensor1030
1031output_shape = get_shape_list(output_tensor)1032
1033orig_dims = orig_shape_list[0:-1]1034width = output_shape[-1]1035
1036return tf.reshape(output_tensor, orig_dims + [width])1037
1038
1039def assert_rank(tensor, expected_rank, name=None):1040"""Raises an exception if the tensor rank is not of the expected rank.1041
1042Args:
1043tensor: A tf.Tensor to check the rank of.
1044expected_rank: Python integer or list of integers, expected rank.
1045name: Optional name of the tensor for the error message.
1046
1047Raises:
1048ValueError: If the expected shape doesn't match the actual shape.
1049"""
1050if name is None:1051name = tensor.name1052
1053expected_rank_dict = {}1054if isinstance(expected_rank, six.integer_types):1055expected_rank_dict[expected_rank] = True1056else:1057for x in expected_rank:1058expected_rank_dict[x] = True1059
1060actual_rank = tensor.shape.ndims1061if actual_rank not in expected_rank_dict:1062scope_name = tf.get_variable_scope().name1063raise ValueError(1064"For the tensor `%s` in scope `%s`, the actual rank "1065"`%d` (shape = %s) is not equal to the expected rank `%s`" %1066(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))1067