google-research
342 строки · 12.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Custom attention modules for Charformer.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import math
23import mesh_tensorflow as mtf
24from mesh_tensorflow.transformer import attention
25from mesh_tensorflow.transformer.transformer import sinusoid_positional_embedding_weights
26import tensorflow.compat.v1 as tf
27
28
29def local_attention_1d(q,
30k,
31v,
32length_dim,
33key_dim,
34value_dim,
35fully_autoregressive=True,
36length_dim_num_splits=1,
37radius=128,
38sequence_id=1,
39write_priority=None,
40read_priority=None,
41attention_kwargs=None,
42context=None):
43"""Attention to the a neighborood around the source.
44
45If fully_autoregressive, then query position p can only see memory positions
46in the range (p - radius, p].
47
48If not fully_autoregressive, then query position p can only see memory
49positions in the range (p - window_size, p + radius].
50
51In addition, if write_priority and read_priority are provided, then attention
52is limited to position pairs where
53read_priority[query position] >= write_priority[memory position]
54
55Args:
56q: a Tensor containing length_dim
57k: a Tensor containing length_dim
58v: an optional Tensor containing length_dim. If none then uses v=k.
59length_dim: a Dimension
60key_dim: a Dimension (the channels dimension of q and k)
61value_dim: a Dimension (the channels dimension of v)
62fully_autoregressive: a boolean
63length_dim_num_splits: an optional integer indicating how many ways the
64length dimension is split
65radius: an integer
66sequence_id: a Tensor or an integer
67write_priority: an optional Tensor containing length_dim
68read_priority: an optional Tensor containing length_dim
69attention_kwargs: optional keyword arguments for attention()
70context: optional context.
71
72Returns:
73a Tensor with the shape x.shape - key_dim + value_dim
74
75Raises:
76ValueError: if channels or depth don't match.
77"""
78# Choose a suitable block size.
79# We choose the greatest divisor of length_per_split less than or equal
80# to max(window_size, 128)
81tf.logging.info(attention_kwargs)
82length_per_split = length_dim.size // length_dim_num_splits
83block_length = max(radius, 128)
84while length_per_split % block_length != 0:
85block_length -= 1
86query_block_length = mtf.Dimension("query_block_length", block_length)
87memory_block_length = mtf.Dimension("memory_block_length", block_length)
88# The num_blocks dimension gets the same name as the length dimension,
89# so it will be split in the same way.
90num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length)
91def _reshape_query(x):
92return mtf.replace_dimensions(
93x, length_dim, [num_blocks, query_block_length])
94def _reshape_memory(x):
95x = mtf.replace_dimensions(
96x, length_dim, [num_blocks, memory_block_length])
97return (mtf.left_halo_exchange if fully_autoregressive
98else mtf.halo_exchange)(
99x, num_blocks, memory_block_length, radius)
100q = _reshape_query(q)
101k = _reshape_memory(k)
102if v:
103v = _reshape_memory(v)
104else:
105v = k
106if sequence_id is None:
107sequence_id = 1
108if (not isinstance(sequence_id, mtf.Tensor) or
109length_dim not in sequence_id.shape.dims):
110sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
111q_sequence_id = _reshape_query(sequence_id)
112m_sequence_id = _reshape_memory(sequence_id)
113pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
114q_pos = _reshape_query(pos)
115m_pos = _reshape_memory(pos)
116
117padded_memory_block_length = mtf.Dimension(
118"memory_block_length",
119(1 if fully_autoregressive else 2) * radius + block_length)
120
121relative_position = m_pos - q_pos
122visible = mtf.equal(q_sequence_id, m_sequence_id)
123visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius))
124visible = mtf.logical_and(visible, mtf.less_equal(
125relative_position, 0 if fully_autoregressive else radius))
126if read_priority is not None:
127write_priority = _reshape_memory(write_priority)
128read_priority = _reshape_query(read_priority)
129visible = mtf.logical_and(
130visible, mtf.greater_equal(read_priority, write_priority))
131
132bias = attention.visibility_mask_to_attention_bias(visible, q.dtype)
133o = attention.attention(q, k, v, padded_memory_block_length, key_dim,
134value_dim, bias, context=context,
135**attention_kwargs)
136return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
137
138
139def gradient_based_subword_tokenization(x,
140length_dim,
141max_subword_length=4,
142downsample=None,
143use_offsets=False,
144consider_chars_as_blocks=False,
145use_block_pos_embedding=False,
146share_block_kernel=False,
147memory_embeddings=0,
148context=None,
149block_mixing_mode=None,
150activation="softmax",
151downsample_function="mean"):
152"""Implements GBSWT from Charformer.
153
154Args:
155x: a Tensor containing length_dim
156length_dim: a Dimension
157max_subword_length: integer
158downsample: integer.
159use_offsets: boolean.
160consider_chars_as_blocks: boolean.
161use_block_pos_embedding: boolean.
162share_block_kernel: boolean.
163memory_embeddings: integer.
164context: Context.
165block_mixing_mode: Str for block mixing.
166activation: Str for block ranking.
167downsample_function: Str, supports mean/linformer for now.
168
169Returns:
170a Tensor with the same shape as x.
171
172Raises:
173ValueError: if channels or depth don't match.
174"""
175# don't use this for now.
176del max_subword_length
177del memory_embeddings
178all_blocks = []
179all_scores = []
180tf.logging.info("GSW block layer")
181
182def _tile(x, n, tile_dim):
183# Simple tile function in MTF.
184return mtf.concat([x] * n, tile_dim.name)
185
186def _repeat(x, n, repeat_dim):
187# repeat function in MTF
188tmp_dim = mtf.Dimension("tmp", 1)
189expand_shape = mtf.Shape(x.shape.dims + [tmp_dim])
190x = mtf.reshape(x, expand_shape)
191x = _tile(x, n, tmp_dim)
192output_shape = []
193for dim in x.shape.dims:
194if dim.name == "tmp":
195continue
196if dim.name == repeat_dim.name:
197dim = mtf.Dimension(dim.name, dim.size * n)
198output_shape.append(dim)
199output_shape = mtf.Shape(output_shape)
200x = mtf.reshape(x, output_shape)
201return x
202
203def _combined_dim(dims):
204return mtf.Dimension(dims[0].name, mtf.Shape(dims).size)
205
206# compute all subword blocks
207# TODO(yitay): handle offsets to get all blocks
208if activation == "sigtanh":
209# one score for sigmoid
210tmp_dim = mtf.Dimension("block_score", 2)
211else:
212tmp_dim = mtf.Dimension("block_score", 1)
213
214model_dim = x.shape[-1]
215subword_blocks_width = [2, 3, 4]
216
217if consider_chars_as_blocks:
218subword_blocks_width += [1]
219
220if share_block_kernel:
221block_kernel_shape = mtf.Shape([model_dim, tmp_dim])
222block_kernel = mtf.get_variable(
223x.mesh, "block_kernel", block_kernel_shape, initializer=None,
224dtype=context.variable_dtype)
225else:
226block_kernel = None
227
228for subword_len in subword_blocks_width:
229if use_block_pos_embedding:
230# this is turn off by default. It is meant to support cases like
231# parameterized pooling or other features.
232block_len_dim = mtf.Dimension(length_dim.name, subword_len)
233# TODO(vqtran): Consider other positional embeddings.
234block_pos_emb = sinusoid_positional_embedding_weights(
235context.mesh, block_len_dim, x.shape[-1],
236context.variable_dtype.activation_dtype)
237block_pos_emb = _repeat(block_pos_emb,
238math.ceil(length_dim.size / float(subword_len)),
239block_len_dim)
240if use_offsets:
241offset_space = subword_len
242else:
243offset_space = 1
244for offsets in range(offset_space):
245if offsets > 0:
246xoff = mtf.shift(x, offsets, length_dim, wrap=False)
247if use_block_pos_embedding:
248block_pos_emb = mtf.shift(
249block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False)
250else:
251xoff = x
252tf.logging.info("SW len=%d offset=%d", subword_len, offsets)
253if length_dim.size % subword_len != 0:
254tf.logging.info("Not divisible by length")
255# add extra padding tokens
256pad_amt = int(subword_len) - int(
257length_dim.size % subword_len)
258kp = mtf.pad(xoff, [0, pad_amt], length_dim.name)
259else:
260kp = xoff
261
262if use_block_pos_embedding:
263kp += block_pos_emb
264
265bx = mtf.pool_tensor_1d(
266kp,
267pool_dim=kp.shape.get_dim_by_name("length"),
268reduce_fn=mtf.reduce_mean,
269pool_size=int(subword_len))
270block_score = mtf.layers.dense(
271bx, [tmp_dim],
272use_bias=False,
273name="bx",
274reduced_dims=[model_dim],
275variable_dtype=None,
276kernel_weights=block_kernel)
277
278expand_bx = _repeat(bx, subword_len, length_dim)
279expand_scores = _repeat(block_score, subword_len, length_dim)
280if offsets > 0:
281# add offset.
282expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name)
283expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name)
284new_len = expand_bx.shape.get_dim_by_name(length_dim.name)
285if new_len.size < length_dim.size:
286pad_amt = new_len.size - length_dim.size
287expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name)
288expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name)
289elif new_len.size > length_dim.size:
290expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name)
291expand_scores = mtf.slice(expand_scores, 0, length_dim.size,
292length_dim.name)
293
294new_tmp_dim = mtf.Dimension("extra_dim", 1)
295expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim])
296expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim])
297expand_bx = mtf.reshape(expand_bx, expand_shape)
298expand_scores = mtf.reshape(expand_scores, expand_scores_shape)
299all_blocks.append(expand_bx)
300all_scores.append(expand_scores)
301
302all_blocks = mtf.concat(all_blocks, new_tmp_dim.name)
303all_scores = mtf.concat(all_scores, new_tmp_dim.name)
304tf.logging.info(all_blocks)
305new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim")
306combined_dim = _combined_dim([new_tmp_dim, tmp_dim])
307block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim
308block_net = mtf.reshape(all_scores, block_net_shape)
309
310if block_mixing_mode == "score_attention":
311tf.logging.info("Using score attention")
312att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim])
313tf.logging.info(block_net)
314att = mtf.softmax(att, reduced_dim=att.shape[-1])
315block_net = mtf.einsum([att, block_net], output_shape=block_net.shape)
316tf.logging.info(block_net)
317
318if activation == "softmax":
319block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim)
320elif activation == "tanh":
321tf.logging.info("Using tanh")
322block_net = mtf.tanh(block_net)
323
324all_blocks = block_net * all_blocks
325all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim)
326output = all_blocks
327
328if downsample:
329output_length = output.shape.get_dim_by_name("length")
330if output_length.size % int(downsample) != 0:
331pad_amt = int(downsample) - int(output_length.size % int(downsample))
332output = mtf.pad(output, [0, pad_amt], output_length.name)
333if downsample_function == "mean":
334output = mtf.pool_tensor_1d(
335output,
336pool_dim=output.shape.get_dim_by_name("length"),
337reduce_fn=mtf.reduce_mean,
338pool_size=int(downsample))
339else:
340raise ValueError("Downsampling function not implemeneted.")
341
342return output
343
344