google-research
990 строк · 37.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""The main BERT model and related functions."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import collections23import copy24import json25import math26import re27import six28import tensorflow.compat.v1 as tf29from tensorflow.contrib import layers as contrib_layers30
31
32class BertConfig(object):33"""Configuration for `BertModel`."""34
35def __init__(self,36vocab_size,37hidden_size=768,38num_hidden_layers=12,39num_attention_heads=12,40intermediate_size=3072,41hidden_act="gelu",42hidden_dropout_prob=0.1,43attention_probs_dropout_prob=0.1,44max_position_embeddings=512,45type_vocab_size=16,46initializer_range=0.02):47"""Constructs BertConfig.48
49Args:
50vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
51hidden_size: Size of the encoder layers and the pooler layer.
52num_hidden_layers: Number of hidden layers in the Transformer encoder.
53num_attention_heads: Number of attention heads for each attention layer in
54the Transformer encoder.
55intermediate_size: The size of the "intermediate" (i.e., feed-forward)
56layer in the Transformer encoder.
57hidden_act: The non-linear activation function (function or string) in the
58encoder and pooler.
59hidden_dropout_prob: The dropout probability for all fully connected
60layers in the embeddings, encoder, and pooler.
61attention_probs_dropout_prob: The dropout ratio for the attention
62probabilities.
63max_position_embeddings: The maximum sequence length that this model might
64ever be used with. Typically set this to something large just in case
65(e.g., 512 or 1024 or 2048).
66type_vocab_size: The vocabulary size of the `token_type_ids` passed into
67`BertModel`.
68initializer_range: The stdev of the truncated_normal_initializer for
69initializing all weight matrices.
70"""
71self.vocab_size = vocab_size72self.hidden_size = hidden_size73self.num_hidden_layers = num_hidden_layers74self.num_attention_heads = num_attention_heads75self.hidden_act = hidden_act76self.intermediate_size = intermediate_size77self.hidden_dropout_prob = hidden_dropout_prob78self.attention_probs_dropout_prob = attention_probs_dropout_prob79self.max_position_embeddings = max_position_embeddings80self.type_vocab_size = type_vocab_size81self.initializer_range = initializer_range82
83@classmethod84def from_dict(cls, json_object):85"""Constructs a `BertConfig` from a Python dictionary of parameters."""86config = BertConfig(vocab_size=None)87for (key, value) in six.iteritems(json_object):88config.__dict__[key] = value89return config90
91@classmethod92def from_json_file(cls, json_file):93"""Constructs a `BertConfig` from a json file of parameters."""94with tf.gfile.GFile(json_file, "r") as reader:95text = reader.read()96return cls.from_dict(json.loads(text))97
98def to_dict(self):99"""Serializes this instance to a Python dictionary."""100output = copy.deepcopy(self.__dict__)101return output102
103def to_json_string(self):104"""Serializes this instance to a JSON string."""105return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"106
107
108class BertModel(object):109"""BERT model ("Bidirectional Encoder Representations from Transformers").110
111Example usage:
112
113```python
114# Already been converted into WordPiece token ids
115input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
116input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
117token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
118
119config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
120num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
121
122model = modeling.BertModel(config=config, is_training=True,
123input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
124
125label_embeddings = tf.get_variable(...)
126pooled_output = model.get_pooled_output()
127logits = tf.matmul(pooled_output, label_embeddings)
128...
129```
130"""
131
132def __init__(self,133config,134is_training,135input_ids,136input_mask=None,137token_type_ids=None,138use_one_hot_embeddings=True,139scope=None):140"""Constructor for BertModel.141
142Args:
143config: `BertConfig` instance.
144is_training: bool. true for training model, false for eval model. Controls
145whether dropout will be applied.
146input_ids: int32 Tensor of shape [batch_size, seq_length].
147input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
148token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
149use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
150embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
151it is much faster if this is True, on the CPU or GPU, it is faster if
152this is False.
153scope: (optional) variable scope. Defaults to "bert".
154
155Raises:
156ValueError: The config is invalid or one of the input tensor shapes
157is invalid.
158"""
159config = copy.deepcopy(config)160if not is_training:161config.hidden_dropout_prob = 0.0162config.attention_probs_dropout_prob = 0.0163
164input_shape = get_shape_list(input_ids, expected_rank=2)165batch_size = input_shape[0]166seq_length = input_shape[1]167
168if input_mask is None:169input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)170
171if token_type_ids is None:172token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)173
174with tf.variable_scope(scope, default_name="bert"):175with tf.variable_scope("embeddings"):176# Perform embedding lookup on the word ids.177(self.embedding_output, self.embedding_table) = embedding_lookup(178input_ids=input_ids,179vocab_size=config.vocab_size,180embedding_size=config.hidden_size,181initializer_range=config.initializer_range,182word_embedding_name="word_embeddings",183use_one_hot_embeddings=use_one_hot_embeddings)184
185# Add positional embeddings and token type embeddings, then layer186# normalize and perform dropout.187self.embedding_output = embedding_postprocessor(188input_tensor=self.embedding_output,189use_token_type=True,190token_type_ids=token_type_ids,191token_type_vocab_size=config.type_vocab_size,192token_type_embedding_name="token_type_embeddings",193use_position_embeddings=True,194position_embedding_name="position_embeddings",195initializer_range=config.initializer_range,196max_position_embeddings=config.max_position_embeddings,197dropout_prob=config.hidden_dropout_prob)198
199with tf.variable_scope("encoder"):200# This converts a 2D mask of shape [batch_size, seq_length] to a 3D201# mask of shape [batch_size, seq_length, seq_length] which is used202# for the attention scores.203attention_mask = create_attention_mask_from_input_mask(204input_ids, input_mask)205
206# Run the stacked transformer.207# `sequence_output` shape = [batch_size, seq_length, hidden_size].208self.all_encoder_layers = transformer_model(209input_tensor=self.embedding_output,210attention_mask=attention_mask,211hidden_size=config.hidden_size,212num_hidden_layers=config.num_hidden_layers,213num_attention_heads=config.num_attention_heads,214intermediate_size=config.intermediate_size,215intermediate_act_fn=get_activation(config.hidden_act),216hidden_dropout_prob=config.hidden_dropout_prob,217attention_probs_dropout_prob=config.attention_probs_dropout_prob,218initializer_range=config.initializer_range,219do_return_all_layers=True)220
221self.sequence_output = self.all_encoder_layers[-1]222# The "pooler" converts the encoded sequence tensor of shape223# [batch_size, seq_length, hidden_size] to a tensor of shape224# [batch_size, hidden_size]. This is necessary for segment-level225# (or segment-pair-level) classification tasks where we need a fixed226# dimensional representation of the segment.227with tf.variable_scope("pooler"):228# We "pool" the model by simply taking the hidden state corresponding229# to the first token. We assume that this has been pre-trained230first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)231self.pooled_output = tf.layers.dense(232first_token_tensor,233config.hidden_size,234activation=tf.tanh,235kernel_initializer=create_initializer(config.initializer_range))236
237def get_pooled_output(self):238return self.pooled_output239
240def get_sequence_output(self):241"""Gets final hidden layer of encoder.242
243Returns:
244float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
245to the final hidden of the transformer encoder.
246"""
247return self.sequence_output248
249def get_all_encoder_layers(self):250return self.all_encoder_layers251
252def get_embedding_output(self):253"""Gets output of the embedding lookup (i.e., input to the transformer).254
255Returns:
256float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
257to the output of the embedding layer, after summing the word
258embeddings with the positional embeddings and the token type embeddings,
259then performing layer normalization. This is the input to the transformer.
260"""
261return self.embedding_output262
263def get_embedding_table(self):264return self.embedding_table265
266
267def gelu(input_tensor):268"""Gaussian Error Linear Unit.269
270This is a smoother version of the RELU.
271Original paper: https://arxiv.org/abs/1606.08415
272
273Args:
274input_tensor: float Tensor to perform activation.
275
276Returns:
277`input_tensor` with the GELU activation applied.
278"""
279cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))280return input_tensor * cdf281
282
283def get_activation(activation_string):284"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.285
286Args:
287activation_string: String name of the activation function.
288
289Returns:
290A Python function corresponding to the activation function. If
291`activation_string` is None, empty, or "linear", this will return None.
292If `activation_string` is not a string, it will return `activation_string`.
293
294Raises:
295ValueError: The `activation_string` does not correspond to a known
296activation.
297"""
298
299# We assume that anything that"s not a string is already an activation300# function, so we just return it.301if not isinstance(activation_string, six.string_types):302return activation_string303
304if not activation_string:305return None306
307act = activation_string.lower()308if act == "linear":309return None310elif act == "relu":311return tf.nn.relu312elif act == "gelu":313return gelu314elif act == "tanh":315return tf.tanh316else:317raise ValueError("Unsupported activation: %s" % act)318
319
320def get_assignment_map_from_checkpoint(tvars, init_checkpoint):321"""Compute the union of the current variables and checkpoint variables."""322assignment_map = {}323initialized_variable_names = {}324
325name_to_variable = collections.OrderedDict()326for var in tvars:327name = var.name328m = re.match("^(.*):\\d+$", name)329if m is not None:330name = m.group(1)331name_to_variable[name] = var332
333init_vars = tf.train.list_variables(init_checkpoint)334
335assignment_map = collections.OrderedDict()336for x in init_vars:337(name, var) = (x[0], x[1])338if name not in name_to_variable:339continue340assignment_map[name] = name341initialized_variable_names[name] = 1342initialized_variable_names[name + ":0"] = 1343
344return (assignment_map, initialized_variable_names)345
346
347def dropout(input_tensor, dropout_prob):348"""Perform dropout.349
350Args:
351input_tensor: float Tensor.
352dropout_prob: Python float. The probability of dropping out a value (NOT of
353*keeping* a dimension as in `tf.nn.dropout`).
354
355Returns:
356A version of `input_tensor` with dropout applied.
357"""
358if dropout_prob is None or dropout_prob == 0.0:359return input_tensor360
361output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)362return output363
364
365def layer_norm(input_tensor, name=None):366"""Run layer normalization on the last dimension of the tensor."""367return contrib_layers.layer_norm(368inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)369
370
371def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):372"""Runs layer normalization followed by dropout."""373output_tensor = layer_norm(input_tensor, name)374output_tensor = dropout(output_tensor, dropout_prob)375return output_tensor376
377
378def create_initializer(initializer_range=0.02):379"""Creates a `truncated_normal_initializer` with the given range."""380return tf.truncated_normal_initializer(stddev=initializer_range)381
382
383def embedding_lookup(input_ids,384vocab_size,385embedding_size=128,386initializer_range=0.02,387word_embedding_name="word_embeddings",388use_one_hot_embeddings=False):389"""Looks up words embeddings for id tensor.390
391Args:
392input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
393ids.
394vocab_size: int. Size of the embedding vocabulary.
395embedding_size: int. Width of the word embeddings.
396initializer_range: float. Embedding initialization range.
397word_embedding_name: string. Name of the embedding table.
398use_one_hot_embeddings: bool. If True, use one-hot method for word
399embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
400for TPUs.
401
402Returns:
403float Tensor of shape [batch_size, seq_length, embedding_size].
404"""
405# This function assumes that the input is of shape [batch_size, seq_length,406# num_inputs].407#408# If the input is a 2D tensor of shape [batch_size, seq_length], we409# reshape to [batch_size, seq_length, 1].410if input_ids.shape.ndims == 2:411input_ids = tf.expand_dims(input_ids, axis=[-1])412
413embedding_table = tf.get_variable(414name=word_embedding_name,415shape=[vocab_size, embedding_size],416initializer=create_initializer(initializer_range))417
418if use_one_hot_embeddings:419flat_input_ids = tf.reshape(input_ids, [-1])420one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)421output = tf.matmul(one_hot_input_ids, embedding_table)422else:423output = tf.nn.embedding_lookup(embedding_table, input_ids)424
425input_shape = get_shape_list(input_ids)426
427output = tf.reshape(output,428input_shape[0:-1] + [input_shape[-1] * embedding_size])429return (output, embedding_table)430
431
432def embedding_postprocessor(input_tensor,433use_token_type=False,434token_type_ids=None,435token_type_vocab_size=16,436token_type_embedding_name="token_type_embeddings",437use_position_embeddings=True,438position_embedding_name="position_embeddings",439initializer_range=0.02,440max_position_embeddings=512,441dropout_prob=0.1):442"""Performs various post-processing on a word embedding tensor.443
444Args:
445input_tensor: float Tensor of shape [batch_size, seq_length,
446embedding_size].
447use_token_type: bool. Whether to add embeddings for `token_type_ids`.
448token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
449Must be specified if `use_token_type` is True.
450token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
451token_type_embedding_name: string. The name of the embedding table variable
452for token type ids.
453use_position_embeddings: bool. Whether to add position embeddings for the
454position of each token in the sequence.
455position_embedding_name: string. The name of the embedding table variable
456for positional embeddings.
457initializer_range: float. Range of the weight initialization.
458max_position_embeddings: int. Maximum sequence length that might ever be
459used with this model. This can be longer than the sequence length of
460input_tensor, but cannot be shorter.
461dropout_prob: float. Dropout probability applied to the final output tensor.
462
463Returns:
464float tensor with same shape as `input_tensor`.
465
466Raises:
467ValueError: One of the tensor shapes or input values is invalid.
468"""
469input_shape = get_shape_list(input_tensor, expected_rank=3)470batch_size = input_shape[0]471seq_length = input_shape[1]472width = input_shape[2]473
474output = input_tensor475
476if use_token_type:477if token_type_ids is None:478raise ValueError("`token_type_ids` must be specified if"479"`use_token_type` is True.")480token_type_table = tf.get_variable(481name=token_type_embedding_name,482shape=[token_type_vocab_size, width],483initializer=create_initializer(initializer_range))484# This vocab will be small so we always do one-hot here, since it is always485# faster for a small vocabulary.486flat_token_type_ids = tf.reshape(token_type_ids, [-1])487one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)488token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)489token_type_embeddings = tf.reshape(token_type_embeddings,490[batch_size, seq_length, width])491output += token_type_embeddings492
493if use_position_embeddings:494assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)495with tf.control_dependencies([assert_op]):496full_position_embeddings = tf.get_variable(497name=position_embedding_name,498shape=[max_position_embeddings, width],499initializer=create_initializer(initializer_range))500# Since the position embedding table is a learned variable, we create it501# using a (long) sequence length `max_position_embeddings`. The actual502# sequence length might be shorter than this, for faster training of503# tasks that do not have long sequences.504#505# So `full_position_embeddings` is effectively an embedding table506# for position [0, 1, 2, ..., max_position_embeddings-1], and the current507# sequence has positions [0, 1, 2, ... seq_length-1], so we can just508# perform a slice.509position_embeddings = tf.slice(full_position_embeddings, [0, 0],510[seq_length, -1])511num_dims = len(output.shape.as_list())512
513# Only the last two dimensions are relevant (`seq_length` and `width`), so514# we broadcast among the first dimensions, which is typically just515# the batch size.516position_broadcast_shape = []517for _ in range(num_dims - 2):518position_broadcast_shape.append(1)519position_broadcast_shape.extend([seq_length, width])520position_embeddings = tf.reshape(position_embeddings,521position_broadcast_shape)522output += position_embeddings523
524output = layer_norm_and_dropout(output, dropout_prob)525return output526
527
528def create_attention_mask_from_input_mask(from_tensor, to_mask):529"""Create 3D attention mask from a 2D tensor mask.530
531Args:
532from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
533to_mask: int32 Tensor of shape [batch_size, to_seq_length].
534
535Returns:
536float Tensor of shape [batch_size, from_seq_length, to_seq_length].
537"""
538from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])539batch_size = from_shape[0]540from_seq_length = from_shape[1]541
542to_shape = get_shape_list(to_mask, expected_rank=2)543to_seq_length = to_shape[1]544
545to_mask = tf.cast(546tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)547
548# We don't assume that `from_tensor` is a mask (although it could be). We549# don't actually care if we attend *from* padding tokens (only *to* padding)550# tokens so we create a tensor of all ones.551#552# `broadcast_ones` = [batch_size, from_seq_length, 1]553broadcast_ones = tf.ones(554shape=[batch_size, from_seq_length, 1], dtype=tf.float32)555
556# Here we broadcast along two dimensions to create the mask.557mask = broadcast_ones * to_mask558
559return mask560
561
562def attention_layer(from_tensor,563to_tensor,564attention_mask=None,565num_attention_heads=1,566size_per_head=512,567query_act=None,568key_act=None,569value_act=None,570attention_probs_dropout_prob=0.0,571initializer_range=0.02,572do_return_2d_tensor=False,573batch_size=None,574from_seq_length=None,575to_seq_length=None):576"""Performs multi-headed attention from `from_tensor` to `to_tensor`.577
578This is an implementation of multi-headed attention based on "Attention
579is all you Need". If `from_tensor` and `to_tensor` are the same, then
580this is self-attention. Each timestep in `from_tensor` attends to the
581corresponding sequence in `to_tensor`, and returns a fixed-with vector.
582
583This function first projects `from_tensor` into a "query" tensor and
584`to_tensor` into "key" and "value" tensors. These are (effectively) a list
585of tensors of length `num_attention_heads`, where each tensor is of shape
586[batch_size, seq_length, size_per_head].
587
588Then, the query and key tensors are dot-producted and scaled. These are
589softmaxed to obtain attention probabilities. The value tensors are then
590interpolated by these probabilities, then concatenated back to a single
591tensor and returned.
592
593In practice, the multi-headed attention are done with transposes and
594reshapes rather than actual separate tensors.
595
596Args:
597from_tensor: float Tensor of shape [batch_size, from_seq_length,
598from_width].
599to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
600attention_mask: (optional) int32 Tensor of shape [batch_size,
601from_seq_length, to_seq_length]. The values should be 1 or 0. The
602attention scores will effectively be set to -infinity for any positions in
603the mask that are 0, and will be unchanged for positions that are 1.
604num_attention_heads: int. Number of attention heads.
605size_per_head: int. Size of each attention head.
606query_act: (optional) Activation function for the query transform.
607key_act: (optional) Activation function for the key transform.
608value_act: (optional) Activation function for the value transform.
609attention_probs_dropout_prob: (optional) float. Dropout probability of the
610attention probabilities.
611initializer_range: float. Range of the weight initializer.
612do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
613* from_seq_length, num_attention_heads * size_per_head]. If False, the
614output will be of shape [batch_size, from_seq_length, num_attention_heads
615* size_per_head].
616batch_size: (Optional) int. If the input is 2D, this might be the batch size
617of the 3D version of the `from_tensor` and `to_tensor`.
618from_seq_length: (Optional) If the input is 2D, this might be the seq length
619of the 3D version of the `from_tensor`.
620to_seq_length: (Optional) If the input is 2D, this might be the seq length
621of the 3D version of the `to_tensor`.
622
623Returns:
624float Tensor of shape [batch_size, from_seq_length,
625num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
626true, this will be of shape [batch_size * from_seq_length,
627num_attention_heads * size_per_head]).
628
629Raises:
630ValueError: Any of the arguments or tensor shapes are invalid.
631"""
632
633def transpose_for_scores(input_tensor, batch_size, num_attention_heads,634seq_length, width):635output_tensor = tf.reshape(636input_tensor, [batch_size, seq_length, num_attention_heads, width])637
638output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])639return output_tensor640
641from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])642to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])643
644if len(from_shape) != len(to_shape):645raise ValueError(646"The rank of `from_tensor` must match the rank of `to_tensor`.")647
648if len(from_shape) == 3:649batch_size = from_shape[0]650from_seq_length = from_shape[1]651to_seq_length = to_shape[1]652elif len(from_shape) == 2:653if (batch_size is None or from_seq_length is None or to_seq_length is None):654raise ValueError(655"When passing in rank 2 tensors to attention_layer, the values "656"for `batch_size`, `from_seq_length`, and `to_seq_length` "657"must all be specified.")658
659# Scalar dimensions referenced here:660# B = batch size (number of sequences)661# F = `from_tensor` sequence length662# T = `to_tensor` sequence length663# N = `num_attention_heads`664# H = `size_per_head`665
666from_tensor_2d = reshape_to_matrix(from_tensor)667to_tensor_2d = reshape_to_matrix(to_tensor)668
669# `query_layer` = [B*F, N*H]670query_layer = tf.layers.dense(671from_tensor_2d,672num_attention_heads * size_per_head,673activation=query_act,674name="query",675kernel_initializer=create_initializer(initializer_range))676
677# `key_layer` = [B*T, N*H]678key_layer = tf.layers.dense(679to_tensor_2d,680num_attention_heads * size_per_head,681activation=key_act,682name="key",683kernel_initializer=create_initializer(initializer_range))684
685# `value_layer` = [B*T, N*H]686value_layer = tf.layers.dense(687to_tensor_2d,688num_attention_heads * size_per_head,689activation=value_act,690name="value",691kernel_initializer=create_initializer(initializer_range))692
693# `query_layer` = [B, N, F, H]694query_layer = transpose_for_scores(query_layer, batch_size,695num_attention_heads, from_seq_length,696size_per_head)697
698# `key_layer` = [B, N, T, H]699key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,700to_seq_length, size_per_head)701
702# Take the dot product between "query" and "key" to get the raw703# attention scores.704# `attention_scores` = [B, N, F, T]705attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)706attention_scores = tf.multiply(attention_scores,7071.0 / math.sqrt(float(size_per_head)))708
709if attention_mask is not None:710# `attention_mask` = [B, 1, F, T]711attention_mask = tf.expand_dims(attention_mask, axis=[1])712
713# Since attention_mask is 1.0 for positions we want to attend and 0.0 for714# masked positions, this operation will create a tensor which is 0.0 for715# positions we want to attend and -10000.0 for masked positions.716adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0717
718# Since we are adding it to the raw scores before the softmax, this is719# effectively the same as removing these entirely.720attention_scores += adder721
722# Normalize the attention scores to probabilities.723# `attention_probs` = [B, N, F, T]724attention_probs = tf.nn.softmax(attention_scores)725
726# This is actually dropping out entire tokens to attend to, which might727# seem a bit unusual, but is taken from the original Transformer paper.728attention_probs = dropout(attention_probs, attention_probs_dropout_prob)729
730# `value_layer` = [B, T, N, H]731value_layer = tf.reshape(732value_layer,733[batch_size, to_seq_length, num_attention_heads, size_per_head])734
735# `value_layer` = [B, N, T, H]736value_layer = tf.transpose(value_layer, [0, 2, 1, 3])737
738# `context_layer` = [B, N, F, H]739context_layer = tf.matmul(attention_probs, value_layer)740
741# `context_layer` = [B, F, N, H]742context_layer = tf.transpose(context_layer, [0, 2, 1, 3])743
744if do_return_2d_tensor:745# `context_layer` = [B*F, N*H]746context_layer = tf.reshape(747context_layer,748[batch_size * from_seq_length, num_attention_heads * size_per_head])749else:750# `context_layer` = [B, F, N*H]751context_layer = tf.reshape(752context_layer,753[batch_size, from_seq_length, num_attention_heads * size_per_head])754
755return context_layer756
757
758def transformer_model(input_tensor,759attention_mask=None,760hidden_size=768,761num_hidden_layers=12,762num_attention_heads=12,763intermediate_size=3072,764intermediate_act_fn=gelu,765hidden_dropout_prob=0.1,766attention_probs_dropout_prob=0.1,767initializer_range=0.02,768do_return_all_layers=False):769"""Multi-headed, multi-layer Transformer from "Attention is All You Need".770
771This is almost an exact implementation of the original Transformer encoder.
772
773See the original paper:
774https://arxiv.org/abs/1706.03762
775
776Also see:
777https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
778
779Args:
780input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
781attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
782seq_length], with 1 for positions that can be attended to and 0 in
783positions that should not be.
784hidden_size: int. Hidden size of the Transformer.
785num_hidden_layers: int. Number of layers (blocks) in the Transformer.
786num_attention_heads: int. Number of attention heads in the Transformer.
787intermediate_size: int. The size of the "intermediate" (a.k.a., feed
788forward) layer.
789intermediate_act_fn: function. The non-linear activation function to apply
790to the output of the intermediate/feed-forward layer.
791hidden_dropout_prob: float. Dropout probability for the hidden layers.
792attention_probs_dropout_prob: float. Dropout probability of the attention
793probabilities.
794initializer_range: float. Range of the initializer (stddev of truncated
795normal).
796do_return_all_layers: Whether to also return all layers or just the final
797layer.
798
799Returns:
800float Tensor of shape [batch_size, seq_length, hidden_size], the final
801hidden layer of the Transformer.
802
803Raises:
804ValueError: A Tensor shape or parameter is invalid.
805"""
806if hidden_size % num_attention_heads != 0:807raise ValueError(808"The hidden size (%d) is not a multiple of the number of attention "809"heads (%d)" % (hidden_size, num_attention_heads))810
811attention_head_size = int(hidden_size / num_attention_heads)812input_shape = get_shape_list(input_tensor, expected_rank=3)813batch_size = input_shape[0]814seq_length = input_shape[1]815input_width = input_shape[2]816
817# The Transformer performs sum residuals on all layers so the input needs818# to be the same as the hidden size.819if input_width != hidden_size:820raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %821(input_width, hidden_size))822
823# We keep the representation as a 2D tensor to avoid re-shaping it back and824# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on825# the GPU/CPU but may not be free on the TPU, so we want to minimize them to826# help the optimizer.827prev_output = reshape_to_matrix(input_tensor)828
829all_layer_outputs = []830for layer_idx in range(num_hidden_layers):831with tf.variable_scope("layer_%d" % layer_idx):832layer_input = prev_output833
834with tf.variable_scope("attention"):835attention_heads = []836with tf.variable_scope("self"):837attention_head = attention_layer(838from_tensor=layer_input,839to_tensor=layer_input,840attention_mask=attention_mask,841num_attention_heads=num_attention_heads,842size_per_head=attention_head_size,843attention_probs_dropout_prob=attention_probs_dropout_prob,844initializer_range=initializer_range,845do_return_2d_tensor=True,846batch_size=batch_size,847from_seq_length=seq_length,848to_seq_length=seq_length)849attention_heads.append(attention_head)850
851attention_output = None852if len(attention_heads) == 1:853attention_output = attention_heads[0]854else:855# In the case where we have other sequences, we just concatenate856# them to the self-attention head before the projection.857attention_output = tf.concat(attention_heads, axis=-1)858
859# Run a linear projection of `hidden_size` then add a residual860# with `layer_input`.861with tf.variable_scope("output"):862attention_output = tf.layers.dense(863attention_output,864hidden_size,865kernel_initializer=create_initializer(initializer_range))866attention_output = dropout(attention_output, hidden_dropout_prob)867attention_output = layer_norm(attention_output + layer_input)868
869# The activation is only applied to the "intermediate" hidden layer.870with tf.variable_scope("intermediate"):871intermediate_output = tf.layers.dense(872attention_output,873intermediate_size,874activation=intermediate_act_fn,875kernel_initializer=create_initializer(initializer_range))876
877# Down-project back to `hidden_size` then add the residual.878with tf.variable_scope("output"):879layer_output = tf.layers.dense(880intermediate_output,881hidden_size,882kernel_initializer=create_initializer(initializer_range))883layer_output = dropout(layer_output, hidden_dropout_prob)884layer_output = layer_norm(layer_output + attention_output)885prev_output = layer_output886all_layer_outputs.append(layer_output)887
888if do_return_all_layers:889final_outputs = []890for layer_output in all_layer_outputs:891final_output = reshape_from_matrix(layer_output, input_shape)892final_outputs.append(final_output)893return final_outputs894else:895final_output = reshape_from_matrix(prev_output, input_shape)896return final_output897
898
899def get_shape_list(tensor, expected_rank=None, name=None):900"""Returns a list of the shape of tensor, preferring static dimensions.901
902Args:
903tensor: A tf.Tensor object to find the shape of.
904expected_rank: (optional) int. The expected rank of `tensor`. If this is
905specified and the `tensor` has a different rank, and exception will be
906thrown.
907name: Optional name of the tensor for the error message.
908
909Returns:
910A list of dimensions of the shape of tensor. All static dimensions will
911be returned as python integers, and dynamic dimensions will be returned
912as tf.Tensor scalars.
913"""
914if name is None:915name = tensor.name916
917if expected_rank is not None:918assert_rank(tensor, expected_rank, name)919
920shape = tensor.shape.as_list()921
922non_static_indexes = []923for (index, dim) in enumerate(shape):924if dim is None:925non_static_indexes.append(index)926
927if not non_static_indexes:928return shape929
930dyn_shape = tf.shape(tensor)931for index in non_static_indexes:932shape[index] = dyn_shape[index]933return shape934
935
936def reshape_to_matrix(input_tensor):937"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""938ndims = input_tensor.shape.ndims939if ndims < 2:940raise ValueError("Input tensor must have at least rank 2. Shape = %s" %941(input_tensor.shape))942if ndims == 2:943return input_tensor944
945width = input_tensor.shape[-1]946output_tensor = tf.reshape(input_tensor, [-1, width])947return output_tensor948
949
950def reshape_from_matrix(output_tensor, orig_shape_list):951"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""952if len(orig_shape_list) == 2:953return output_tensor954
955output_shape = get_shape_list(output_tensor)956
957orig_dims = orig_shape_list[0:-1]958width = output_shape[-1]959
960return tf.reshape(output_tensor, orig_dims + [width])961
962
963def assert_rank(tensor, expected_rank, name=None):964"""Raises an exception if the tensor rank is not of the expected rank.965
966Args:
967tensor: A tf.Tensor to check the rank of.
968expected_rank: Python integer or list of integers, expected rank.
969name: Optional name of the tensor for the error message.
970
971Raises:
972ValueError: If the expected shape doesn't match the actual shape.
973"""
974if name is None:975name = tensor.name976
977expected_rank_dict = {}978if isinstance(expected_rank, six.integer_types):979expected_rank_dict[expected_rank] = True980else:981for x in expected_rank:982expected_rank_dict[x] = True983
984actual_rank = tensor.shape.ndims985if actual_rank not in expected_rank_dict:986scope_name = tf.get_variable_scope().name987raise ValueError(988"For the tensor `%s` in scope `%s`, the actual rank "989"`%d` (shape = %s) is not equal to the expected rank `%s`" %990(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))991