google-research
169 строк · 5.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Some utils for attention layers."""
17
18import functools
19import itertools
20import operator
21import numpy as np
22import tensorflow.compat.v2 as tf
23
24
25def index_to_step(index, shape):
26"""Compute step for a given nd index if we were enumerating to shape."""
27step = index[0]
28for i, s in enumerate(shape[1:]):
29step = step * s + index[i + 1]
30return step
31
32
33def pad_to_multiple_nd(x, shape):
34"""Pads x such that nd-shape axes are multiples of shape axes.
35
36Args:
37x: Tensor of shape [B] + nd_shape + [...].
38shape: Shape tuple of same length as nd_shape.
39
40Returns:
41x padded to make each axis in nd_shape divisible by the same shape axis.
42"""
43x_shape = x.shape.as_list()
44num_feat_dim = len(x_shape) - len(shape) - 1
45if all(s for s in x_shape[1:len(shape) + 1]):
46pad_amount = np.mod(-np.asarray(x_shape[1:len(shape) + 1]), shape)
47paddings = [[0, 0]] + [[0, p] for p in pad_amount] + [[0, 0]] * num_feat_dim
48
49return tf.pad(x, paddings) if any(any(p) for p in paddings) else x
50else:
51# If shape is not fully defined.
52tf_shape = tf.shape(x)
53last = x_shape[-num_feat_dim:]
54paddings = [[0, -(x_shape[i + 1] or tf_shape[i + 1]) % s]
55for i, s in enumerate(shape)]
56paddings = [[0, 0]] + paddings + [[0, 0]] * num_feat_dim
57padded_x = tf.pad(x, paddings)
58padded_shape = padded_x.shape.as_list()
59padded_shape = padded_shape[:-1] + last
60return padded_x
61
62
63def divide_nd_blocks(inputs, nd_block_size, collapse=False):
64"""Divides input into non-overlapping n-dimensional blocks.
65
66Args:
67inputs: [B, D1, D2, ..., Dk, ...] tensor.
68nd_block_size: Shape tuple of length k.
69collapse: collapse.
70
71Returns:
72A [B, D1 // S1, D2 // S2, ..., Dk // Sk, S1 , S2 , ... , Sk, ...] tensor.
73"""
74nd_block_size = list(nd_block_size)
75inputs = pad_to_multiple_nd(inputs, nd_block_size)
76
77shape = list(inputs.shape)
78for i, s in enumerate(shape):
79if s is None:
80shape[i] = tf.shape(inputs)[i]
81
82block_axes = shape[1:len(nd_block_size) + 1]
83num_blocks = [l // s for l, s in zip(block_axes, nd_block_size)]
84num_nd_axes = len(nd_block_size)
85num_feat_axes = len(shape) - num_nd_axes - 1
86features_shape = shape[-num_feat_axes:]
87
88# Reshape into [B, D1 // S1, S1, D2 // S2, S2, ..., Dk // Sk, Sk, ...].
89mid_shape = list(itertools.chain(*zip(num_blocks, nd_block_size)))
90cut_shape = shape[:1] + mid_shape + features_shape
91cut_inputs = tf.reshape(inputs, cut_shape)
92
93# Permute into [B, D1 // S1, D2 // S2, ..., Dk // Sk, S1, S2, ..., Sk, ...].
94num_mid_axes = num_nd_axes * 2
95num_feat_axes = len(shape) - num_nd_axes - 1
96mid_permute = itertools.chain(
97range(1, num_mid_axes, 2), range(2, num_mid_axes + 1, 2))
98post_permute = range(num_mid_axes + 1, num_mid_axes + num_feat_axes + 1)
99permutation = [0] + list(mid_permute) + list(post_permute)
100permuted_inputs = tf.transpose(cut_inputs, permutation)
101
102if not collapse:
103return permuted_inputs
104# Collapse to [B * D1 // S1 * D2 // S2 * ... * Dk // Sk, S1 * S2 * Sk, ...]
105block_length = functools.reduce(operator.mul, nd_block_size, 1)
106collapsed_inputs = tf.reshape(permuted_inputs, [-1, block_length] +
107features_shape)
108
109return collapsed_inputs
110
111
112def relative_attn_bias(rel_bias, num_heads, decode_step=None):
113"""Computes attention bias based on relative positions.
114
115Content-based relative position attention bias was used in:
116https://arxiv.org/pdf/1803.02155.
117Non-content-based relative position attention bias was used in:
118https://arxiv.org/abs/1606.01933.
119
120Args:
121rel_bias: Relative bias variable of shape [num_heads, 2 * length].
122num_heads: Number of attention heads.
123decode_step: Optional decode step, used for slicing during decoding.
124
125Returns:
126A [..., length, num_heads, length] tensor with queries.
127"""
128num_rel_pos = rel_bias.shape[-1]
129length = num_rel_pos // 2
130
131if tf.is_tensor(decode_step):
132# This is decoding so we need to select the current slice within rel_bias.
133# E.g.: len_k = 3, decode_step = 1
134# We have: rel_bias = [-2, -1, 0, 1, 2, 3]
135# We want: [-1, 0, 1]
136# We slice at len_k - decode_step - 1 = 1
137rel_bias = tf.reshape(rel_bias, [1, num_heads, num_rel_pos])
138start = ((length - 1) - decode_step)
139rel_bias = tf.slice(rel_bias, [0, 0, start], [1, num_heads, length])
140return rel_bias
141
142# Now we have to shift in order to compute relative biases.
143# Example: length = 3
144# Say we want: [[0, 1, 2], [-1, 0, 1], [-2, -1, 0]]
145# Start: [[-2, -1, 0, 1, 2, 3], [-2, -1, 0, 1, 2, 3], [-2, -1, 0, 1, 2, 3]]
146# We linearize: [-2, -1, 0, 1, 2, 3, -2, -1, 0, 1, 2, 3, -2, -1, 0, 1, 2, 3]
147# We slice: [-2, -1, 0, 1, 2, 3, -2, -1, 0, 1, 2, 3, -2, -1, 0]
148# We reshape: [[-2, -1, 0, 1, 2], [3, -2, -1, 0, 1], [2, 3, -2, -1, 0]]
149# We slice: [[0, 1, 2], [-1, 0, 1], [-2, -1, 0]]
150# Tadaaa!
151
152# [heads, len_q * num_rel_pos]
153rel_bias = tf.tile(rel_bias, [1, length])
154
155# [heads, len_q * (num_rel_pos - 1)]
156num_rel_pos -= 1
157rel_bias = rel_bias[Ellipsis, :length * num_rel_pos]
158
159# [heads, len_q, num_rel_pos - 1]
160# Now every row is shifted by 1 to the right.
161rel_bias = tf.reshape(rel_bias, [num_heads, length, num_rel_pos])
162
163# [heads, len_q, len_k]
164# Slice the overlapping elements from start.
165rel_bias = rel_bias[Ellipsis, num_rel_pos - length:]
166# [len_q, heads, len_k]
167rel_bias = tf.transpose(rel_bias, [1, 0, 2])
168
169return rel_bias
170