google-research
417 строк · 17.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Sparse attention for the transformer."""
17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20
21from tensor2tensor.layers import common_attention22from tensor2tensor.layers import common_layers23
24import tensorflow.compat.v1 as tf25
26from state_of_sparsity.sparse_transformer.layers import common_sparse27from tensorflow.contrib.model_pruning.python import pruning # pylint: disable=g-direct-tensorflow-import28from tensorflow.python.ops import inplace_ops # pylint: disable=g-direct-tensorflow-import29
30
31def compute_attention_component(antecedent,32total_depth,33filter_width=1,34padding="VALID",35name="c",36vars_3d_num_heads=0,37sparsity_technique=None,38threshold=3.0,39training=True,40clip_alpha=None,41initial_sparsity=None,42split_heads=False,43num_heads=None):44"""Computes attention compoenent (query, key or value).45
46Args:
47antecedent: a Tensor with shape [batch, length, channels]
48total_depth: an integer
49filter_width: An integer specifying how wide you want the attention
50component to be.
51padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
52name: a string specifying scope name.
53vars_3d_num_heads: an optional integer (if we want to use 3d variables)
54sparsity_technique: technique used for sparsifying weights.
55threshold: log alpha threshold used for evaluation with variational dropout.
56training: whether model is being trained or not.
57clip_alpha: alpha clipping threshold for variational dropout.
58initial_sparsity: initial sparsity level for lottery ticket &
59scratch experiments.
60split_heads: Whether to prune each head separately.
61num_heads: The number of heads in the attention module.
62
63Returns:
64c : [batch, length, depth] tensor
65"""
66# We don't support 3d attention variables or filter_width > 1 with sparsity67# techniques68assert not sparsity_technique or (not vars_3d_num_heads and filter_width == 1)69
70if vars_3d_num_heads > 0:71assert filter_width == 172input_depth = antecedent.get_shape().as_list()[-1]73depth_per_head = total_depth // vars_3d_num_heads74initializer_stddev = input_depth ** -0.575if "q" in name:76initializer_stddev *= depth_per_head ** -0.577var = tf.get_variable(78name, [input_depth,79vars_3d_num_heads,80total_depth // vars_3d_num_heads],81initializer=tf.random_normal_initializer(stddev=initializer_stddev))82var = tf.cast(var, antecedent.dtype)83var = tf.reshape(var, [input_depth, total_depth])84return tf.tensordot(antecedent, var, axes=1)85if filter_width == 1:86if sparsity_technique:87if split_heads:88# Prune each heads weights separately so that they are free89# to have different weight magnitude distributions.90if num_heads is None:91raise ValueError("`num_heads` must be set for split head pruning.")92if total_depth % num_heads != 0:93raise ValueError("`total_depth` must be divisible by `num_heads`.")94input_depth = antecedent.get_shape().as_list()[-1]95depth_per_head = int(total_depth / num_heads)96masked_head_weights = []97for head_id in range(num_heads):98head_name = name + "_shard_{}".format(head_id)99with tf.variable_scope(head_name) as vs:100head_weights = tf.get_variable(101"kernel", [input_depth, depth_per_head])102masked_head_weights.append(pruning.apply_mask(head_weights, vs))103component_weights = tf.concat(masked_head_weights, axis=1)104
105# compute the full component result106return tf.tensordot(antecedent, component_weights, axes=1)107else:108return common_sparse.dense(109antecedent,110total_depth,111use_bias=False,112sparsity_technique=sparsity_technique,113threshold=threshold,114training=training,115clip_alpha=clip_alpha,116name=name,117initial_sparsity=initial_sparsity)118else:119return common_layers.dense(120antecedent, total_depth, use_bias=False, name=name)121else:122return common_layers.conv1d(123antecedent, total_depth, filter_width, padding=padding, name=name)124
125
126def compute_qkv(query_antecedent,127memory_antecedent,128total_key_depth,129total_value_depth,130q_filter_width=1,131kv_filter_width=1,132q_padding="VALID",133kv_padding="VALID",134vars_3d_num_heads=0,135sparsity_technique=None,136threshold=3.0,137training=True,138clip_alpha=None,139initial_sparsity=None,140split_heads=False,141num_heads=None):142"""Computes query, key and value.143
144Args:
145query_antecedent: a Tensor with shape [batch, length_q, channels]
146memory_antecedent: a Tensor with shape [batch, length_m, channels]
147total_key_depth: an integer
148total_value_depth: an integer
149q_filter_width: An integer specifying how wide you want the query to be.
150kv_filter_width: An integer specifying how wide you want the keys and values
151to be.
152q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
153kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
154vars_3d_num_heads: an optional (if we want to use 3d variables)
155sparsity_technique: technique used for sparsifying weights.
156threshold: log alpha threshold used for evaluation with variational dropout.
157training: whether model is being trained or not.
158clip_alpha: alpha clipping threshold for variational dropout.
159initial_sparsity: initial sparsity level for lottery ticket &
160scratch experiments.
161split_heads: Whether to prune each head separately.
162num_heads: The number of heads in the attention module.
163
164Returns:
165q, k, v : [batch, length, depth] tensors
166"""
167if memory_antecedent is None:168memory_antecedent = query_antecedent169q = compute_attention_component(170query_antecedent,171total_key_depth,172q_filter_width,173q_padding,174"q",175vars_3d_num_heads=vars_3d_num_heads,176sparsity_technique=sparsity_technique,177threshold=threshold,178training=training,179clip_alpha=clip_alpha,180initial_sparsity=initial_sparsity,181split_heads=split_heads,182num_heads=num_heads)183k = compute_attention_component(184memory_antecedent,185total_key_depth,186kv_filter_width,187kv_padding,188"k",189vars_3d_num_heads=vars_3d_num_heads,190sparsity_technique=sparsity_technique,191threshold=threshold,192training=training,193clip_alpha=clip_alpha,194initial_sparsity=initial_sparsity,195split_heads=split_heads,196num_heads=num_heads)197v = compute_attention_component(198memory_antecedent,199total_value_depth,200kv_filter_width,201kv_padding,202"v",203vars_3d_num_heads=vars_3d_num_heads,204sparsity_technique=sparsity_technique,205threshold=threshold,206training=training,207clip_alpha=clip_alpha,208initial_sparsity=initial_sparsity,209split_heads=split_heads,210num_heads=num_heads)211return q, k, v212
213
214def multihead_attention(query_antecedent,215memory_antecedent,216bias,217total_key_depth,218total_value_depth,219output_depth,220num_heads,221dropout_rate,222attention_type="dot_product",223image_shapes=None,224q_filter_width=1,225kv_filter_width=1,226q_padding="VALID",227kv_padding="VALID",228cache=None,229name="multihead_attention",230save_weights_to=None,231make_image_summary=True,232dropout_broadcast_dims=None,233vars_3d=False,234sparsity_technique=None,235threshold=3.0,236training=True,237clip_alpha=None,238initial_sparsity=None,239split_heads=False,240**kwargs):241"""Multihead scaled-dot-product attention with input/output transformations.242
243Args:
244query_antecedent: a Tensor with shape [batch, length_q, channels]
245memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
246bias: bias Tensor (see attention_bias())
247total_key_depth: an integer
248total_value_depth: an integer
249output_depth: an integer
250num_heads: an integer dividing total_key_depth and total_value_depth
251dropout_rate: a floating point number
252attention_type: a string, either "dot_product", "dot_product_relative",
253"local_mask_right", "local_unmasked", "masked_dilated_1d",
254"unmasked_dilated_1d", graph, or any attention function
255with the signature (query, key, value, **kwargs)
256image_shapes: optional tuple of integer scalars.
257see comments for attention_image_summary()
258q_filter_width: An integer specifying how wide you want the query to be.
259kv_filter_width: An integer specifying how wide you want the keys and values
260to be.
261q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
262kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
263no padding.
264cache: dict containing Tensors which are the results of previous
265attentions, used for fast decoding. Expects the dict to contrain two
266keys ('k' and 'v'), for the initial call the values for these keys
267should be empty Tensors of the appropriate shape.
268'k' [batch_size, 0, key_channels]
269'v' [batch_size, 0, value_channels]
270name: an optional string.
271save_weights_to: an optional dictionary to capture attention weights
272for vizualization; the weights tensor will be appended there under
273a string key created from the variable scope (including name).
274make_image_summary: Whether to make an attention image summary.
275dropout_broadcast_dims: an optional list of integers less than 4
276specifying in which dimensions to broadcast the dropout decisions.
277saves memory.
278vars_3d: use 3-dimensional variables for input/output transformations
279sparsity_technique: technique used for sparsifying weights.
280threshold: log alpha threshold used for evaluation with variational dropout.
281training: whether model is being trained or not.
282clip_alpha: alpha clipping threshold for variational dropout.
283initial_sparsity: initial sparsity level for lottery ticket &
284scratch experiments.
285split_heads: Whether to prune each head separately.
286**kwargs (dict): Parameters for the attention function
287
288Caching:
289WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
290the caching assumes that the bias contains future masking.
291
292The caching works by saving all the previous key and value values so that
293you are able to send just the last query location to this attention
294function. I.e. if the cache dict is provided it assumes the query is of the
295shape [batch_size, 1, hidden_dim] rather than the full memory.
296
297Returns:
298The result of the attention transformation. The output shape is
299[batch_size, length_q, hidden_dim]
300unless the cache dict is provided in which case only the last memory
301position is calculated and the output shape is [batch_size, 1, hidden_dim]
302Optionally returns an additional loss parameters (ex: load balance loss for
303the experts) returned by the attention_type function.
304
305Raises:
306ValueError: if the key depth or value depth are not divisible by the
307number of attention heads.
308"""
309if total_key_depth % num_heads != 0:310raise ValueError("Key depth (%d) must be divisible by the number of "311"attention heads (%d)." % (total_key_depth, num_heads))312if total_value_depth % num_heads != 0:313raise ValueError("Value depth (%d) must be divisible by the number of "314"attention heads (%d)." % (total_value_depth, num_heads))315if vars_3d:316raise ValueError("3d attention variables not supported.")317if attention_type != "dot_product":318raise ValueError(319"Sparse multihead attention only supports dot_product attention.")320
321vars_3d_num_heads = 0322with tf.variable_scope(323name,324default_name="multihead_attention",325values=[query_antecedent, memory_antecedent]):326
327if cache is None or memory_antecedent is None:328q, k, v = compute_qkv(query_antecedent, memory_antecedent,329total_key_depth, total_value_depth, q_filter_width,330kv_filter_width, q_padding, kv_padding,331vars_3d_num_heads=vars_3d_num_heads,332sparsity_technique=sparsity_technique,333threshold=threshold,334training=training,335clip_alpha=clip_alpha,336initial_sparsity=initial_sparsity,337split_heads=split_heads,338num_heads=num_heads)339if cache is not None:340if bias is None:341raise ValueError("Bias required for caching. See function docstring "342"for details.")343
344if memory_antecedent is not None:345# Encoder-Decoder Attention Cache346q = compute_attention_component(query_antecedent, total_key_depth,347q_filter_width, q_padding, "q",348vars_3d_num_heads=vars_3d_num_heads,349sparsity_technique=sparsity_technique,350threshold=threshold,351training=training,352clip_alpha=clip_alpha,353initial_sparsity=initial_sparsity,354split_heads=split_heads,355num_heads=num_heads)356k = cache["k_encdec"]357v = cache["v_encdec"]358else:359k = common_attention.split_heads(k, num_heads)360v = common_attention.split_heads(v, num_heads)361decode_loop_step = kwargs.get("decode_loop_step")362if decode_loop_step is None:363k = cache["k"] = tf.concat([cache["k"], k], axis=2)364v = cache["v"] = tf.concat([cache["v"], v], axis=2)365else:366# Inplace update is required for inference on TPU.367# Inplace_ops only supports inplace_update on the first dimension.368# The performance of current implementation is better than updating369# the tensor by adding the result of matmul(one_hot,370# update_in_current_step)371tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])372tmp_k = inplace_ops.alias_inplace_update(373tmp_k, decode_loop_step, tf.squeeze(k, axis=2))374k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])375tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])376tmp_v = inplace_ops.alias_inplace_update(377tmp_v, decode_loop_step, tf.squeeze(v, axis=2))378v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])379
380q = common_attention.split_heads(q, num_heads)381if cache is None:382k = common_attention.split_heads(k, num_heads)383v = common_attention.split_heads(v, num_heads)384
385key_depth_per_head = total_key_depth // num_heads386if not vars_3d:387q *= key_depth_per_head**-0.5388
389# compute the attention390x = common_attention.dot_product_attention(391q, k, v, bias, dropout_rate, image_shapes,392save_weights_to=save_weights_to,393make_image_summary=make_image_summary,394dropout_broadcast_dims=dropout_broadcast_dims)395x = common_attention.combine_heads(x)396
397# Set last dim specifically.398x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])399
400if sparsity_technique:401x = common_sparse.dense(402x,403output_depth,404use_bias=False,405sparsity_technique=sparsity_technique,406threshold=threshold,407training=training,408clip_alpha=clip_alpha,409name="output_transform",410initial_sparsity=initial_sparsity)411else:412x = common_layers.dense(413x,414output_depth,415use_bias=False,416name="output_transform")417return x418