google-research

Форк
0
/
lower_triangular_multiplication.py 
272 строки · 8.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implementation of lower triangular multiplication algorithm.
17

18
This file provides a function lt_multiply to compute lt(a @ b.T) @ c given three
19
matrices a, b and c of appropriate dimensions.
20

21
This file also provides another function lt_tensor_multiply that computes
22
lt( (a tensor a) @ (b tensor b).T) @ c given three matrices a, b and c of
23
appropriate dimensions.
24

25
The functions implement a block-based algorithm described in
26
https://arxiv.org/abs/2310.01655
27
"""
28

29
from typing import Optional, Tuple
30
import jax
31
import jax.numpy as jnp
32

33

34
def lt_multiply(
35
    a,
36
    b,
37
    c,
38
    grain_size,
39
    precision = jax.lax.Precision.DEFAULT,
40
    build_cache = False,
41
):
42
  """Computes the 'lower triangular product'.
43

44
  Given a, b and c, the lower triangular product is defined to be the matrix
45
  lt(a @ b.T) @ c. When a batch dimension is present, the operation is defined
46
  with respect last to dimensions.
47

48
  Args:
49
    a: An array of shape [batch, ..., n, r]
50
    b: An array of shape [batch, ..., n, r]
51
    c: An array of shape [batch, ..., n, d]
52
    grain_size: an integer parameter that divides n
53
    precision: a precision parameter that defines the precision at which the
54
      intermediate multiplications are to be performed.
55
    build_cache: If set to True, then builds a cache and returns it useful for
56
      inference.
57

58
  Returns:
59
    A tuple. If build_cache is True, the first item in the tuple is an array
60
    equal to lt(a @ b.T) @ c of shape [batch, ..., n, d] is returned and the
61
    second item is an array called "cache" of shape [batch, ..., r, d] equal to
62
    b.T @ c. If build_cache is False, the first item of the tuple is the same
63
    as before but the second item is set to None.
64
  """
65
  assert a.shape == b.shape and a.shape[:-1] == c.shape[:-1]
66

67
  batch_dims = a.shape[:-2]
68
  n, r = a.shape[-2:]
69
  _, d = c.shape[-2:]
70

71
  assert n % grain_size == 0  # Grain size must divide the number of rows
72

73
  if n == grain_size:
74
    result = (
75
        jnp.tril(
76
            jnp.einsum(
77
                '...ti, ...si -> ...ts',
78
                a,
79
                b,
80
                precision=precision,
81
            )
82
        )
83
        @ c
84
    )
85
    if build_cache:
86
      cache = jnp.einsum(
87
          '...ti, ...tj -> ...ij',
88
          b,
89
          c,
90
          precision=precision,
91
      )
92
      return result, cache
93

94
  # We list the meaning of each array in a comment on the side/above.
95
  # The operations are described for a single example in the batch.
96

97
  a_view = a.reshape(
98
      batch_dims + (-1, grain_size, r)
99
  )  # [a_1, ...] where a_i is a grain_size x r matrix
100
  b_view = b.reshape(
101
      batch_dims + (-1, grain_size, r)
102
  )  # [b_1, ...] where b_i is a grain_size x r matrix
103
  c_view = c.reshape(
104
      batch_dims + (-1, grain_size, d)
105
  )  # [c_1, ...] where c_i is a grain_size x d matrix
106

107
  # Computes [a_1 @ b_1.T, ..., ]
108
  ab_products = jnp.einsum(
109
      '...ti, ...si -> ...ts',
110
      a_view,
111
      b_view,
112
      precision=precision,
113
  )
114

115
  # Computes [b_1.T @ c_1, ..., ] excluding last matrix
116
  bc_products = jnp.einsum(
117
      '...si, ...sk -> ...ik',
118
      b_view[Ellipsis, :-1, :, :],  # Excludes last matrix
119
      c_view[Ellipsis, :-1, :, :],  # Excludes last matrix
120
      precision=precision,
121
  )
122

123
  lt_ab_products = jnp.tril(ab_products)  # [lt(a_1 @ b_1.T), ...]
124

125
  # Computes [lt(a_1 @ b_1.T) @ c_1, ...]
126
  result = jnp.matmul(lt_ab_products, c_view, precision=precision)
127

128
  # Computes [b_1.T @ c_1, b_1.T @ c_1 + b_2.T @ c_2, ...]
129
  bc_products_cum_sum = jnp.cumsum(bc_products, axis=-3)
130

131
  # Computes [a_2 @ (b_1.T @ c_1), a_3 @ (b_1.T @ c_1 + b_2.T @ c_2), ...]
132
  correction = jnp.matmul(
133
      a_view[Ellipsis, 1:, :, :], bc_products_cum_sum, precision=precision
134
  )
135

136
  pad_list = [(0, 0)] * (len(a.shape) + 1)
137
  pad_list[-3] = (1, 0)
138
  correction = jnp.pad(correction, pad_list)  # Appends a 0 matrix.
139

140
  # [lt(a_1 @ b_1.T) @ c_1, lt(a_2 @ b_2.T) @ c_2 + a_2 @ (b_1.T @ c_1), ...]
141
  result = result + correction
142

143
  result = result.reshape(c.shape)
144

145
  cache = None
146
  if build_cache:
147
    cache = bc_products_cum_sum[Ellipsis, -1, :, :] + jnp.einsum(
148
        '...si, ...sd -> ...id',
149
        b_view[Ellipsis, -1, :, :],
150
        c_view[Ellipsis, -1, :, :],
151
        precision=precision,
152
    )
153
  return result, cache
154

155

156
def tensor_lt_multiply(
157
    a,
158
    b,
159
    c,
160
    grain_size,
161
    precision = jax.lax.Precision.DEFAULT,
162
    build_cache = False,
163
):
164
  """Computes the lower triangular product after tensoring a and b.
165

166
  Given a matrix a, the matrix (a tensor a) is defined by tensoring each
167
  row of a with itself which "squares" the number of columns in a.
168

169
  This function takes matrices a, b and c of appropriate input sizes and
170
  computes lt ( (a tensor a) @ (b tensor b).T ) @ c using a block-based
171
  algorithm with the given grain_size parameter. When a batch dimension is
172
  present, the operation is defined with respect last to dimensions.
173

174
  Instead of tensoring a and b and passing it to lt_multiply, we directly
175
  implement this algorithm using einsums. This is more efficient in practice
176
  when using TPUs/GPUs.
177

178
  Args:
179
    a: An array of shape [batch, ..., n, r]
180
    b: Input array of size [batch, ..., n, r]
181
    c: Input array of size [batch, ..., n, d]
182
    grain_size: number of rows in a block
183
    precision: precision of the einsum and matmul operations
184
    build_cache: If set to True, then builds a cache and returns it useful for
185
      inference.
186

187
  Returns:
188
    A tuple. If build_cache is True, then the first item is an array of shape
189
    [batch, ..., n, d] equal to lt ((a tensor a) @ (b tensor b).T) @ c. The
190
    second item is an array called "cache" which holds the representation of
191
    (b tensor b).T @ c and has the shape [batch, ..., r, r, d]. If build_cache
192
    is False, the first item in the tuple is the same as before but the second
193
    item in the tuple is set to None.
194
  """
195

196
  assert a.shape == b.shape and a.shape[:-1] == c.shape[:-1]
197

198
  batch_dims = a.shape[:-2]
199
  n, r = a.shape[-2:]
200
  _, d = c.shape[-2:]
201

202
  assert n % grain_size == 0
203

204
  if n == grain_size:
205
    result = (
206
        jnp.tril(
207
            jnp.einsum('...ti, ...si->...ts', a, b, precision=precision) ** 2
208
        )
209
        @ c
210
    )
211
    cache = None
212
    if build_cache:
213
      cache = jnp.einsum(
214
          '...ti, ...tj, ...td->...ijd', b, b, c, precision=precision
215
      )
216
    return result, cache
217

218
  a_view = a.reshape(batch_dims + (-1, grain_size, r))  # [a1, ..., at]
219
  b_view = b.reshape(batch_dims + (-1, grain_size, r))  # [b1, ..., bt]
220
  c_view = c.reshape(batch_dims + (-1, grain_size, d))  # [c1, ..., ct]
221

222
  # Analog of ab_products in the above
223
  a_tensor_b_tensor_products = jnp.einsum(
224
      '...ti, ...si -> ...ts',
225
      a_view,
226
      b_view,
227
      precision=precision,
228
  ) ** 2
229

230
  b_tensor_transpose_c_products = jnp.einsum(
231
      '...ti, ...tj, ...td -> ...ijd',
232
      b_view[Ellipsis, :-1, :, :],
233
      b_view[Ellipsis, :-1, :, :],
234
      c_view[Ellipsis, :-1, :, :],
235
      precision=precision,
236
  )
237

238
  lt_a_tensor_b_tensor_products = jnp.tril(a_tensor_b_tensor_products)
239

240
  result = jnp.matmul(
241
      lt_a_tensor_b_tensor_products, c_view, precision=precision
242
  )
243

244
  b_tensor_transpose_c_products_cum_sum = jnp.cumsum(
245
      b_tensor_transpose_c_products, axis=-4
246
  )
247

248
  correction = jnp.einsum(
249
      '...ti, ...tj, ...ijd -> ...td',
250
      a_view[Ellipsis, 1:, :, :],
251
      a_view[Ellipsis, 1:, :, :],
252
      b_tensor_transpose_c_products_cum_sum,
253
      precision=precision,
254
  )
255

256
  pad_list = [(0, 0)] * (len(a.shape) + 1)
257
  pad_list[-3] = (1, 0)
258
  correction = jnp.pad(correction, pad_list)
259
  result = result + correction
260
  result = result.reshape(c.shape)
261

262
  cache = None
263
  if build_cache:
264
    cache = b_tensor_transpose_c_products_cum_sum[
265
        Ellipsis, -1, :, :, :  # Take the last matrix
266
    ] + jnp.einsum(
267
        '...ti, ...tj, ...td -> ...ijd',
268
        b_view[Ellipsis, -1, :, :],
269
        b_view[Ellipsis, -1, :, :],
270
        c_view[Ellipsis, -1, :, :],
271
    )  # Add the remaining contribution
272
  return result, cache
273

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.