google-research

Форк
0
417 строк · 17.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Sparse attention for the transformer."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
from tensor2tensor.layers import common_attention
22
from tensor2tensor.layers import common_layers
23

24
import tensorflow.compat.v1 as tf
25

26
from state_of_sparsity.sparse_transformer.layers import common_sparse
27
from tensorflow.contrib.model_pruning.python import pruning  # pylint: disable=g-direct-tensorflow-import
28
from tensorflow.python.ops import inplace_ops  # pylint: disable=g-direct-tensorflow-import
29

30

31
def compute_attention_component(antecedent,
32
                                total_depth,
33
                                filter_width=1,
34
                                padding="VALID",
35
                                name="c",
36
                                vars_3d_num_heads=0,
37
                                sparsity_technique=None,
38
                                threshold=3.0,
39
                                training=True,
40
                                clip_alpha=None,
41
                                initial_sparsity=None,
42
                                split_heads=False,
43
                                num_heads=None):
44
  """Computes attention compoenent (query, key or value).
45

46
  Args:
47
    antecedent: a Tensor with shape [batch, length, channels]
48
    total_depth: an integer
49
    filter_width: An integer specifying how wide you want the attention
50
      component to be.
51
    padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
52
    name: a string specifying scope name.
53
    vars_3d_num_heads: an optional integer (if we want to use 3d variables)
54
    sparsity_technique: technique used for sparsifying weights.
55
    threshold: log alpha threshold used for evaluation with variational dropout.
56
    training: whether model is being trained or not.
57
    clip_alpha: alpha clipping threshold for variational dropout.
58
    initial_sparsity: initial sparsity level for lottery ticket &
59
      scratch experiments.
60
    split_heads: Whether to prune each head separately.
61
    num_heads: The number of heads in the attention module.
62

63
  Returns:
64
    c : [batch, length, depth] tensor
65
  """
66
  # We don't support 3d attention variables or filter_width > 1 with sparsity
67
  # techniques
68
  assert not sparsity_technique or (not vars_3d_num_heads and filter_width == 1)
69

70
  if vars_3d_num_heads > 0:
71
    assert filter_width == 1
72
    input_depth = antecedent.get_shape().as_list()[-1]
73
    depth_per_head = total_depth // vars_3d_num_heads
74
    initializer_stddev = input_depth ** -0.5
75
    if "q" in name:
76
      initializer_stddev *= depth_per_head ** -0.5
77
    var = tf.get_variable(
78
        name, [input_depth,
79
               vars_3d_num_heads,
80
               total_depth // vars_3d_num_heads],
81
        initializer=tf.random_normal_initializer(stddev=initializer_stddev))
82
    var = tf.cast(var, antecedent.dtype)
83
    var = tf.reshape(var, [input_depth, total_depth])
84
    return tf.tensordot(antecedent, var, axes=1)
85
  if filter_width == 1:
86
    if sparsity_technique:
87
      if split_heads:
88
        # Prune each heads weights separately so that they are free
89
        # to have different weight magnitude distributions.
90
        if num_heads is None:
91
          raise ValueError("`num_heads` must be set for split head pruning.")
92
        if total_depth % num_heads != 0:
93
          raise ValueError("`total_depth` must be divisible by `num_heads`.")
94
        input_depth = antecedent.get_shape().as_list()[-1]
95
        depth_per_head = int(total_depth / num_heads)
96
        masked_head_weights = []
97
        for head_id in range(num_heads):
98
          head_name = name + "_shard_{}".format(head_id)
99
          with tf.variable_scope(head_name) as vs:
100
            head_weights = tf.get_variable(
101
                "kernel", [input_depth, depth_per_head])
102
            masked_head_weights.append(pruning.apply_mask(head_weights, vs))
103
        component_weights = tf.concat(masked_head_weights, axis=1)
104

105
        # compute the full component result
106
        return tf.tensordot(antecedent, component_weights, axes=1)
107
      else:
108
        return common_sparse.dense(
109
            antecedent,
110
            total_depth,
111
            use_bias=False,
112
            sparsity_technique=sparsity_technique,
113
            threshold=threshold,
114
            training=training,
115
            clip_alpha=clip_alpha,
116
            name=name,
117
            initial_sparsity=initial_sparsity)
118
    else:
119
      return common_layers.dense(
120
          antecedent, total_depth, use_bias=False, name=name)
121
  else:
122
    return common_layers.conv1d(
123
        antecedent, total_depth, filter_width, padding=padding, name=name)
124

125

126
def compute_qkv(query_antecedent,
127
                memory_antecedent,
128
                total_key_depth,
129
                total_value_depth,
130
                q_filter_width=1,
131
                kv_filter_width=1,
132
                q_padding="VALID",
133
                kv_padding="VALID",
134
                vars_3d_num_heads=0,
135
                sparsity_technique=None,
136
                threshold=3.0,
137
                training=True,
138
                clip_alpha=None,
139
                initial_sparsity=None,
140
                split_heads=False,
141
                num_heads=None):
142
  """Computes query, key and value.
143

144
  Args:
145
    query_antecedent: a Tensor with shape [batch, length_q, channels]
146
    memory_antecedent: a Tensor with shape [batch, length_m, channels]
147
    total_key_depth: an integer
148
    total_value_depth: an integer
149
    q_filter_width: An integer specifying how wide you want the query to be.
150
    kv_filter_width: An integer specifying how wide you want the keys and values
151
    to be.
152
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
153
    kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
154
    vars_3d_num_heads: an optional (if we want to use 3d variables)
155
    sparsity_technique: technique used for sparsifying weights.
156
    threshold: log alpha threshold used for evaluation with variational dropout.
157
    training: whether model is being trained or not.
158
    clip_alpha: alpha clipping threshold for variational dropout.
159
    initial_sparsity: initial sparsity level for lottery ticket &
160
      scratch experiments.
161
    split_heads: Whether to prune each head separately.
162
    num_heads: The number of heads in the attention module.
163

164
  Returns:
165
    q, k, v : [batch, length, depth] tensors
166
  """
167
  if memory_antecedent is None:
168
    memory_antecedent = query_antecedent
169
  q = compute_attention_component(
170
      query_antecedent,
171
      total_key_depth,
172
      q_filter_width,
173
      q_padding,
174
      "q",
175
      vars_3d_num_heads=vars_3d_num_heads,
176
      sparsity_technique=sparsity_technique,
177
      threshold=threshold,
178
      training=training,
179
      clip_alpha=clip_alpha,
180
      initial_sparsity=initial_sparsity,
181
      split_heads=split_heads,
182
      num_heads=num_heads)
183
  k = compute_attention_component(
184
      memory_antecedent,
185
      total_key_depth,
186
      kv_filter_width,
187
      kv_padding,
188
      "k",
189
      vars_3d_num_heads=vars_3d_num_heads,
190
      sparsity_technique=sparsity_technique,
191
      threshold=threshold,
192
      training=training,
193
      clip_alpha=clip_alpha,
194
      initial_sparsity=initial_sparsity,
195
      split_heads=split_heads,
196
      num_heads=num_heads)
197
  v = compute_attention_component(
198
      memory_antecedent,
199
      total_value_depth,
200
      kv_filter_width,
201
      kv_padding,
202
      "v",
203
      vars_3d_num_heads=vars_3d_num_heads,
204
      sparsity_technique=sparsity_technique,
205
      threshold=threshold,
206
      training=training,
207
      clip_alpha=clip_alpha,
208
      initial_sparsity=initial_sparsity,
209
      split_heads=split_heads,
210
      num_heads=num_heads)
211
  return q, k, v
212

213

214
def multihead_attention(query_antecedent,
215
                        memory_antecedent,
216
                        bias,
217
                        total_key_depth,
218
                        total_value_depth,
219
                        output_depth,
220
                        num_heads,
221
                        dropout_rate,
222
                        attention_type="dot_product",
223
                        image_shapes=None,
224
                        q_filter_width=1,
225
                        kv_filter_width=1,
226
                        q_padding="VALID",
227
                        kv_padding="VALID",
228
                        cache=None,
229
                        name="multihead_attention",
230
                        save_weights_to=None,
231
                        make_image_summary=True,
232
                        dropout_broadcast_dims=None,
233
                        vars_3d=False,
234
                        sparsity_technique=None,
235
                        threshold=3.0,
236
                        training=True,
237
                        clip_alpha=None,
238
                        initial_sparsity=None,
239
                        split_heads=False,
240
                        **kwargs):
241
  """Multihead scaled-dot-product attention with input/output transformations.
242

243
  Args:
244
    query_antecedent: a Tensor with shape [batch, length_q, channels]
245
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
246
    bias: bias Tensor (see attention_bias())
247
    total_key_depth: an integer
248
    total_value_depth: an integer
249
    output_depth: an integer
250
    num_heads: an integer dividing total_key_depth and total_value_depth
251
    dropout_rate: a floating point number
252
    attention_type: a string, either "dot_product", "dot_product_relative",
253
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
254
                    "unmasked_dilated_1d", graph, or any attention function
255
                    with the signature (query, key, value, **kwargs)
256
    image_shapes: optional tuple of integer scalars.
257
                  see comments for attention_image_summary()
258
    q_filter_width: An integer specifying how wide you want the query to be.
259
    kv_filter_width: An integer specifying how wide you want the keys and values
260
                     to be.
261
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
262
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
263
               no padding.
264
    cache: dict containing Tensors which are the results of previous
265
           attentions, used for fast decoding. Expects the dict to contrain two
266
           keys ('k' and 'v'), for the initial call the values for these keys
267
           should be empty Tensors of the appropriate shape.
268
               'k' [batch_size, 0, key_channels]
269
               'v' [batch_size, 0, value_channels]
270
    name: an optional string.
271
    save_weights_to: an optional dictionary to capture attention weights
272
      for vizualization; the weights tensor will be appended there under
273
      a string key created from the variable scope (including name).
274
    make_image_summary: Whether to make an attention image summary.
275
    dropout_broadcast_dims:  an optional list of integers less than 4
276
      specifying in which dimensions to broadcast the dropout decisions.
277
      saves memory.
278
    vars_3d: use 3-dimensional variables for input/output transformations
279
    sparsity_technique: technique used for sparsifying weights.
280
    threshold: log alpha threshold used for evaluation with variational dropout.
281
    training: whether model is being trained or not.
282
    clip_alpha: alpha clipping threshold for variational dropout.
283
    initial_sparsity: initial sparsity level for lottery ticket &
284
      scratch experiments.
285
    split_heads: Whether to prune each head separately.
286
    **kwargs (dict): Parameters for the attention function
287

288
  Caching:
289
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
290
    the caching assumes that the bias contains future masking.
291

292
    The caching works by saving all the previous key and value values so that
293
    you are able to send just the last query location to this attention
294
    function. I.e. if the cache dict is provided it assumes the query is of the
295
    shape [batch_size, 1, hidden_dim] rather than the full memory.
296

297
  Returns:
298
    The result of the attention transformation. The output shape is
299
        [batch_size, length_q, hidden_dim]
300
    unless the cache dict is provided in which case only the last memory
301
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
302
    Optionally returns an additional loss parameters (ex: load balance loss for
303
    the experts) returned by the attention_type function.
304

305
  Raises:
306
    ValueError: if the key depth or value depth are not divisible by the
307
      number of attention heads.
308
  """
309
  if total_key_depth % num_heads != 0:
310
    raise ValueError("Key depth (%d) must be divisible by the number of "
311
                     "attention heads (%d)." % (total_key_depth, num_heads))
312
  if total_value_depth % num_heads != 0:
313
    raise ValueError("Value depth (%d) must be divisible by the number of "
314
                     "attention heads (%d)." % (total_value_depth, num_heads))
315
  if vars_3d:
316
    raise ValueError("3d attention variables not supported.")
317
  if attention_type != "dot_product":
318
    raise ValueError(
319
        "Sparse multihead attention only supports dot_product attention.")
320

321
  vars_3d_num_heads = 0
322
  with tf.variable_scope(
323
      name,
324
      default_name="multihead_attention",
325
      values=[query_antecedent, memory_antecedent]):
326

327
    if cache is None or memory_antecedent is None:
328
      q, k, v = compute_qkv(query_antecedent, memory_antecedent,
329
                            total_key_depth, total_value_depth, q_filter_width,
330
                            kv_filter_width, q_padding, kv_padding,
331
                            vars_3d_num_heads=vars_3d_num_heads,
332
                            sparsity_technique=sparsity_technique,
333
                            threshold=threshold,
334
                            training=training,
335
                            clip_alpha=clip_alpha,
336
                            initial_sparsity=initial_sparsity,
337
                            split_heads=split_heads,
338
                            num_heads=num_heads)
339
    if cache is not None:
340
      if bias is None:
341
        raise ValueError("Bias required for caching. See function docstring "
342
                         "for details.")
343

344
      if memory_antecedent is not None:
345
        # Encoder-Decoder Attention Cache
346
        q = compute_attention_component(query_antecedent, total_key_depth,
347
                                        q_filter_width, q_padding, "q",
348
                                        vars_3d_num_heads=vars_3d_num_heads,
349
                                        sparsity_technique=sparsity_technique,
350
                                        threshold=threshold,
351
                                        training=training,
352
                                        clip_alpha=clip_alpha,
353
                                        initial_sparsity=initial_sparsity,
354
                                        split_heads=split_heads,
355
                                        num_heads=num_heads)
356
        k = cache["k_encdec"]
357
        v = cache["v_encdec"]
358
      else:
359
        k = common_attention.split_heads(k, num_heads)
360
        v = common_attention.split_heads(v, num_heads)
361
        decode_loop_step = kwargs.get("decode_loop_step")
362
        if decode_loop_step is None:
363
          k = cache["k"] = tf.concat([cache["k"], k], axis=2)
364
          v = cache["v"] = tf.concat([cache["v"], v], axis=2)
365
        else:
366
          # Inplace update is required for inference on TPU.
367
          # Inplace_ops only supports inplace_update on the first dimension.
368
          # The performance of current implementation is better than updating
369
          # the tensor by adding the result of matmul(one_hot,
370
          # update_in_current_step)
371
          tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
372
          tmp_k = inplace_ops.alias_inplace_update(
373
              tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
374
          k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
375
          tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
376
          tmp_v = inplace_ops.alias_inplace_update(
377
              tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
378
          v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])
379

380
    q = common_attention.split_heads(q, num_heads)
381
    if cache is None:
382
      k = common_attention.split_heads(k, num_heads)
383
      v = common_attention.split_heads(v, num_heads)
384

385
    key_depth_per_head = total_key_depth // num_heads
386
    if not vars_3d:
387
      q *= key_depth_per_head**-0.5
388

389
    # compute the attention
390
    x = common_attention.dot_product_attention(
391
        q, k, v, bias, dropout_rate, image_shapes,
392
        save_weights_to=save_weights_to,
393
        make_image_summary=make_image_summary,
394
        dropout_broadcast_dims=dropout_broadcast_dims)
395
    x = common_attention.combine_heads(x)
396

397
    # Set last dim specifically.
398
    x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])
399

400
    if sparsity_technique:
401
      x = common_sparse.dense(
402
          x,
403
          output_depth,
404
          use_bias=False,
405
          sparsity_technique=sparsity_technique,
406
          threshold=threshold,
407
          training=training,
408
          clip_alpha=clip_alpha,
409
          name="output_transform",
410
          initial_sparsity=initial_sparsity)
411
    else:
412
      x = common_layers.dense(
413
          x,
414
          output_depth,
415
          use_bias=False,
416
          name="output_transform")
417
    return x
418

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.