google-research
1044 строки · 38.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""The main BERT model and related functions. Copied from bert."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import copy
24import json
25import math
26import re
27import numpy as np
28import six
29import tensorflow.compat.v1 as tf
30from tf_slim.layers import layers
31
32
33class BertConfig(object):
34"""Configuration for `BertModel`."""
35
36def __init__(self,
37vocab_size,
38hidden_size=768,
39num_hidden_layers=12,
40num_attention_heads=12,
41intermediate_size=3072,
42hidden_act="gelu",
43hidden_dropout_prob=0.1,
44attention_probs_dropout_prob=0.1,
45max_position_embeddings=512,
46type_vocab_size=16,
47initializer_range=0.02):
48"""Constructs BertConfig.
49
50Args:
51vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
52hidden_size: Size of the encoder layers and the pooler layer.
53num_hidden_layers: Number of hidden layers in the Transformer encoder.
54num_attention_heads: Number of attention heads for each attention layer in
55the Transformer encoder.
56intermediate_size: The size of the "intermediate" (i.e., feed-forward)
57layer in the Transformer encoder.
58hidden_act: The non-linear activation function (function or string) in the
59encoder and pooler.
60hidden_dropout_prob: The dropout probability for all fully connected
61layers in the embeddings, encoder, and pooler.
62attention_probs_dropout_prob: The dropout ratio for the attention
63probabilities.
64max_position_embeddings: The maximum sequence length that this model might
65ever be used with. Typically set this to something large just in case
66(e.g., 512 or 1024 or 2048).
67type_vocab_size: The vocabulary size of the `token_type_ids` passed into
68`BertModel`.
69initializer_range: The stdev of the truncated_normal_initializer for
70initializing all weight matrices.
71"""
72self.vocab_size = vocab_size
73self.hidden_size = hidden_size
74self.num_hidden_layers = num_hidden_layers
75self.num_attention_heads = num_attention_heads
76self.hidden_act = hidden_act
77self.intermediate_size = intermediate_size
78self.hidden_dropout_prob = hidden_dropout_prob
79self.attention_probs_dropout_prob = attention_probs_dropout_prob
80self.max_position_embeddings = max_position_embeddings
81self.type_vocab_size = type_vocab_size
82self.initializer_range = initializer_range
83
84@classmethod
85def from_dict(cls, json_object, strict=False):
86"""Constructs a `BertConfig` from a Python dictionary of parameters."""
87config = cls(vocab_size=None)
88for (key, value) in six.iteritems(json_object):
89if strict and key not in config.__dict__:
90raise ValueError("BertConfig has no field '{}'".format(key))
91config.__dict__[key] = value
92if strict and config.vocab_size is None:
93raise ValueError("BertConfig field 'vocab_size' is unset")
94return config
95
96@classmethod
97def from_json_file(cls, json_file, strict=False):
98"""Constructs a `BertConfig` from a json file of parameters."""
99with tf.io.gfile.GFile(json_file, "r") as reader:
100text = reader.read()
101return cls.from_dict(json.loads(text), strict=strict)
102
103def to_dict(self):
104"""Serializes this instance to a Python dictionary."""
105output = copy.deepcopy(self.__dict__)
106return output
107
108def to_json_string(self):
109"""Serializes this instance to a JSON string."""
110return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
111
112
113class BertModel(object):
114"""BERT model ("Bidirectional Encoder Representations from Transformers").
115
116Example usage:
117
118```python
119# Already been converted into WordPiece token ids
120input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
121input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
122token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
123
124config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
125num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
126
127model = modeling.BertModel(config=config, is_training=True,
128input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
129
130label_embeddings = tf.get_variable(...)
131pooled_output = model.get_pooled_output()
132logits = tf.matmul(pooled_output, label_embeddings)
133...
134```
135"""
136
137def __init__(self,
138config,
139is_training,
140input_ids,
141input_mask=None,
142token_type_ids=None,
143use_one_hot_embeddings=False,
144scope=None):
145"""Constructor for BertModel.
146
147Args:
148config: `BertConfig` instance.
149is_training: bool. true for training model, false for eval model. Controls
150whether dropout will be applied.
151input_ids: int32 Tensor of shape [batch_size, seq_length].
152input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
153token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
154use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
155embeddings or tf.embedding_lookup() for the word embeddings.
156scope: (optional) variable scope. Defaults to "bert".
157
158Raises:
159ValueError: The config is invalid or one of the input tensor shapes
160is invalid.
161"""
162config = copy.deepcopy(config)
163if not is_training:
164config.hidden_dropout_prob = 0.0
165config.attention_probs_dropout_prob = 0.0
166
167input_shape = get_shape_list(input_ids, expected_rank=2)
168batch_size = input_shape[0]
169seq_length = input_shape[1]
170
171if input_mask is None:
172input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
173
174if token_type_ids is None:
175token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
176
177with tf.variable_scope(scope, default_name="bert"):
178with tf.variable_scope("embeddings"):
179# Perform embedding lookup on the word ids.
180(self.word_embedding_output, self.embedding_table) = embedding_lookup(
181input_ids=input_ids,
182vocab_size=config.vocab_size,
183embedding_size=config.hidden_size,
184initializer_range=config.initializer_range,
185word_embedding_name="word_embeddings",
186use_one_hot_embeddings=use_one_hot_embeddings)
187
188# Add positional embeddings and token type embeddings, then layer
189# normalize and perform dropout.
190self.embedding_output = embedding_postprocessor(
191input_tensor=self.word_embedding_output,
192use_token_type=True,
193token_type_ids=token_type_ids,
194token_type_vocab_size=config.type_vocab_size,
195token_type_embedding_name="token_type_embeddings",
196use_position_embeddings=True,
197position_embedding_name="position_embeddings",
198initializer_range=config.initializer_range,
199max_position_embeddings=config.max_position_embeddings,
200dropout_prob=config.hidden_dropout_prob)
201
202with tf.variable_scope("encoder"):
203# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
204# mask of shape [batch_size, seq_length, seq_length] which is used
205# for the attention scores.
206attention_mask = create_attention_mask_from_input_mask(
207input_ids, input_mask)
208
209# Run the stacked transformer.
210# `sequence_output` shape = [batch_size, seq_length, hidden_size].
211self.all_encoder_layers = transformer_model(
212input_tensor=self.embedding_output,
213attention_mask=attention_mask,
214hidden_size=config.hidden_size,
215num_hidden_layers=config.num_hidden_layers,
216num_attention_heads=config.num_attention_heads,
217intermediate_size=config.intermediate_size,
218intermediate_act_fn=get_activation(config.hidden_act),
219hidden_dropout_prob=config.hidden_dropout_prob,
220attention_probs_dropout_prob=config.attention_probs_dropout_prob,
221initializer_range=config.initializer_range,
222do_return_all_layers=True)
223
224self.sequence_output = self.all_encoder_layers[-1]
225# The "pooler" converts the encoded sequence tensor of shape
226# [batch_size, seq_length, hidden_size] to a tensor of shape
227# [batch_size, hidden_size]. This is necessary for segment-level
228# (or segment-pair-level) classification tasks where we need a fixed
229# dimensional representation of the segment.
230with tf.variable_scope("pooler"):
231# We "pool" the model by simply taking the hidden state corresponding
232# to the first token. We assume that this has been pre-trained
233first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
234self.pooled_output = tf.layers.dense(
235first_token_tensor,
236config.hidden_size,
237activation=tf.tanh,
238kernel_initializer=create_initializer(config.initializer_range))
239
240def get_pooled_output(self):
241return self.pooled_output
242
243def get_sequence_output(self):
244"""Gets final hidden layer of encoder.
245
246Returns:
247float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
248to the final hidden of the transformer encoder.
249"""
250return self.sequence_output
251
252def get_all_encoder_layers(self):
253return self.all_encoder_layers
254
255def get_word_embedding_output(self):
256"""Get output of the word(piece) embedding lookup.
257
258This is BEFORE positional embeddings and token type embeddings have been
259added.
260
261Returns:
262float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
263to the output of the word(piece) embedding layer.
264"""
265return self.word_embedding_output
266
267def get_embedding_output(self):
268"""Gets output of the embedding lookup (i.e., input to the transformer).
269
270Returns:
271float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
272to the output of the embedding layer, after summing the word
273embeddings with the positional embeddings and the token type embeddings,
274then performing layer normalization. This is the input to the transformer.
275"""
276return self.embedding_output
277
278def get_embedding_table(self):
279return self.embedding_table
280
281
282def gelu(x):
283"""Gaussian Error Linear Unit.
284
285This is a smoother version of the RELU.
286Original paper: https://arxiv.org/abs/1606.08415
287Args:
288x: float Tensor to perform activation.
289
290Returns:
291`x` with the GELU activation applied.
292"""
293cdf = 0.5 * (1.0 + tf.tanh(
294(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
295return x * cdf
296
297
298def get_activation(activation_string):
299"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
300
301Args:
302activation_string: String name of the activation function.
303
304Returns:
305A Python function corresponding to the activation function. If
306`activation_string` is None, empty, or "linear", this will return None.
307If `activation_string` is not a string, it will return `activation_string`.
308
309Raises:
310ValueError: The `activation_string` does not correspond to a known
311activation.
312"""
313
314# We assume that anything that"s not a string is already an activation
315# function, so we just return it.
316if not isinstance(activation_string, six.string_types):
317return activation_string
318
319if not activation_string:
320return None
321
322act = activation_string.lower()
323if act == "linear":
324return None
325elif act == "relu":
326return tf.nn.relu
327elif act == "gelu":
328return gelu
329elif act == "tanh":
330return tf.tanh
331else:
332raise ValueError("Unsupported activation: %s" % act)
333
334
335def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
336"""Compute the union of the current variables and checkpoint variables."""
337assignment_map = {}
338initialized_variable_names = {}
339
340name_to_variable = collections.OrderedDict()
341for var in tvars:
342name = var.name
343m = re.match("^(.*):\\d+$", name)
344if m is not None:
345name = m.group(1)
346name_to_variable[name] = var
347
348init_vars = tf.train.list_variables(init_checkpoint)
349
350assignment_map = collections.OrderedDict()
351for x in init_vars:
352(name, var) = (x[0], x[1])
353if name not in name_to_variable:
354continue
355assignment_map[name] = name
356initialized_variable_names[name] = 1
357initialized_variable_names[name + ":0"] = 1
358
359return (assignment_map, initialized_variable_names)
360
361
362def dropout(input_tensor, dropout_prob):
363"""Perform dropout.
364
365Args:
366input_tensor: float Tensor.
367dropout_prob: Python float. The probability of dropping out a value (NOT of
368*keeping* a dimension as in `tf.nn.dropout`).
369
370Returns:
371A version of `input_tensor` with dropout applied.
372"""
373if dropout_prob is None or dropout_prob == 0.0:
374return input_tensor
375
376output = tf.nn.dropout(input_tensor, rate=dropout_prob)
377return output
378
379
380def layer_norm(input_tensor, name=None):
381"""Run layer normalization on the last dimension of the tensor."""
382return layers.layer_norm(
383inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
384
385
386def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
387"""Runs layer normalization followed by dropout."""
388output_tensor = layer_norm(input_tensor, name)
389output_tensor = dropout(output_tensor, dropout_prob)
390return output_tensor
391
392
393def create_initializer(initializer_range=0.02):
394"""Creates a `truncated_normal_initializer` with the given range."""
395return tf.truncated_normal_initializer(stddev=initializer_range)
396
397
398def embedding_lookup(input_ids,
399vocab_size,
400embedding_size=128,
401initializer_range=0.02,
402word_embedding_name="word_embeddings",
403use_one_hot_embeddings=False):
404"""Looks up words embeddings for id tensor.
405
406Args:
407input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
408ids.
409vocab_size: int. Size of the embedding vocabulary.
410embedding_size: int. Width of the word embeddings.
411initializer_range: float. Embedding initialization range.
412word_embedding_name: string. Name of the embedding table.
413use_one_hot_embeddings: bool. If True, use one-hot method for word
414embeddings. If False, use `tf.nn.embedding_lookup()`.
415
416Returns:
417float Tensor of shape [batch_size, seq_length, embedding_size].
418"""
419# This function assumes that the input is of shape [batch_size, seq_length,
420# num_inputs].
421#
422# If the input is a 2D tensor of shape [batch_size, seq_length], we
423# reshape to [batch_size, seq_length, 1].
424if input_ids.shape.ndims == 2:
425input_ids = tf.expand_dims(input_ids, axis=[-1])
426
427embedding_table = tf.get_variable(
428name=word_embedding_name,
429shape=[vocab_size, embedding_size],
430initializer=create_initializer(initializer_range))
431
432if use_one_hot_embeddings:
433flat_input_ids = tf.reshape(input_ids, [-1])
434one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
435output = tf.matmul(one_hot_input_ids, embedding_table)
436else:
437output = tf.nn.embedding_lookup(embedding_table, input_ids)
438
439input_shape = get_shape_list(input_ids)
440
441output = tf.reshape(output,
442input_shape[0:-1] + [input_shape[-1] * embedding_size])
443return (output, embedding_table)
444
445
446def embedding_postprocessor(input_tensor,
447use_token_type=False,
448token_type_ids=None,
449token_type_vocab_size=16,
450token_type_embedding_name="token_type_embeddings",
451use_position_embeddings=True,
452position_embedding_name="position_embeddings",
453initializer_range=0.02,
454max_position_embeddings=512,
455dropout_prob=0.1):
456"""Performs various post-processing on a word embedding tensor.
457
458Args:
459input_tensor: float Tensor of shape [batch_size, seq_length,
460embedding_size].
461use_token_type: bool. Whether to add embeddings for `token_type_ids`.
462token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
463Must be specified if `use_token_type` is True.
464token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
465token_type_embedding_name: string. The name of the embedding table variable
466for token type ids.
467use_position_embeddings: bool. Whether to add position embeddings for the
468position of each token in the sequence.
469position_embedding_name: string. The name of the embedding table variable
470for positional embeddings.
471initializer_range: float. Range of the weight initialization.
472max_position_embeddings: int. Maximum sequence length that might ever be
473used with this model. This can be longer than the sequence length of
474input_tensor, but cannot be shorter.
475dropout_prob: float. Dropout probability applied to the final output tensor.
476
477Returns:
478float tensor with same shape as `input_tensor`.
479
480Raises:
481ValueError: One of the tensor shapes or input values is invalid.
482"""
483input_shape = get_shape_list(input_tensor, expected_rank=3)
484batch_size = input_shape[0]
485seq_length = input_shape[1]
486width = input_shape[2]
487
488output = input_tensor
489
490if use_token_type:
491if token_type_ids is None:
492raise ValueError("`token_type_ids` must be specified if"
493"`use_token_type` is True.")
494token_type_table = tf.get_variable(
495name=token_type_embedding_name,
496shape=[token_type_vocab_size, width],
497initializer=create_initializer(initializer_range))
498# This vocab will be small so we always do one-hot here, since it is always
499# faster for a small vocabulary.
500flat_token_type_ids = tf.reshape(token_type_ids, [-1])
501one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
502token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
503token_type_embeddings = tf.reshape(token_type_embeddings,
504[batch_size, seq_length, width])
505output += token_type_embeddings
506
507if use_position_embeddings:
508# Create the variable outside the assertion to avoid TF2 compatibility
509# issues.
510full_position_embeddings = tf.get_variable(
511name=position_embedding_name,
512shape=[max_position_embeddings, width],
513initializer=create_initializer(initializer_range))
514
515assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
516with tf.control_dependencies([assert_op]):
517# Since the position embedding table is a learned variable, we create it
518# using a (long) sequence length `max_position_embeddings`. The actual
519# sequence length might be shorter than this, for faster training of
520# tasks that do not have long sequences.
521#
522# So `full_position_embeddings` is effectively an embedding table
523# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
524# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
525# perform a slice.
526position_embeddings = tf.slice(full_position_embeddings, [0, 0],
527[seq_length, -1])
528num_dims = len(output.shape.as_list())
529
530# Only the last two dimensions are relevant (`seq_length` and `width`), so
531# we broadcast among the first dimensions, which is typically just
532# the batch size.
533position_broadcast_shape = []
534for _ in range(num_dims - 2):
535position_broadcast_shape.append(1)
536position_broadcast_shape.extend([seq_length, width])
537position_embeddings = tf.reshape(position_embeddings,
538position_broadcast_shape)
539output += position_embeddings
540
541output = layer_norm_and_dropout(output, dropout_prob)
542return output
543
544
545def create_attention_mask_from_input_mask(from_tensor, to_mask):
546"""Create 3D attention mask from a 2D tensor mask.
547
548Args:
549from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
550to_mask: int32 Tensor of shape [batch_size, to_seq_length].
551
552Returns:
553float Tensor of shape [batch_size, from_seq_length, to_seq_length].
554"""
555from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
556batch_size = from_shape[0]
557from_seq_length = from_shape[1]
558
559to_shape = get_shape_list(to_mask, expected_rank=2)
560to_seq_length = to_shape[1]
561
562to_mask = tf.cast(
563tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
564
565# We don't assume that `from_tensor` is a mask (although it could be). We
566# don't actually care if we attend *from* padding tokens (only *to* padding)
567# tokens so we create a tensor of all ones.
568#
569# `broadcast_ones` = [batch_size, from_seq_length, 1]
570broadcast_ones = tf.ones(
571shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
572
573# Here we broadcast along two dimensions to create the mask.
574mask = broadcast_ones * to_mask
575
576return mask
577
578
579def dense_layer_3d(input_tensor,
580num_attention_heads,
581size_per_head,
582initializer,
583activation,
584name=None):
585"""A dense layer with 3D kernel.
586
587Args:
588input_tensor: float Tensor of shape [batch, seq_length, hidden_size].
589num_attention_heads: Number of attention heads.
590size_per_head: The size per attention head.
591initializer: Kernel initializer.
592activation: Actication function.
593name: The name scope of this layer.
594
595Returns:
596float logits Tensor.
597"""
598
599last_dim = get_shape_list(input_tensor)[-1]
600
601with tf.variable_scope(name):
602w = tf.get_variable(
603name="kernel",
604shape=[last_dim, num_attention_heads * size_per_head],
605initializer=initializer)
606w = tf.reshape(w, [last_dim, num_attention_heads, size_per_head])
607b = tf.get_variable(
608name="bias",
609shape=[num_attention_heads * size_per_head],
610initializer=tf.zeros_initializer)
611b = tf.reshape(b, [num_attention_heads, size_per_head])
612ret = tf.einsum("abc,cde->abde", input_tensor, w)
613ret += b
614if activation is not None:
615return activation(ret)
616else:
617return ret
618
619
620def dense_layer_3d_proj(input_tensor,
621hidden_size,
622num_attention_heads,
623head_size,
624initializer,
625activation,
626name=None):
627"""A dense layer with 3D kernel for projection.
628
629Args:
630input_tensor: float Tensor of shape [batch,from_seq_length,
631num_attention_heads, size_per_head].
632hidden_size: The size of hidden layer.
633num_attention_heads: The size of output dimension.
634head_size: The size of head.
635initializer: Kernel initializer.
636activation: Actication function.
637name: The name scope of this layer.
638
639Returns:
640float logits Tensor.
641"""
642head_size = hidden_size // num_attention_heads
643with tf.variable_scope(name):
644w = tf.get_variable(
645name="kernel",
646shape=[hidden_size, hidden_size],
647initializer=initializer)
648w = tf.reshape(w, [num_attention_heads, head_size, hidden_size])
649b = tf.get_variable(
650name="bias", shape=[hidden_size], initializer=tf.zeros_initializer)
651
652ret = tf.einsum("BFNH,NHD->BFD", input_tensor, w)
653ret += b
654if activation is not None:
655return activation(ret)
656else:
657return ret
658
659
660def dense_layer_2d(input_tensor,
661output_size,
662initializer,
663activation,
664name=None):
665"""A dense layer with 2D kernel.
666
667Args:
668input_tensor: Float tensor with rank 3.
669output_size: The size of output dimension.
670initializer: Kernel initializer.
671activation: Actication function.
672name: The name scope of this layer.
673
674Returns:
675float logits Tensor.
676"""
677last_dim = get_shape_list(input_tensor)[-1]
678with tf.variable_scope(name):
679w = tf.get_variable(
680name="kernel", shape=[last_dim, output_size], initializer=initializer)
681b = tf.get_variable(
682name="bias", shape=[output_size], initializer=tf.zeros_initializer)
683
684ret = tf.einsum("abc,cd->abd", input_tensor, w)
685ret += b
686if activation is not None:
687return activation(ret)
688else:
689return ret
690
691
692def attention_layer(from_tensor,
693to_tensor,
694attention_mask=None,
695num_attention_heads=1,
696size_per_head=512,
697query_act=None,
698key_act=None,
699value_act=None,
700attention_probs_dropout_prob=0.0,
701initializer_range=0.02,
702batch_size=None,
703from_seq_length=None,
704to_seq_length=None):
705"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
706
707This is an implementation of multi-headed attention based on "Attention
708is all you Need". If `from_tensor` and `to_tensor` are the same, then
709this is self-attention. Each timestep in `from_tensor` attends to the
710corresponding sequence in `to_tensor`, and returns a fixed-with vector.
711
712This function first projects `from_tensor` into a "query" tensor and
713`to_tensor` into "key" and "value" tensors. These are (effectively) a list
714of tensors of length `num_attention_heads`, where each tensor is of shape
715[batch_size, seq_length, size_per_head].
716
717Then, the query and key tensors are dot-producted and scaled. These are
718softmaxed to obtain attention probabilities. The value tensors are then
719interpolated by these probabilities, then concatenated back to a single
720tensor and returned.
721
722In practice, the multi-headed attention are done with tf.einsum as follows:
723Input_tensor: [BFD]
724Wq, Wk, Wv: [DNH]
725Q:[BFNH] = einsum('BFD,DNH->BFNH', Input_tensor, Wq)
726K:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wk)
727V:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wv)
728attention_scores:[BNFT] = einsum('BFNH,BTNH>BNFT', Q, K) / sqrt(H)
729attention_probs:[BNFT] = softmax(attention_scores)
730context_layer:[BFNH] = einsum('BNFT,BTNH->BFNH', attention_probs, V)
731Wout:[DNH]
732Output:[BFD] = einsum('BFNH,DNH>BFD', context_layer, Wout)
733
734Args:
735from_tensor: float Tensor of shape [batch_size, from_seq_length,
736from_width].
737to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
738attention_mask: (optional) int32 Tensor of shape [batch_size,
739from_seq_length, to_seq_length]. The values should be 1 or 0. The
740attention scores will effectively be set to -infinity for any positions in
741the mask that are 0, and will be unchanged for positions that are 1.
742num_attention_heads: int. Number of attention heads.
743size_per_head: int. Size of each attention head.
744query_act: (optional) Activation function for the query transform.
745key_act: (optional) Activation function for the key transform.
746value_act: (optional) Activation function for the value transform.
747attention_probs_dropout_prob: (optional) float. Dropout probability of the
748attention probabilities.
749initializer_range: float. Range of the weight initializer.
750batch_size: (Optional) int. If the input is 2D, this might be the batch size
751of the 3D version of the `from_tensor` and `to_tensor`.
752from_seq_length: (Optional) If the input is 2D, this might be the seq length
753of the 3D version of the `from_tensor`.
754to_seq_length: (Optional) If the input is 2D, this might be the seq length
755of the 3D version of the `to_tensor`.
756
757Returns:
758float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
759size_per_head].
760
761Raises:
762ValueError: Any of the arguments or tensor shapes are invalid.
763"""
764from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
765to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
766
767if len(from_shape) != len(to_shape):
768raise ValueError(
769"The rank of `from_tensor` must match the rank of `to_tensor`.")
770
771if len(from_shape) == 3:
772batch_size = from_shape[0]
773from_seq_length = from_shape[1]
774to_seq_length = to_shape[1]
775elif len(from_shape) == 2:
776if (batch_size is None or from_seq_length is None or to_seq_length is None):
777raise ValueError(
778"When passing in rank 2 tensors to attention_layer, the values "
779"for `batch_size`, `from_seq_length`, and `to_seq_length` "
780"must all be specified.")
781
782# Scalar dimensions referenced here:
783# B = batch size (number of sequences)
784# F = `from_tensor` sequence length
785# T = `to_tensor` sequence length
786# N = `num_attention_heads`
787# H = `size_per_head`
788
789# `query_layer` = [B, F, N, H]
790query_layer = dense_layer_3d(from_tensor, num_attention_heads, size_per_head,
791create_initializer(initializer_range), query_act,
792"query")
793
794# `key_layer` = [B, T, N, H]
795key_layer = dense_layer_3d(to_tensor, num_attention_heads, size_per_head,
796create_initializer(initializer_range), key_act,
797"key")
798
799# `value_layer` = [B, T, N, H]
800value_layer = dense_layer_3d(to_tensor, num_attention_heads, size_per_head,
801create_initializer(initializer_range), value_act,
802"value")
803
804# Take the dot product between "query" and "key" to get the raw
805# attention scores.
806attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_layer, query_layer)
807attention_scores = tf.multiply(attention_scores,
8081.0 / math.sqrt(float(size_per_head)))
809
810if attention_mask is not None:
811# `attention_mask` = [B, 1, F, T]
812attention_mask = tf.expand_dims(attention_mask, axis=[1])
813
814# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
815# masked positions, this operation will create a tensor which is 0.0 for
816# positions we want to attend and -10000.0 for masked positions.
817adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
818
819# Since we are adding it to the raw scores before the softmax, this is
820# effectively the same as removing these entirely.
821attention_scores += adder
822
823# Normalize the attention scores to probabilities.
824# `attention_probs` = [B, N, F, T]
825attention_probs = tf.nn.softmax(attention_scores)
826
827# This is actually dropping out entire tokens to attend to, which might
828# seem a bit unusual, but is taken from the original Transformer paper.
829attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
830
831# `context_layer` = [B, F, N, H]
832context_layer = tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_layer)
833
834return context_layer
835
836
837def transformer_model(input_tensor,
838attention_mask=None,
839hidden_size=768,
840num_hidden_layers=12,
841num_attention_heads=12,
842intermediate_size=3072,
843intermediate_act_fn=gelu,
844hidden_dropout_prob=0.1,
845attention_probs_dropout_prob=0.1,
846initializer_range=0.02,
847do_return_all_layers=False):
848"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
849
850This is almost an exact implementation of the original Transformer encoder.
851
852See the original paper:
853https://arxiv.org/abs/1706.03762
854
855Also see:
856https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
857
858Args:
859input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
860attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
861seq_length], with 1 for positions that can be attended to and 0 in
862positions that should not be.
863hidden_size: int. Hidden size of the Transformer.
864num_hidden_layers: int. Number of layers (blocks) in the Transformer.
865num_attention_heads: int. Number of attention heads in the Transformer.
866intermediate_size: int. The size of the "intermediate" (a.k.a., feed
867forward) layer.
868intermediate_act_fn: function. The non-linear activation function to apply
869to the output of the intermediate/feed-forward layer.
870hidden_dropout_prob: float. Dropout probability for the hidden layers.
871attention_probs_dropout_prob: float. Dropout probability of the attention
872probabilities.
873initializer_range: float. Range of the initializer (stddev of truncated
874normal).
875do_return_all_layers: Whether to also return all layers or just the final
876layer.
877
878Returns:
879float Tensor of shape [batch_size, seq_length, hidden_size], the final
880hidden layer of the Transformer.
881
882Raises:
883ValueError: A Tensor shape or parameter is invalid.
884"""
885if hidden_size % num_attention_heads != 0:
886raise ValueError(
887"The hidden size (%d) is not a multiple of the number of attention "
888"heads (%d)" % (hidden_size, num_attention_heads))
889
890attention_head_size = int(hidden_size / num_attention_heads)
891input_shape = get_shape_list(input_tensor, expected_rank=3)
892input_width = input_shape[2]
893
894# The Transformer performs sum residuals on all layers so the input needs
895# to be the same as the hidden size.
896if input_width != hidden_size:
897raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
898(input_width, hidden_size))
899
900prev_output = input_tensor
901all_layer_outputs = []
902for layer_idx in range(num_hidden_layers):
903with tf.variable_scope("layer_%d" % layer_idx):
904layer_input = prev_output
905
906with tf.variable_scope("attention"):
907with tf.variable_scope("self"):
908attention_output = attention_layer(
909from_tensor=layer_input,
910to_tensor=layer_input,
911attention_mask=attention_mask,
912num_attention_heads=num_attention_heads,
913size_per_head=attention_head_size,
914attention_probs_dropout_prob=attention_probs_dropout_prob,
915initializer_range=initializer_range)
916
917# Run a linear projection of `hidden_size` then add a residual
918# with `layer_input`.
919with tf.variable_scope("output"):
920attention_output = dense_layer_3d_proj(
921attention_output, hidden_size,
922num_attention_heads, attention_head_size,
923create_initializer(initializer_range), None, "dense")
924attention_output = dropout(attention_output, hidden_dropout_prob)
925attention_output = layer_norm(attention_output + layer_input)
926
927# The activation is only applied to the "intermediate" hidden layer.
928with tf.variable_scope("intermediate"):
929intermediate_output = dense_layer_2d(
930attention_output, intermediate_size,
931create_initializer(initializer_range), intermediate_act_fn, "dense")
932
933# Down-project back to `hidden_size` then add the residual.
934with tf.variable_scope("output"):
935layer_output = dense_layer_2d(intermediate_output, hidden_size,
936create_initializer(initializer_range),
937None, "dense")
938layer_output = dropout(layer_output, hidden_dropout_prob)
939layer_output = layer_norm(layer_output + attention_output)
940prev_output = layer_output
941all_layer_outputs.append(layer_output)
942
943if do_return_all_layers:
944return all_layer_outputs
945else:
946return all_layer_outputs[-1]
947
948
949def get_shape_list(tensor, expected_rank=None, name=None):
950"""Returns a list of the shape of tensor, preferring static dimensions.
951
952Args:
953tensor: A tf.Tensor object to find the shape of.
954expected_rank: (optional) int. The expected rank of `tensor`. If this is
955specified and the `tensor` has a different rank, and exception will be
956thrown.
957name: Optional name of the tensor for the error message.
958
959Returns:
960A list of dimensions of the shape of tensor. All static dimensions will
961be returned as python integers, and dynamic dimensions will be returned
962as tf.Tensor scalars.
963"""
964if name is None:
965# Tensor.name is not supported in Eager mode.
966if tf.executing_eagerly():
967name = "get_shape_list"
968else:
969name = tensor.name
970
971if expected_rank is not None:
972assert_rank(tensor, expected_rank, name)
973
974shape = tensor.shape.as_list()
975
976non_static_indexes = []
977for (index, dim) in enumerate(shape):
978if dim is None:
979non_static_indexes.append(index)
980
981if not non_static_indexes:
982return shape
983
984dyn_shape = tf.shape(tensor)
985for index in non_static_indexes:
986shape[index] = dyn_shape[index]
987return shape
988
989
990def reshape_to_matrix(input_tensor):
991"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
992ndims = input_tensor.shape.ndims
993if ndims < 2:
994raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
995(input_tensor.shape))
996if ndims == 2:
997return input_tensor
998
999width = input_tensor.shape[-1]
1000output_tensor = tf.reshape(input_tensor, [-1, width])
1001return output_tensor
1002
1003
1004def reshape_from_matrix(output_tensor, orig_shape_list):
1005"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
1006if len(orig_shape_list) == 2:
1007return output_tensor
1008
1009output_shape = get_shape_list(output_tensor)
1010
1011orig_dims = orig_shape_list[0:-1]
1012width = output_shape[-1]
1013
1014return tf.reshape(output_tensor, orig_dims + [width])
1015
1016
1017def assert_rank(tensor, expected_rank, name=None):
1018"""Raises an exception if the tensor rank is not of the expected rank.
1019
1020Args:
1021tensor: A tf.Tensor to check the rank of.
1022expected_rank: Python integer or list of integers, expected rank.
1023name: Optional name of the tensor for the error message.
1024
1025Raises:
1026ValueError: If the expected shape doesn't match the actual shape.
1027"""
1028if name is None:
1029name = tensor.name
1030
1031expected_rank_dict = {}
1032if isinstance(expected_rank, six.integer_types):
1033expected_rank_dict[expected_rank] = True
1034else:
1035for x in expected_rank:
1036expected_rank_dict[x] = True
1037
1038actual_rank = tensor.shape.ndims
1039if actual_rank not in expected_rank_dict:
1040scope_name = tf.get_variable_scope().name
1041raise ValueError(
1042"For the tensor `%s` in scope `%s`, the actual rank "
1043"`%d` (shape = %s) is not equal to the expected rank `%s`" %
1044(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
1045