google-research

Форк
0
/
custom_attention.py 
342 строки · 12.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Custom attention modules for Charformer.
17
"""
18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import math
23
import mesh_tensorflow as mtf
24
from mesh_tensorflow.transformer import attention
25
from mesh_tensorflow.transformer.transformer import sinusoid_positional_embedding_weights
26
import tensorflow.compat.v1 as tf
27

28

29
def local_attention_1d(q,
30
                       k,
31
                       v,
32
                       length_dim,
33
                       key_dim,
34
                       value_dim,
35
                       fully_autoregressive=True,
36
                       length_dim_num_splits=1,
37
                       radius=128,
38
                       sequence_id=1,
39
                       write_priority=None,
40
                       read_priority=None,
41
                       attention_kwargs=None,
42
                       context=None):
43
  """Attention to the a neighborood around the source.
44

45
  If fully_autoregressive, then query position p can only see memory positions
46
  in the range (p - radius, p].
47

48
  If not fully_autoregressive, then query position p can only see memory
49
  positions in the range (p - window_size, p + radius].
50

51
  In addition, if write_priority and read_priority are provided, then attention
52
  is limited to position pairs where
53
  read_priority[query position] >= write_priority[memory position]
54

55
  Args:
56
    q: a Tensor containing length_dim
57
    k: a Tensor containing length_dim
58
    v: an optional Tensor containing length_dim.  If none then uses v=k.
59
    length_dim: a Dimension
60
    key_dim: a Dimension (the channels dimension of q and k)
61
    value_dim: a Dimension (the channels dimension of v)
62
    fully_autoregressive: a boolean
63
    length_dim_num_splits: an optional integer indicating how many ways the
64
      length dimension is split
65
    radius: an integer
66
    sequence_id: a Tensor or an integer
67
    write_priority: an optional Tensor containing length_dim
68
    read_priority: an optional Tensor containing length_dim
69
    attention_kwargs: optional keyword arguments for attention()
70
    context: optional context.
71

72
  Returns:
73
    a Tensor with the shape x.shape - key_dim + value_dim
74

75
  Raises:
76
    ValueError: if channels or depth don't match.
77
  """
78
  # Choose a suitable block size.
79
  # We choose the greatest divisor of length_per_split less than or equal
80
  # to max(window_size, 128)
81
  tf.logging.info(attention_kwargs)
82
  length_per_split = length_dim.size // length_dim_num_splits
83
  block_length = max(radius, 128)
84
  while length_per_split % block_length != 0:
85
    block_length -= 1
86
  query_block_length = mtf.Dimension("query_block_length", block_length)
87
  memory_block_length = mtf.Dimension("memory_block_length", block_length)
88
  # The num_blocks dimension gets the same name as the length dimension,
89
  # so it will be split in the same way.
90
  num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length)
91
  def _reshape_query(x):
92
    return mtf.replace_dimensions(
93
        x, length_dim, [num_blocks, query_block_length])
94
  def _reshape_memory(x):
95
    x = mtf.replace_dimensions(
96
        x, length_dim, [num_blocks, memory_block_length])
97
    return (mtf.left_halo_exchange if fully_autoregressive
98
            else mtf.halo_exchange)(
99
                x, num_blocks, memory_block_length, radius)
100
  q = _reshape_query(q)
101
  k = _reshape_memory(k)
102
  if v:
103
    v = _reshape_memory(v)
104
  else:
105
    v = k
106
  if sequence_id is None:
107
    sequence_id = 1
108
  if (not isinstance(sequence_id, mtf.Tensor) or
109
      length_dim not in sequence_id.shape.dims):
110
    sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
111
  q_sequence_id = _reshape_query(sequence_id)
112
  m_sequence_id = _reshape_memory(sequence_id)
113
  pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
114
  q_pos = _reshape_query(pos)
115
  m_pos = _reshape_memory(pos)
116

117
  padded_memory_block_length = mtf.Dimension(
118
      "memory_block_length",
119
      (1 if fully_autoregressive else 2) * radius + block_length)
120

121
  relative_position = m_pos - q_pos
122
  visible = mtf.equal(q_sequence_id, m_sequence_id)
123
  visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius))
124
  visible = mtf.logical_and(visible, mtf.less_equal(
125
      relative_position, 0 if fully_autoregressive else radius))
126
  if read_priority is not None:
127
    write_priority = _reshape_memory(write_priority)
128
    read_priority = _reshape_query(read_priority)
129
    visible = mtf.logical_and(
130
        visible, mtf.greater_equal(read_priority, write_priority))
131

132
  bias = attention.visibility_mask_to_attention_bias(visible, q.dtype)
133
  o = attention.attention(q, k, v, padded_memory_block_length, key_dim,
134
                          value_dim, bias, context=context,
135
                          **attention_kwargs)
136
  return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
137

138

139
def gradient_based_subword_tokenization(x,
140
                                        length_dim,
141
                                        max_subword_length=4,
142
                                        downsample=None,
143
                                        use_offsets=False,
144
                                        consider_chars_as_blocks=False,
145
                                        use_block_pos_embedding=False,
146
                                        share_block_kernel=False,
147
                                        memory_embeddings=0,
148
                                        context=None,
149
                                        block_mixing_mode=None,
150
                                        activation="softmax",
151
                                        downsample_function="mean"):
152
  """Implements GBSWT from Charformer.
153

154
  Args:
155
    x: a Tensor containing length_dim
156
    length_dim: a Dimension
157
    max_subword_length: integer
158
    downsample: integer.
159
    use_offsets: boolean.
160
    consider_chars_as_blocks: boolean.
161
    use_block_pos_embedding: boolean.
162
    share_block_kernel: boolean.
163
    memory_embeddings: integer.
164
    context: Context.
165
    block_mixing_mode: Str for block mixing.
166
    activation: Str for block ranking.
167
    downsample_function: Str, supports mean/linformer for now.
168

169
  Returns:
170
    a Tensor with the same shape as x.
171

172
  Raises:
173
    ValueError: if channels or depth don't match.
174
  """
175
  # don't use this for now.
176
  del max_subword_length
177
  del memory_embeddings
178
  all_blocks = []
179
  all_scores = []
180
  tf.logging.info("GSW block layer")
181

182
  def _tile(x, n, tile_dim):
183
    # Simple tile function in MTF.
184
    return mtf.concat([x] * n, tile_dim.name)
185

186
  def _repeat(x, n, repeat_dim):
187
    # repeat function in MTF
188
    tmp_dim = mtf.Dimension("tmp", 1)
189
    expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
190
    x = mtf.reshape(x, expand_shape)
191
    x = _tile(x, n, tmp_dim)
192
    output_shape = []
193
    for dim in x.shape.dims:
194
      if dim.name == "tmp":
195
        continue
196
      if dim.name == repeat_dim.name:
197
        dim = mtf.Dimension(dim.name, dim.size * n)
198
      output_shape.append(dim)
199
    output_shape = mtf.Shape(output_shape)
200
    x = mtf.reshape(x, output_shape)
201
    return x
202

203
  def _combined_dim(dims):
204
    return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)
205

206
  # compute all subword blocks
207
  # TODO(yitay): handle offsets to get all blocks
208
  if activation == "sigtanh":
209
    # one score for sigmoid
210
    tmp_dim = mtf.Dimension("block_score", 2)
211
  else:
212
    tmp_dim = mtf.Dimension("block_score", 1)
213

214
  model_dim = x.shape[-1]
215
  subword_blocks_width = [2, 3, 4]
216

217
  if consider_chars_as_blocks:
218
    subword_blocks_width += [1]
219

220
  if share_block_kernel:
221
    block_kernel_shape = mtf.Shape([model_dim, tmp_dim])
222
    block_kernel = mtf.get_variable(
223
        x.mesh, "block_kernel", block_kernel_shape, initializer=None,
224
        dtype=context.variable_dtype)
225
  else:
226
    block_kernel = None
227

228
  for subword_len in subword_blocks_width:
229
    if use_block_pos_embedding:
230
      # this is turn off by default. It is meant to support cases like
231
      # parameterized pooling or other features.
232
      block_len_dim = mtf.Dimension(length_dim.name, subword_len)
233
      # TODO(vqtran): Consider other positional embeddings.
234
      block_pos_emb = sinusoid_positional_embedding_weights(
235
          context.mesh, block_len_dim, x.shape[-1],
236
          context.variable_dtype.activation_dtype)
237
      block_pos_emb = _repeat(block_pos_emb,
238
                              math.ceil(length_dim.size / float(subword_len)),
239
                              block_len_dim)
240
    if use_offsets:
241
      offset_space = subword_len
242
    else:
243
      offset_space = 1
244
    for offsets in range(offset_space):
245
      if offsets > 0:
246
        xoff = mtf.shift(x, offsets, length_dim, wrap=False)
247
        if use_block_pos_embedding:
248
          block_pos_emb = mtf.shift(
249
              block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False)
250
      else:
251
        xoff = x
252
      tf.logging.info("SW len=%d offset=%d", subword_len, offsets)
253
      if length_dim.size % subword_len != 0:
254
        tf.logging.info("Not divisible by length")
255
        # add extra padding tokens
256
        pad_amt = int(subword_len) - int(
257
            length_dim.size % subword_len)
258
        kp = mtf.pad(xoff, [0, pad_amt], length_dim.name)
259
      else:
260
        kp = xoff
261

262
      if use_block_pos_embedding:
263
        kp += block_pos_emb
264

265
      bx = mtf.pool_tensor_1d(
266
          kp,
267
          pool_dim=kp.shape.get_dim_by_name("length"),
268
          reduce_fn=mtf.reduce_mean,
269
          pool_size=int(subword_len))
270
      block_score = mtf.layers.dense(
271
          bx, [tmp_dim],
272
          use_bias=False,
273
          name="bx",
274
          reduced_dims=[model_dim],
275
          variable_dtype=None,
276
          kernel_weights=block_kernel)
277

278
      expand_bx = _repeat(bx, subword_len, length_dim)
279
      expand_scores = _repeat(block_score, subword_len, length_dim)
280
      if offsets > 0:
281
        # add offset.
282
        expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name)
283
        expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name)
284
      new_len = expand_bx.shape.get_dim_by_name(length_dim.name)
285
      if new_len.size < length_dim.size:
286
        pad_amt = new_len.size - length_dim.size
287
        expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name)
288
        expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name)
289
      elif new_len.size > length_dim.size:
290
        expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name)
291
        expand_scores = mtf.slice(expand_scores, 0, length_dim.size,
292
                                  length_dim.name)
293

294
      new_tmp_dim = mtf.Dimension("extra_dim", 1)
295
      expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim])
296
      expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim])
297
      expand_bx = mtf.reshape(expand_bx, expand_shape)
298
      expand_scores = mtf.reshape(expand_scores, expand_scores_shape)
299
      all_blocks.append(expand_bx)
300
      all_scores.append(expand_scores)
301

302
  all_blocks = mtf.concat(all_blocks, new_tmp_dim.name)
303
  all_scores = mtf.concat(all_scores, new_tmp_dim.name)
304
  tf.logging.info(all_blocks)
305
  new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim")
306
  combined_dim = _combined_dim([new_tmp_dim, tmp_dim])
307
  block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim
308
  block_net = mtf.reshape(all_scores, block_net_shape)
309

310
  if block_mixing_mode == "score_attention":
311
    tf.logging.info("Using score attention")
312
    att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim])
313
    tf.logging.info(block_net)
314
    att = mtf.softmax(att, reduced_dim=att.shape[-1])
315
    block_net = mtf.einsum([att, block_net], output_shape=block_net.shape)
316
    tf.logging.info(block_net)
317

318
  if activation == "softmax":
319
    block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim)
320
  elif activation == "tanh":
321
    tf.logging.info("Using tanh")
322
    block_net = mtf.tanh(block_net)
323

324
  all_blocks = block_net * all_blocks
325
  all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim)
326
  output = all_blocks
327

328
  if downsample:
329
    output_length = output.shape.get_dim_by_name("length")
330
    if output_length.size % int(downsample) != 0:
331
      pad_amt = int(downsample) - int(output_length.size % int(downsample))
332
      output = mtf.pad(output, [0, pad_amt], output_length.name)
333
    if downsample_function == "mean":
334
      output = mtf.pool_tensor_1d(
335
          output,
336
          pool_dim=output.shape.get_dim_by_name("length"),
337
          reduce_fn=mtf.reduce_mean,
338
          pool_size=int(downsample))
339
    else:
340
      raise ValueError("Downsampling function not implemeneted.")
341

342
  return output
343

344

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.