google-research

Форк
0
/
realformer.py 
1066 строк · 40.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""RealFormer: https://arxiv.org/abs/2012.11747."""
17

18
import collections
19
import copy
20
import json
21
import math
22
import re
23
import numpy as np
24
import six
25
import tensorflow.compat.v1 as tf
26
import tf_slim.layers as tf_slim_layers
27

28

29
class BertConfig(object):
30
  """Configuration for `BertModel`."""
31

32
  def __init__(self,
33
               vocab_size,
34
               hidden_size=768,
35
               num_hidden_layers=12,
36
               num_attention_heads=12,
37
               intermediate_size=3072,
38
               hidden_act="gelu",
39
               hidden_dropout_prob=0.1,
40
               attention_probs_dropout_prob=0.1,
41
               max_position_embeddings=512,
42
               type_vocab_size=16,
43
               initializer_range=0.02,
44
               use_running_mean=False):
45
    """Constructs BertConfig.
46

47
    Args:
48
      vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
49
      hidden_size: Size of the encoder layers and the pooler layer.
50
      num_hidden_layers: Number of hidden layers in the Transformer encoder.
51
      num_attention_heads: Number of attention heads for each attention layer in
52
        the Transformer encoder.
53
      intermediate_size: The size of the "intermediate" (i.e., feed-forward)
54
        layer in the Transformer encoder.
55
      hidden_act: The non-linear activation function (function or string) in the
56
        encoder and pooler.
57
      hidden_dropout_prob: The dropout probability for all fully connected
58
        layers in the embeddings, encoder, and pooler.
59
      attention_probs_dropout_prob: The dropout ratio for the attention
60
        probabilities.
61
      max_position_embeddings: The maximum sequence length that this model might
62
        ever be used with. Typically set this to something large just in case
63
        (e.g., 512 or 1024 or 2048).
64
      type_vocab_size: The vocabulary size of the `token_type_ids` passed into
65
        `BertModel`.
66
      initializer_range: The stdev of the truncated_normal_initializer for
67
        initializing all weight matrices.
68
      use_running_mean: Whether to softmax running mean instead of running
69
        sum of attention scores to compute attention probabilities. Running
70
        mean is found to be helpful for deep RealFormer models with 30+ layers.
71
    """
72
    self.vocab_size = vocab_size
73
    self.hidden_size = hidden_size
74
    self.num_hidden_layers = num_hidden_layers
75
    self.num_attention_heads = num_attention_heads
76
    self.hidden_act = hidden_act
77
    self.intermediate_size = intermediate_size
78
    self.hidden_dropout_prob = hidden_dropout_prob
79
    self.attention_probs_dropout_prob = attention_probs_dropout_prob
80
    self.max_position_embeddings = max_position_embeddings
81
    self.type_vocab_size = type_vocab_size
82
    self.initializer_range = initializer_range
83
    self.use_running_mean = use_running_mean
84

85
  @classmethod
86
  def from_dict(cls, json_object):
87
    """Constructs a `BertConfig` from a Python dictionary of parameters."""
88
    config = cls(vocab_size=None)
89
    for (key, value) in six.iteritems(json_object):
90
      config.__dict__[key] = value
91
    return config
92

93
  @classmethod
94
  def from_json_file(cls, json_file):
95
    """Constructs a `BertConfig` from a json file of parameters."""
96
    with tf.io.gfile.GFile(json_file, "r") as reader:
97
      text = reader.read()
98
    return cls.from_dict(json.loads(text))
99

100
  def to_dict(self):
101
    """Serializes this instance to a Python dictionary."""
102
    output = copy.deepcopy(self.__dict__)
103
    return output
104

105
  def to_json_string(self):
106
    """Serializes this instance to a JSON string."""
107
    return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
108

109

110
class BertModel(object):
111
  """BERT model with RealFormer as the backbone Transformer.
112

113
  Example usage:
114

115
  ```python
116
  # Already been converted into WordPiece token ids
117
  input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
118
  input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
119
  token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
120

121
  config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
122
    num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
123

124
  model = modeling.BertModel(config=config, is_training=True,
125
    input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
126

127
  label_embeddings = tf.get_variable(...)
128
  pooled_output = model.get_pooled_output()
129
  logits = tf.matmul(pooled_output, label_embeddings)
130
  ...
131
  ```
132
  """
133

134
  def __init__(self,
135
               config,
136
               is_training,
137
               input_ids,
138
               input_mask=None,
139
               token_type_ids=None,
140
               use_one_hot_embeddings=False,
141
               scope=None):
142
    """Constructor for BertModel.
143

144
    Args:
145
      config: `BertConfig` instance.
146
      is_training: bool. true for training model, false for eval model. Controls
147
        whether dropout will be applied.
148
      input_ids: int32 Tensor of shape [batch_size, seq_length].
149
      input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
150
      token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
151
      use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
152
        embeddings or tf.embedding_lookup() for the word embeddings.
153
      scope: (optional) variable scope. Defaults to "bert".
154

155
    Raises:
156
      ValueError: The config is invalid or one of the input tensor shapes
157
        is invalid.
158
    """
159
    config = copy.deepcopy(config)
160
    if not is_training:
161
      config.hidden_dropout_prob = 0.0
162
      config.attention_probs_dropout_prob = 0.0
163

164
    input_shape = get_shape_list(input_ids, expected_rank=2)
165
    batch_size = input_shape[0]
166
    seq_length = input_shape[1]
167

168
    if input_mask is None:
169
      input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
170

171
    if token_type_ids is None:
172
      token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
173

174
    with tf.variable_scope(scope, default_name="bert"):
175
      with tf.variable_scope("embeddings"):
176
        # Perform embedding lookup on the word ids.
177
        (self.word_embedding_output, self.embedding_table) = embedding_lookup(
178
            input_ids=input_ids,
179
            vocab_size=config.vocab_size,
180
            embedding_size=config.hidden_size,
181
            initializer_range=config.initializer_range,
182
            word_embedding_name="word_embeddings",
183
            use_one_hot_embeddings=use_one_hot_embeddings)
184

185
        # Add positional embeddings and token type embeddings, then layer
186
        # normalize and perform dropout.
187
        self.embedding_output = embedding_postprocessor(
188
            input_tensor=self.word_embedding_output,
189
            use_token_type=True,
190
            token_type_ids=token_type_ids,
191
            token_type_vocab_size=config.type_vocab_size,
192
            token_type_embedding_name="token_type_embeddings",
193
            use_position_embeddings=True,
194
            position_embedding_name="position_embeddings",
195
            initializer_range=config.initializer_range,
196
            max_position_embeddings=config.max_position_embeddings,
197
            dropout_prob=config.hidden_dropout_prob)
198

199
      with tf.variable_scope("encoder"):
200
        # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
201
        # mask of shape [batch_size, seq_length, seq_length] which is used
202
        # for the attention scores.
203
        attention_mask = create_attention_mask_from_input_mask(
204
            input_ids, input_mask)
205

206
        # Run the stacked transformer.
207
        # `sequence_output` shape = [batch_size, seq_length, hidden_size].
208
        self.all_encoder_layers = realformer_model(
209
            input_tensor=self.embedding_output,
210
            attention_mask=attention_mask,
211
            hidden_size=config.hidden_size,
212
            num_hidden_layers=config.num_hidden_layers,
213
            num_attention_heads=config.num_attention_heads,
214
            intermediate_size=config.intermediate_size,
215
            intermediate_act_fn=get_activation(config.hidden_act),
216
            hidden_dropout_prob=config.hidden_dropout_prob,
217
            attention_probs_dropout_prob=config.attention_probs_dropout_prob,
218
            initializer_range=config.initializer_range,
219
            use_running_mean=config.use_running_mean,
220
            do_return_all_layers=True)
221

222
      self.sequence_output = self.all_encoder_layers[-1]
223
      # The "pooler" converts the encoded sequence tensor of shape
224
      # [batch_size, seq_length, hidden_size] to a tensor of shape
225
      # [batch_size, hidden_size]. This is necessary for segment-level
226
      # (or segment-pair-level) classification tasks where we need a fixed
227
      # dimensional representation of the segment.
228
      with tf.variable_scope("pooler"):
229
        # We "pool" the model by simply taking the hidden state corresponding
230
        # to the first token. We assume that this has been pre-trained
231
        first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
232
        self.pooled_output = tf.layers.dense(
233
            first_token_tensor,
234
            config.hidden_size,
235
            activation=tf.tanh,
236
            kernel_initializer=create_initializer(config.initializer_range))
237

238
  def get_pooled_output(self):
239
    return self.pooled_output
240

241
  def get_sequence_output(self):
242
    """Gets final hidden layer of encoder.
243

244
    Returns:
245
      float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
246
      to the final hidden of the transformer encoder.
247
    """
248
    return self.sequence_output
249

250
  def get_all_encoder_layers(self):
251
    return self.all_encoder_layers
252

253
  def get_word_embedding_output(self):
254
    """Get output of the word(piece) embedding lookup.
255

256
    This is BEFORE positional embeddings and token type embeddings have been
257
    added.
258

259
    Returns:
260
      float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
261
      to the output of the word(piece) embedding layer.
262
    """
263
    return self.word_embedding_output
264

265
  def get_embedding_output(self):
266
    """Gets output of the embedding lookup (i.e., input to the transformer).
267

268
    Returns:
269
      float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
270
      to the output of the embedding layer, after summing the word
271
      embeddings with the positional embeddings and the token type embeddings,
272
      then performing layer normalization. This is the input to the transformer.
273
    """
274
    return self.embedding_output
275

276
  def get_embedding_table(self):
277
    return self.embedding_table
278

279

280
def gelu(x):
281
  """Gaussian Error Linear Unit.
282

283
  This is a smoother version of the RELU.
284
  Original paper: https://arxiv.org/abs/1606.08415
285
  Args:
286
    x: float Tensor to perform activation.
287

288
  Returns:
289
    `x` with the GELU activation applied.
290
  """
291
  cdf = 0.5 * (1.0 + tf.tanh(
292
      (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
293
  return x * cdf
294

295

296
def get_activation(activation_string):
297
  """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
298

299
  Args:
300
    activation_string: String name of the activation function.
301

302
  Returns:
303
    A Python function corresponding to the activation function. If
304
    `activation_string` is None, empty, or "linear", this will return None.
305
    If `activation_string` is not a string, it will return `activation_string`.
306

307
  Raises:
308
    ValueError: The `activation_string` does not correspond to a known
309
      activation.
310
  """
311

312
  # We assume that anything that"s not a string is already an activation
313
  # function, so we just return it.
314
  if not isinstance(activation_string, six.string_types):
315
    return activation_string
316

317
  if not activation_string:
318
    return None
319

320
  act = activation_string.lower()
321
  if act == "linear":
322
    return None
323
  elif act == "relu":
324
    return tf.nn.relu
325
  elif act == "gelu":
326
    return gelu
327
  elif act == "tanh":
328
    return tf.tanh
329
  else:
330
    raise ValueError("Unsupported activation: %s" % act)
331

332

333
def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
334
  """Compute the union of the current variables and checkpoint variables."""
335
  assignment_map = {}
336
  initialized_variable_names = {}
337

338
  name_to_variable = collections.OrderedDict()
339
  for var in tvars:
340
    name = var.name
341
    m = re.match("^(.*):\\d+$", name)
342
    if m is not None:
343
      name = m.group(1)
344
    name_to_variable[name] = var
345

346
  init_vars = tf.train.list_variables(init_checkpoint)
347

348
  assignment_map = collections.OrderedDict()
349
  for x in init_vars:
350
    (name, var) = (x[0], x[1])
351
    if name not in name_to_variable:
352
      continue
353
    assignment_map[name] = name
354
    initialized_variable_names[name] = 1
355
    initialized_variable_names[name + ":0"] = 1
356

357
  return (assignment_map, initialized_variable_names)
358

359

360
def dropout(input_tensor, dropout_prob):
361
  """Perform dropout.
362

363
  Args:
364
    input_tensor: float Tensor.
365
    dropout_prob: Python float. The probability of dropping out a value (NOT of
366
      *keeping* a dimension as in `tf.nn.dropout`).
367

368
  Returns:
369
    A version of `input_tensor` with dropout applied.
370
  """
371
  if dropout_prob is None or dropout_prob == 0.0:
372
    return input_tensor
373

374
  output = tf.nn.dropout(input_tensor, rate=dropout_prob)
375
  return output
376

377

378
def layer_norm(input_tensor, name=None):
379
  """Run layer normalization on the last dimension of the tensor."""
380
  return tf_slim_layers.layer_norm(
381
      inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
382

383

384
def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
385
  """Runs layer normalization followed by dropout."""
386
  output_tensor = layer_norm(input_tensor, name)
387
  output_tensor = dropout(output_tensor, dropout_prob)
388
  return output_tensor
389

390

391
def create_initializer(initializer_range=0.02):
392
  """Creates a `truncated_normal_initializer` with the given range."""
393
  return tf.truncated_normal_initializer(stddev=initializer_range)
394

395

396
def embedding_lookup(input_ids,
397
                     vocab_size,
398
                     embedding_size=128,
399
                     initializer_range=0.02,
400
                     word_embedding_name="word_embeddings",
401
                     use_one_hot_embeddings=False):
402
  """Looks up words embeddings for id tensor.
403

404
  Args:
405
    input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
406
      ids.
407
    vocab_size: int. Size of the embedding vocabulary.
408
    embedding_size: int. Width of the word embeddings.
409
    initializer_range: float. Embedding initialization range.
410
    word_embedding_name: string. Name of the embedding table.
411
    use_one_hot_embeddings: bool. If True, use one-hot method for word
412
      embeddings. If False, use `tf.nn.embedding_lookup()`.
413

414
  Returns:
415
    float Tensor of shape [batch_size, seq_length, embedding_size].
416
  """
417
  # This function assumes that the input is of shape [batch_size, seq_length,
418
  # num_inputs].
419
  #
420
  # If the input is a 2D tensor of shape [batch_size, seq_length], we
421
  # reshape to [batch_size, seq_length, 1].
422
  if input_ids.shape.ndims == 2:
423
    input_ids = tf.expand_dims(input_ids, axis=[-1])
424

425
  embedding_table = tf.get_variable(
426
      name=word_embedding_name,
427
      shape=[vocab_size, embedding_size],
428
      initializer=create_initializer(initializer_range))
429

430
  if use_one_hot_embeddings:
431
    flat_input_ids = tf.reshape(input_ids, [-1])
432
    one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
433
    output = tf.matmul(one_hot_input_ids, embedding_table)
434
  else:
435
    output = tf.nn.embedding_lookup(embedding_table, input_ids)
436

437
  input_shape = get_shape_list(input_ids)
438

439
  output = tf.reshape(output,
440
                      input_shape[0:-1] + [input_shape[-1] * embedding_size])
441
  return (output, embedding_table)
442

443

444
def embedding_postprocessor(input_tensor,
445
                            use_token_type=False,
446
                            token_type_ids=None,
447
                            token_type_vocab_size=16,
448
                            token_type_embedding_name="token_type_embeddings",
449
                            use_position_embeddings=True,
450
                            position_embedding_name="position_embeddings",
451
                            initializer_range=0.02,
452
                            max_position_embeddings=512,
453
                            dropout_prob=0.1):
454
  """Performs various post-processing on a word embedding tensor.
455

456
  Args:
457
    input_tensor: float Tensor of shape [batch_size, seq_length,
458
      embedding_size].
459
    use_token_type: bool. Whether to add embeddings for `token_type_ids`.
460
    token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
461
      Must be specified if `use_token_type` is True.
462
    token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
463
    token_type_embedding_name: string. The name of the embedding table variable
464
      for token type ids.
465
    use_position_embeddings: bool. Whether to add position embeddings for the
466
      position of each token in the sequence.
467
    position_embedding_name: string. The name of the embedding table variable
468
      for positional embeddings.
469
    initializer_range: float. Range of the weight initialization.
470
    max_position_embeddings: int. Maximum sequence length that might ever be
471
      used with this model. This can be longer than the sequence length of
472
      input_tensor, but cannot be shorter.
473
    dropout_prob: float. Dropout probability applied to the final output tensor.
474

475
  Returns:
476
    float tensor with same shape as `input_tensor`.
477

478
  Raises:
479
    ValueError: One of the tensor shapes or input values is invalid.
480
  """
481
  input_shape = get_shape_list(input_tensor, expected_rank=3)
482
  batch_size = input_shape[0]
483
  seq_length = input_shape[1]
484
  width = input_shape[2]
485

486
  output = input_tensor
487

488
  if use_token_type:
489
    if token_type_ids is None:
490
      raise ValueError("`token_type_ids` must be specified if"
491
                       "`use_token_type` is True.")
492
    token_type_table = tf.get_variable(
493
        name=token_type_embedding_name,
494
        shape=[token_type_vocab_size, width],
495
        initializer=create_initializer(initializer_range))
496
    # This vocab will be small so we always do one-hot here, since it is always
497
    # faster for a small vocabulary.
498
    flat_token_type_ids = tf.reshape(token_type_ids, [-1])
499
    one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
500
    token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
501
    token_type_embeddings = tf.reshape(token_type_embeddings,
502
                                       [batch_size, seq_length, width])
503
    output += token_type_embeddings
504

505
  if use_position_embeddings:
506
    # Create the variable outside the assertion to avoid TF2 compatibility
507
    # issues.
508
    full_position_embeddings = tf.get_variable(
509
        name=position_embedding_name,
510
        shape=[max_position_embeddings, width],
511
        initializer=create_initializer(initializer_range))
512

513
    assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
514
    with tf.control_dependencies([assert_op]):
515
      # Since the position embedding table is a learned variable, we create it
516
      # using a (long) sequence length `max_position_embeddings`. The actual
517
      # sequence length might be shorter than this, for faster training of
518
      # tasks that do not have long sequences.
519
      #
520
      # So `full_position_embeddings` is effectively an embedding table
521
      # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
522
      # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
523
      # perform a slice.
524
      position_embeddings = tf.slice(full_position_embeddings, [0, 0],
525
                                     [seq_length, -1])
526
      num_dims = len(output.shape.as_list())
527

528
      # Only the last two dimensions are relevant (`seq_length` and `width`), so
529
      # we broadcast among the first dimensions, which is typically just
530
      # the batch size.
531
      position_broadcast_shape = []
532
      for _ in range(num_dims - 2):
533
        position_broadcast_shape.append(1)
534
      position_broadcast_shape.extend([seq_length, width])
535
      position_embeddings = tf.reshape(position_embeddings,
536
                                       position_broadcast_shape)
537
      output += position_embeddings
538

539
  output = layer_norm_and_dropout(output, dropout_prob)
540
  return output
541

542

543
def create_attention_mask_from_input_mask(from_tensor, to_mask):
544
  """Create 3D attention mask from a 2D tensor mask.
545

546
  Args:
547
    from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
548
    to_mask: int32 Tensor of shape [batch_size, to_seq_length].
549

550
  Returns:
551
    float Tensor of shape [batch_size, from_seq_length, to_seq_length].
552
  """
553
  from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
554
  batch_size = from_shape[0]
555
  from_seq_length = from_shape[1]
556

557
  to_shape = get_shape_list(to_mask, expected_rank=2)
558
  to_seq_length = to_shape[1]
559

560
  to_mask = tf.cast(
561
      tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
562

563
  # We don't assume that `from_tensor` is a mask (although it could be). We
564
  # don't actually care if we attend *from* padding tokens (only *to* padding)
565
  # tokens so we create a tensor of all ones.
566
  #
567
  # `broadcast_ones` = [batch_size, from_seq_length, 1]
568
  broadcast_ones = tf.ones(
569
      shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
570

571
  # Here we broadcast along two dimensions to create the mask.
572
  mask = broadcast_ones * to_mask
573

574
  return mask
575

576

577
def dense_layer_3d(input_tensor,
578
                   num_attention_heads,
579
                   size_per_head,
580
                   initializer,
581
                   activation,
582
                   name=None):
583
  """A dense layer with 3D kernel.
584

585
  Args:
586
    input_tensor: float Tensor of shape [batch, seq_length, hidden_size].
587
    num_attention_heads: Number of attention heads.
588
    size_per_head: The size per attention head.
589
    initializer: Kernel initializer.
590
    activation: Actication function.
591
    name: The name scope of this layer.
592

593
  Returns:
594
    float logits Tensor.
595
  """
596

597
  last_dim = get_shape_list(input_tensor)[-1]
598

599
  with tf.variable_scope(name):
600
    w = tf.get_variable(
601
        name="kernel",
602
        shape=[last_dim, num_attention_heads * size_per_head],
603
        initializer=initializer)
604
    w = tf.reshape(w, [last_dim, num_attention_heads, size_per_head])
605
    b = tf.get_variable(
606
        name="bias",
607
        shape=[num_attention_heads * size_per_head],
608
        initializer=tf.zeros_initializer)
609
    b = tf.reshape(b, [num_attention_heads, size_per_head])
610
    ret = tf.einsum("abc,cde->abde", input_tensor, w)
611
    ret += b
612
    if activation is not None:
613
      return activation(ret)
614
    else:
615
      return ret
616

617

618
def dense_layer_3d_proj(input_tensor,
619
                        hidden_size,
620
                        num_attention_heads,
621
                        head_size,
622
                        initializer,
623
                        activation,
624
                        name=None):
625
  """A dense layer with 3D kernel for projection.
626

627
  Args:
628
    input_tensor: float Tensor of shape [batch,from_seq_length,
629
      num_attention_heads, size_per_head].
630
    hidden_size: The size of hidden layer.
631
    num_attention_heads: The size of output dimension.
632
    head_size: The size of head.
633
    initializer: Kernel initializer.
634
    activation: Actication function.
635
    name: The name scope of this layer.
636

637
  Returns:
638
    float logits Tensor.
639
  """
640
  with tf.variable_scope(name):
641
    w = tf.get_variable(
642
        name="kernel",
643
        shape=[hidden_size, hidden_size],
644
        initializer=initializer)
645
    w = tf.reshape(w, [num_attention_heads, head_size, hidden_size])
646
    b = tf.get_variable(
647
        name="bias", shape=[hidden_size], initializer=tf.zeros_initializer)
648

649
  ret = tf.einsum("BFNH,NHD->BFD", input_tensor, w)
650
  ret += b
651
  if activation is not None:
652
    return activation(ret)
653
  else:
654
    return ret
655

656

657
def dense_layer_2d(input_tensor,
658
                   output_size,
659
                   initializer,
660
                   activation,
661
                   name=None):
662
  """A dense layer with 2D kernel.
663

664
  Args:
665
    input_tensor: Float tensor with rank 3.
666
    output_size: The size of output dimension.
667
    initializer: Kernel initializer.
668
    activation: Actication function.
669
    name: The name scope of this layer.
670

671
  Returns:
672
    float logits Tensor.
673
  """
674
  last_dim = get_shape_list(input_tensor)[-1]
675
  with tf.variable_scope(name):
676
    w = tf.get_variable(
677
        name="kernel", shape=[last_dim, output_size], initializer=initializer)
678
    b = tf.get_variable(
679
        name="bias", shape=[output_size], initializer=tf.zeros_initializer)
680

681
  ret = tf.einsum("abc,cd->abd", input_tensor, w)
682
  ret += b
683
  if activation is not None:
684
    return activation(ret)
685
  else:
686
    return ret
687

688

689
def residual_attention_layer(from_tensor,
690
                             to_tensor,
691
                             attention_mask=None,
692
                             num_attention_heads=1,
693
                             size_per_head=512,
694
                             query_act=None,
695
                             key_act=None,
696
                             value_act=None,
697
                             attention_probs_dropout_prob=0.0,
698
                             initializer_range=0.02,
699
                             batch_size=None,
700
                             from_seq_length=None,
701
                             to_seq_length=None,
702
                             prev_attention=None,
703
                             num_prev_layers=0,
704
                             use_running_mean=False):
705
  r"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
706

707
  This is an implementation of multi-headed attention based on "Attention
708
  is all you Need" with a residual edge added to connect attention modules in
709
  adjacent layers. If `from_tensor` and `to_tensor` are the same, then
710
  this is self-attention. Each timestep in `from_tensor` attends to the
711
  corresponding sequence in `to_tensor`, and returns a fixed-with vector.
712

713
  This function first projects `from_tensor` into a "query" tensor and
714
  `to_tensor` into "key" and "value" tensors. These are (effectively) a list
715
  of tensors of length `num_attention_heads`, where each tensor is of shape
716
  [batch_size, seq_length, size_per_head].
717

718
  Then, the query and key tensors are dot-producted and scaled. Running sum of
719
  these dot-products are optionally rescaled (to be running mean) before they
720
  are softmaxed to obtain attention probabilities. The value tensors are then
721
  interpolated by these probabilities, then concatenated back to a single
722
  tensor and returned.
723

724
  In practice, the multi-headed attention are done with tf.einsum as follows:
725
    Input_tensor: [BFD]
726
    Wq, Wk, Wv: [DNH]
727
    Q:[BFNH] = einsum('BFD,DNH->BFNH', Input_tensor, Wq)
728
    K:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wk)
729
    V:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wv)
730
    attention_scores:[BNFT] = einsum('BFNH,BTNH>BNFT', Q, K) / sqrt(H)
731
    attention_logits:[BNFT] = \sum_{l=0}^{cur} attention_scores_l
732
    (optional) attention_logits:[BNFT] = attention_logits / (cur + 1)
733
    attention_probs:[BNFT] = softmax(attention_logits)
734
    context_layer:[BFNH] = einsum('BNFT,BTNH->BFNH', attention_probs, V)
735
    Wout:[DNH]
736
    Output:[BFD] = einsum('BFNH,DNH>BFD', context_layer, Wout)
737

738
  Args:
739
    from_tensor: float Tensor of shape [batch_size, from_seq_length,
740
      from_width].
741
    to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
742
    attention_mask: (optional) int32 Tensor of shape [batch_size,
743
      from_seq_length, to_seq_length]. The values should be 1 or 0. The
744
      attention scores will effectively be set to -infinity for any positions in
745
      the mask that are 0, and will be unchanged for positions that are 1.
746
    num_attention_heads: int. Number of attention heads.
747
    size_per_head: int. Size of each attention head.
748
    query_act: (optional) Activation function for the query transform.
749
    key_act: (optional) Activation function for the key transform.
750
    value_act: (optional) Activation function for the value transform.
751
    attention_probs_dropout_prob: (optional) float. Dropout probability of the
752
      attention probabilities.
753
    initializer_range: float. Range of the weight initializer.
754
    batch_size: (Optional) int. If the input is 2D, this might be the batch size
755
      of the 3D version of the `from_tensor` and `to_tensor`.
756
    from_seq_length: (Optional) If the input is 2D, this might be the seq length
757
      of the 3D version of the `from_tensor`.
758
    to_seq_length: (Optional) If the input is 2D, this might be the seq length
759
      of the 3D version of the `to_tensor`.
760
    prev_attention: (Optional) float Tensor of shape [batch_size,
761
      num_attention_heads, from_seq_length, to_seq_length]. Running sum of
762
      attention scores before the current layer.
763
    num_prev_layers: int. Number of previous layers that have contributed to
764
      `prev_attention`.
765
    use_running_mean: bool. Whether to softmax running mean instead of running
766
      sum of attention scores to compute attention probabilities. Running mean
767
      is found to be helpful for deep RealFormer models with 30+ layers.
768

769
  Returns:
770
    float Tensor of shape [batch_size, from_seq_length, num_attention_heads,
771
      size_per_head].
772
    float Tensor of shape [batch_size, num_attention_heads, from_seq_length,
773
      to_seq_length].
774

775
  Raises:
776
    ValueError: Any of the arguments or tensor shapes are invalid.
777
  """
778
  from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
779
  to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
780

781
  if len(from_shape) != len(to_shape):
782
    raise ValueError(
783
        "The rank of `from_tensor` must match the rank of `to_tensor`.")
784

785
  if len(from_shape) == 3:
786
    batch_size = from_shape[0]
787
    from_seq_length = from_shape[1]
788
    to_seq_length = to_shape[1]
789
  elif len(from_shape) == 2:
790
    if (batch_size is None or from_seq_length is None or to_seq_length is None):
791
      raise ValueError(
792
          "When passing in rank 2 tensors to attention_layer, the values "
793
          "for `batch_size`, `from_seq_length`, and `to_seq_length` "
794
          "must all be specified.")
795

796
  # Scalar dimensions referenced here:
797
  #   B = batch size (number of sequences)
798
  #   F = `from_tensor` sequence length
799
  #   T = `to_tensor` sequence length
800
  #   N = `num_attention_heads`
801
  #   H = `size_per_head`
802

803
  # `query_layer` = [B, F, N, H]
804
  query_layer = dense_layer_3d(from_tensor, num_attention_heads, size_per_head,
805
                               create_initializer(initializer_range), query_act,
806
                               "query")
807

808
  # `key_layer` = [B, T, N, H]
809
  key_layer = dense_layer_3d(to_tensor, num_attention_heads, size_per_head,
810
                             create_initializer(initializer_range), key_act,
811
                             "key")
812

813
  # `value_layer` = [B, T, N, H]
814
  value_layer = dense_layer_3d(to_tensor, num_attention_heads, size_per_head,
815
                               create_initializer(initializer_range), value_act,
816
                               "value")
817

818
  # Take the dot product between "query" and "key" to get the raw
819
  # attention scores.
820
  attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_layer, query_layer)
821
  attention_scores = tf.multiply(attention_scores,
822
                                 1.0 / math.sqrt(float(size_per_head)))
823

824
  cur_attention = attention_scores
825
  if prev_attention is not None:
826
    cur_attention += prev_attention
827

828
  attention_logits = cur_attention
829
  if use_running_mean:
830
    attention_logits /= (num_prev_layers + 1.0)
831

832
  if attention_mask is not None:
833
    # `attention_mask` = [B, 1, F, T]
834
    attention_mask = tf.expand_dims(attention_mask, axis=[1])
835

836
    # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
837
    # masked positions, this operation will create a tensor which is 0.0 for
838
    # positions we want to attend and -10000.0 for masked positions.
839
    adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
840

841
    # Since we are adding it to the raw logits before the softmax, this is
842
    # effectively the same as removing these entirely.
843
    attention_logits += adder
844

845
  # Normalize the attention logits to probabilities.
846
  # `attention_probs` = [B, N, F, T]
847
  attention_probs = tf.nn.softmax(attention_logits)
848

849
  # This is actually dropping out entire tokens to attend to, which might
850
  # seem a bit unusual, but is taken from the original Transformer paper.
851
  attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
852

853
  # `context_layer` = [B, F, N, H]
854
  context_layer = tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_layer)
855

856
  return context_layer, cur_attention
857

858

859
def realformer_model(input_tensor,
860
                     attention_mask=None,
861
                     hidden_size=768,
862
                     num_hidden_layers=12,
863
                     num_attention_heads=12,
864
                     intermediate_size=3072,
865
                     intermediate_act_fn=gelu,
866
                     hidden_dropout_prob=0.1,
867
                     attention_probs_dropout_prob=0.1,
868
                     initializer_range=0.02,
869
                     use_running_mean=False,
870
                     do_return_all_layers=False):
871
  """Multi-headed, multi-layer RealFormer from https://arxiv.org/abs/2012.11747.
872

873
  Args:
874
    input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
875
    attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
876
      seq_length], with 1 for positions that can be attended to and 0 in
877
      positions that should not be.
878
    hidden_size: int. Hidden size of the Transformer.
879
    num_hidden_layers: int. Number of layers (blocks) in the Transformer.
880
    num_attention_heads: int. Number of attention heads in the Transformer.
881
    intermediate_size: int. The size of the "intermediate" (a.k.a., feed
882
      forward) layer.
883
    intermediate_act_fn: function. The non-linear activation function to apply
884
      to the output of the intermediate/feed-forward layer.
885
    hidden_dropout_prob: float. Dropout probability for the hidden layers.
886
    attention_probs_dropout_prob: float. Dropout probability of the attention
887
      probabilities.
888
    initializer_range: float. Range of the initializer (stddev of truncated
889
      normal).
890
    use_running_mean: Whether to softmax running mean instead of running sum of
891
      attention scores to compute attention probabilities. Running mean is found
892
      to be helpful for deep RealFormer models with 30+ layers.
893
    do_return_all_layers: Whether to also return all layers or just the final
894
      layer.
895

896
  Returns:
897
    float Tensor of shape [batch_size, seq_length, hidden_size], the final
898
    hidden layer of the Transformer.
899

900
  Raises:
901
    ValueError: A Tensor shape or parameter is invalid.
902
  """
903
  if hidden_size % num_attention_heads != 0:
904
    raise ValueError(
905
        "The hidden size (%d) is not a multiple of the number of attention "
906
        "heads (%d)" % (hidden_size, num_attention_heads))
907

908
  attention_head_size = int(hidden_size / num_attention_heads)
909
  input_shape = get_shape_list(input_tensor, expected_rank=3)
910
  input_width = input_shape[2]
911

912
  # The Transformer performs sum residuals on all layers so the input needs
913
  # to be the same as the hidden size.
914
  if input_width != hidden_size:
915
    raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
916
                     (input_width, hidden_size))
917

918
  prev_output = input_tensor
919
  prev_attention = None
920
  all_layer_outputs = []
921
  for layer_idx in range(num_hidden_layers):
922
    with tf.variable_scope("layer_%d" % layer_idx):
923
      layer_input = prev_output
924

925
      with tf.variable_scope("attention"):
926
        with tf.variable_scope("self"):
927
          attention_output, prev_attention = residual_attention_layer(
928
              from_tensor=layer_input,
929
              to_tensor=layer_input,
930
              attention_mask=attention_mask,
931
              num_attention_heads=num_attention_heads,
932
              size_per_head=attention_head_size,
933
              attention_probs_dropout_prob=attention_probs_dropout_prob,
934
              initializer_range=initializer_range,
935
              prev_attention=prev_attention,
936
              num_prev_layers=layer_idx,
937
              use_running_mean=use_running_mean)
938

939
        # Run a linear projection of `hidden_size` then add a residual
940
        # with `layer_input`.
941
        with tf.variable_scope("output"):
942
          attention_output = dense_layer_3d_proj(
943
              attention_output, hidden_size,
944
              num_attention_heads, attention_head_size,
945
              create_initializer(initializer_range), None, "dense")
946
          attention_output = dropout(attention_output, hidden_dropout_prob)
947
          attention_output = layer_norm(attention_output + layer_input)
948

949
      # The activation is only applied to the "intermediate" hidden layer.
950
      with tf.variable_scope("intermediate"):
951
        intermediate_output = dense_layer_2d(
952
            attention_output, intermediate_size,
953
            create_initializer(initializer_range), intermediate_act_fn, "dense")
954

955
      # Down-project back to `hidden_size` then add the residual.
956
      with tf.variable_scope("output"):
957
        layer_output = dense_layer_2d(intermediate_output, hidden_size,
958
                                      create_initializer(initializer_range),
959
                                      None, "dense")
960
        layer_output = dropout(layer_output, hidden_dropout_prob)
961
        layer_output = layer_norm(layer_output + attention_output)
962
        prev_output = layer_output
963
        all_layer_outputs.append(layer_output)
964

965
  if do_return_all_layers:
966
    return all_layer_outputs
967
  else:
968
    return all_layer_outputs[-1]
969

970

971
def get_shape_list(tensor, expected_rank=None, name=None):
972
  """Returns a list of the shape of tensor, preferring static dimensions.
973

974
  Args:
975
    tensor: A tf.Tensor object to find the shape of.
976
    expected_rank: (optional) int. The expected rank of `tensor`. If this is
977
      specified and the `tensor` has a different rank, and exception will be
978
      thrown.
979
    name: Optional name of the tensor for the error message.
980

981
  Returns:
982
    A list of dimensions of the shape of tensor. All static dimensions will
983
    be returned as python integers, and dynamic dimensions will be returned
984
    as tf.Tensor scalars.
985
  """
986
  if name is None:
987
    # Tensor.name is not supported in Eager mode.
988
    if tf.executing_eagerly():
989
      name = "get_shape_list"
990
    else:
991
      name = tensor.name
992

993
  if expected_rank is not None:
994
    assert_rank(tensor, expected_rank, name)
995

996
  shape = tensor.shape.as_list()
997

998
  non_static_indexes = []
999
  for (index, dim) in enumerate(shape):
1000
    if dim is None:
1001
      non_static_indexes.append(index)
1002

1003
  if not non_static_indexes:
1004
    return shape
1005

1006
  dyn_shape = tf.shape(tensor)
1007
  for index in non_static_indexes:
1008
    shape[index] = dyn_shape[index]
1009
  return shape
1010

1011

1012
def reshape_to_matrix(input_tensor):
1013
  """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
1014
  ndims = input_tensor.shape.ndims
1015
  if ndims < 2:
1016
    raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
1017
                     (input_tensor.shape))
1018
  if ndims == 2:
1019
    return input_tensor
1020

1021
  width = input_tensor.shape[-1]
1022
  output_tensor = tf.reshape(input_tensor, [-1, width])
1023
  return output_tensor
1024

1025

1026
def reshape_from_matrix(output_tensor, orig_shape_list):
1027
  """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
1028
  if len(orig_shape_list) == 2:
1029
    return output_tensor
1030

1031
  output_shape = get_shape_list(output_tensor)
1032

1033
  orig_dims = orig_shape_list[0:-1]
1034
  width = output_shape[-1]
1035

1036
  return tf.reshape(output_tensor, orig_dims + [width])
1037

1038

1039
def assert_rank(tensor, expected_rank, name=None):
1040
  """Raises an exception if the tensor rank is not of the expected rank.
1041

1042
  Args:
1043
    tensor: A tf.Tensor to check the rank of.
1044
    expected_rank: Python integer or list of integers, expected rank.
1045
    name: Optional name of the tensor for the error message.
1046

1047
  Raises:
1048
    ValueError: If the expected shape doesn't match the actual shape.
1049
  """
1050
  if name is None:
1051
    name = tensor.name
1052

1053
  expected_rank_dict = {}
1054
  if isinstance(expected_rank, six.integer_types):
1055
    expected_rank_dict[expected_rank] = True
1056
  else:
1057
    for x in expected_rank:
1058
      expected_rank_dict[x] = True
1059

1060
  actual_rank = tensor.shape.ndims
1061
  if actual_rank not in expected_rank_dict:
1062
    scope_name = tf.get_variable_scope().name
1063
    raise ValueError(
1064
        "For the tensor `%s` in scope `%s`, the actual rank "
1065
        "`%d` (shape = %s) is not equal to the expected rank `%s`" %
1066
        (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
1067

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.