google-research

Форк
0
169 строк · 5.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Some utils for attention layers."""
17

18
import functools
19
import itertools
20
import operator
21
import numpy as np
22
import tensorflow.compat.v2 as tf
23

24

25
def index_to_step(index, shape):
26
  """Compute step for a given nd index if we were enumerating to shape."""
27
  step = index[0]
28
  for i, s in enumerate(shape[1:]):
29
    step = step * s + index[i + 1]
30
  return step
31

32

33
def pad_to_multiple_nd(x, shape):
34
  """Pads x such that nd-shape axes are multiples of shape axes.
35

36
  Args:
37
    x: Tensor of shape [B] + nd_shape + [...].
38
    shape: Shape tuple of same length as nd_shape.
39

40
  Returns:
41
    x padded to make each axis in nd_shape divisible by the same shape axis.
42
  """
43
  x_shape = x.shape.as_list()
44
  num_feat_dim = len(x_shape) - len(shape) - 1
45
  if all(s for s in x_shape[1:len(shape) + 1]):
46
    pad_amount = np.mod(-np.asarray(x_shape[1:len(shape) + 1]), shape)
47
    paddings = [[0, 0]] + [[0, p] for p in pad_amount] + [[0, 0]] * num_feat_dim
48

49
    return tf.pad(x, paddings) if any(any(p) for p in paddings) else x
50
  else:
51
    # If shape is not fully defined.
52
    tf_shape = tf.shape(x)
53
    last = x_shape[-num_feat_dim:]
54
    paddings = [[0, -(x_shape[i + 1] or tf_shape[i + 1]) % s]
55
                for i, s in enumerate(shape)]
56
    paddings = [[0, 0]] + paddings + [[0, 0]] * num_feat_dim
57
    padded_x = tf.pad(x, paddings)
58
    padded_shape = padded_x.shape.as_list()
59
    padded_shape = padded_shape[:-1] + last
60
    return padded_x
61

62

63
def divide_nd_blocks(inputs, nd_block_size, collapse=False):
64
  """Divides input into non-overlapping n-dimensional blocks.
65

66
  Args:
67
    inputs: [B, D1, D2, ..., Dk, ...] tensor.
68
    nd_block_size: Shape tuple of length k.
69
    collapse: collapse.
70

71
  Returns:
72
    A [B, D1 // S1, D2 // S2, ..., Dk // Sk, S1 , S2 , ... , Sk, ...] tensor.
73
  """
74
  nd_block_size = list(nd_block_size)
75
  inputs = pad_to_multiple_nd(inputs, nd_block_size)
76

77
  shape = list(inputs.shape)
78
  for i, s in enumerate(shape):
79
    if s is None:
80
      shape[i] = tf.shape(inputs)[i]
81

82
  block_axes = shape[1:len(nd_block_size) + 1]
83
  num_blocks = [l // s for l, s in zip(block_axes, nd_block_size)]
84
  num_nd_axes = len(nd_block_size)
85
  num_feat_axes = len(shape) - num_nd_axes - 1
86
  features_shape = shape[-num_feat_axes:]
87

88
  # Reshape into [B, D1 // S1, S1, D2 // S2, S2, ..., Dk // Sk, Sk, ...].
89
  mid_shape = list(itertools.chain(*zip(num_blocks, nd_block_size)))
90
  cut_shape = shape[:1] + mid_shape + features_shape
91
  cut_inputs = tf.reshape(inputs, cut_shape)
92

93
  # Permute into [B, D1 // S1, D2 // S2, ..., Dk // Sk, S1, S2, ..., Sk, ...].
94
  num_mid_axes = num_nd_axes * 2
95
  num_feat_axes = len(shape) - num_nd_axes - 1
96
  mid_permute = itertools.chain(
97
      range(1, num_mid_axes, 2), range(2, num_mid_axes + 1, 2))
98
  post_permute = range(num_mid_axes + 1, num_mid_axes + num_feat_axes + 1)
99
  permutation = [0] + list(mid_permute) + list(post_permute)
100
  permuted_inputs = tf.transpose(cut_inputs, permutation)
101

102
  if not collapse:
103
    return permuted_inputs
104
  # Collapse to [B * D1 // S1 * D2 // S2 * ... * Dk // Sk, S1 * S2 * Sk, ...]
105
  block_length = functools.reduce(operator.mul, nd_block_size, 1)
106
  collapsed_inputs = tf.reshape(permuted_inputs, [-1, block_length] +
107
                                features_shape)
108

109
  return collapsed_inputs
110

111

112
def relative_attn_bias(rel_bias, num_heads, decode_step=None):
113
  """Computes attention bias based on relative positions.
114

115
  Content-based relative position attention bias was used in:
116
    https://arxiv.org/pdf/1803.02155.
117
  Non-content-based relative position attention bias was used in:
118
    https://arxiv.org/abs/1606.01933.
119

120
  Args:
121
    rel_bias: Relative bias variable of shape [num_heads, 2 * length].
122
    num_heads: Number of attention heads.
123
    decode_step: Optional decode step, used for slicing during decoding.
124

125
  Returns:
126
    A [..., length, num_heads, length] tensor with queries.
127
  """
128
  num_rel_pos = rel_bias.shape[-1]
129
  length = num_rel_pos // 2
130

131
  if tf.is_tensor(decode_step):
132
    # This is decoding so we need to select the current slice within rel_bias.
133
    # E.g.: len_k = 3, decode_step = 1
134
    # We have: rel_bias = [-2, -1, 0, 1, 2, 3]
135
    # We want: [-1, 0, 1]
136
    # We slice at len_k - decode_step - 1 = 1
137
    rel_bias = tf.reshape(rel_bias, [1, num_heads, num_rel_pos])
138
    start = ((length - 1) - decode_step)
139
    rel_bias = tf.slice(rel_bias, [0, 0, start], [1, num_heads, length])
140
    return rel_bias
141

142
  # Now we have to shift in order to compute relative biases.
143
  # Example: length = 3
144
  # Say we want:  [[0, 1, 2], [-1, 0, 1], [-2, -1, 0]]
145
  # Start: [[-2, -1, 0, 1, 2, 3], [-2, -1, 0, 1, 2, 3], [-2, -1, 0, 1, 2, 3]]
146
  # We linearize: [-2, -1, 0, 1, 2, 3, -2, -1, 0, 1, 2, 3, -2, -1, 0, 1, 2, 3]
147
  # We slice: [-2, -1, 0, 1, 2, 3, -2, -1, 0, 1, 2, 3, -2, -1, 0]
148
  # We reshape: [[-2, -1, 0, 1, 2], [3, -2, -1, 0, 1], [2, 3, -2, -1, 0]]
149
  # We slice: [[0, 1, 2], [-1, 0, 1], [-2, -1, 0]]
150
  # Tadaaa!
151

152
  # [heads, len_q * num_rel_pos]
153
  rel_bias = tf.tile(rel_bias, [1, length])
154

155
  # [heads, len_q * (num_rel_pos - 1)]
156
  num_rel_pos -= 1
157
  rel_bias = rel_bias[Ellipsis, :length * num_rel_pos]
158

159
  # [heads, len_q, num_rel_pos - 1]
160
  # Now every row is shifted by 1 to the right.
161
  rel_bias = tf.reshape(rel_bias, [num_heads, length, num_rel_pos])
162

163
  # [heads, len_q, len_k]
164
  # Slice the overlapping elements from start.
165
  rel_bias = rel_bias[Ellipsis, num_rel_pos - length:]
166
  # [len_q, heads, len_k]
167
  rel_bias = tf.transpose(rel_bias, [1, 0, 2])
168

169
  return rel_bias
170

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.