google-research
272 строки · 8.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implementation of lower triangular multiplication algorithm.
17
18This file provides a function lt_multiply to compute lt(a @ b.T) @ c given three
19matrices a, b and c of appropriate dimensions.
20
21This file also provides another function lt_tensor_multiply that computes
22lt( (a tensor a) @ (b tensor b).T) @ c given three matrices a, b and c of
23appropriate dimensions.
24
25The functions implement a block-based algorithm described in
26https://arxiv.org/abs/2310.01655
27"""
28
29from typing import Optional, Tuple30import jax31import jax.numpy as jnp32
33
34def lt_multiply(35a,36b,37c,38grain_size,39precision = jax.lax.Precision.DEFAULT,40build_cache = False,41):42"""Computes the 'lower triangular product'.43
44Given a, b and c, the lower triangular product is defined to be the matrix
45lt(a @ b.T) @ c. When a batch dimension is present, the operation is defined
46with respect last to dimensions.
47
48Args:
49a: An array of shape [batch, ..., n, r]
50b: An array of shape [batch, ..., n, r]
51c: An array of shape [batch, ..., n, d]
52grain_size: an integer parameter that divides n
53precision: a precision parameter that defines the precision at which the
54intermediate multiplications are to be performed.
55build_cache: If set to True, then builds a cache and returns it useful for
56inference.
57
58Returns:
59A tuple. If build_cache is True, the first item in the tuple is an array
60equal to lt(a @ b.T) @ c of shape [batch, ..., n, d] is returned and the
61second item is an array called "cache" of shape [batch, ..., r, d] equal to
62b.T @ c. If build_cache is False, the first item of the tuple is the same
63as before but the second item is set to None.
64"""
65assert a.shape == b.shape and a.shape[:-1] == c.shape[:-1]66
67batch_dims = a.shape[:-2]68n, r = a.shape[-2:]69_, d = c.shape[-2:]70
71assert n % grain_size == 0 # Grain size must divide the number of rows72
73if n == grain_size:74result = (75jnp.tril(76jnp.einsum(77'...ti, ...si -> ...ts',78a,79b,80precision=precision,81)82)83@ c
84)85if build_cache:86cache = jnp.einsum(87'...ti, ...tj -> ...ij',88b,89c,90precision=precision,91)92return result, cache93
94# We list the meaning of each array in a comment on the side/above.95# The operations are described for a single example in the batch.96
97a_view = a.reshape(98batch_dims + (-1, grain_size, r)99) # [a_1, ...] where a_i is a grain_size x r matrix100b_view = b.reshape(101batch_dims + (-1, grain_size, r)102) # [b_1, ...] where b_i is a grain_size x r matrix103c_view = c.reshape(104batch_dims + (-1, grain_size, d)105) # [c_1, ...] where c_i is a grain_size x d matrix106
107# Computes [a_1 @ b_1.T, ..., ]108ab_products = jnp.einsum(109'...ti, ...si -> ...ts',110a_view,111b_view,112precision=precision,113)114
115# Computes [b_1.T @ c_1, ..., ] excluding last matrix116bc_products = jnp.einsum(117'...si, ...sk -> ...ik',118b_view[Ellipsis, :-1, :, :], # Excludes last matrix119c_view[Ellipsis, :-1, :, :], # Excludes last matrix120precision=precision,121)122
123lt_ab_products = jnp.tril(ab_products) # [lt(a_1 @ b_1.T), ...]124
125# Computes [lt(a_1 @ b_1.T) @ c_1, ...]126result = jnp.matmul(lt_ab_products, c_view, precision=precision)127
128# Computes [b_1.T @ c_1, b_1.T @ c_1 + b_2.T @ c_2, ...]129bc_products_cum_sum = jnp.cumsum(bc_products, axis=-3)130
131# Computes [a_2 @ (b_1.T @ c_1), a_3 @ (b_1.T @ c_1 + b_2.T @ c_2), ...]132correction = jnp.matmul(133a_view[Ellipsis, 1:, :, :], bc_products_cum_sum, precision=precision134)135
136pad_list = [(0, 0)] * (len(a.shape) + 1)137pad_list[-3] = (1, 0)138correction = jnp.pad(correction, pad_list) # Appends a 0 matrix.139
140# [lt(a_1 @ b_1.T) @ c_1, lt(a_2 @ b_2.T) @ c_2 + a_2 @ (b_1.T @ c_1), ...]141result = result + correction142
143result = result.reshape(c.shape)144
145cache = None146if build_cache:147cache = bc_products_cum_sum[Ellipsis, -1, :, :] + jnp.einsum(148'...si, ...sd -> ...id',149b_view[Ellipsis, -1, :, :],150c_view[Ellipsis, -1, :, :],151precision=precision,152)153return result, cache154
155
156def tensor_lt_multiply(157a,158b,159c,160grain_size,161precision = jax.lax.Precision.DEFAULT,162build_cache = False,163):164"""Computes the lower triangular product after tensoring a and b.165
166Given a matrix a, the matrix (a tensor a) is defined by tensoring each
167row of a with itself which "squares" the number of columns in a.
168
169This function takes matrices a, b and c of appropriate input sizes and
170computes lt ( (a tensor a) @ (b tensor b).T ) @ c using a block-based
171algorithm with the given grain_size parameter. When a batch dimension is
172present, the operation is defined with respect last to dimensions.
173
174Instead of tensoring a and b and passing it to lt_multiply, we directly
175implement this algorithm using einsums. This is more efficient in practice
176when using TPUs/GPUs.
177
178Args:
179a: An array of shape [batch, ..., n, r]
180b: Input array of size [batch, ..., n, r]
181c: Input array of size [batch, ..., n, d]
182grain_size: number of rows in a block
183precision: precision of the einsum and matmul operations
184build_cache: If set to True, then builds a cache and returns it useful for
185inference.
186
187Returns:
188A tuple. If build_cache is True, then the first item is an array of shape
189[batch, ..., n, d] equal to lt ((a tensor a) @ (b tensor b).T) @ c. The
190second item is an array called "cache" which holds the representation of
191(b tensor b).T @ c and has the shape [batch, ..., r, r, d]. If build_cache
192is False, the first item in the tuple is the same as before but the second
193item in the tuple is set to None.
194"""
195
196assert a.shape == b.shape and a.shape[:-1] == c.shape[:-1]197
198batch_dims = a.shape[:-2]199n, r = a.shape[-2:]200_, d = c.shape[-2:]201
202assert n % grain_size == 0203
204if n == grain_size:205result = (206jnp.tril(207jnp.einsum('...ti, ...si->...ts', a, b, precision=precision) ** 2208)209@ c
210)211cache = None212if build_cache:213cache = jnp.einsum(214'...ti, ...tj, ...td->...ijd', b, b, c, precision=precision215)216return result, cache217
218a_view = a.reshape(batch_dims + (-1, grain_size, r)) # [a1, ..., at]219b_view = b.reshape(batch_dims + (-1, grain_size, r)) # [b1, ..., bt]220c_view = c.reshape(batch_dims + (-1, grain_size, d)) # [c1, ..., ct]221
222# Analog of ab_products in the above223a_tensor_b_tensor_products = jnp.einsum(224'...ti, ...si -> ...ts',225a_view,226b_view,227precision=precision,228) ** 2229
230b_tensor_transpose_c_products = jnp.einsum(231'...ti, ...tj, ...td -> ...ijd',232b_view[Ellipsis, :-1, :, :],233b_view[Ellipsis, :-1, :, :],234c_view[Ellipsis, :-1, :, :],235precision=precision,236)237
238lt_a_tensor_b_tensor_products = jnp.tril(a_tensor_b_tensor_products)239
240result = jnp.matmul(241lt_a_tensor_b_tensor_products, c_view, precision=precision242)243
244b_tensor_transpose_c_products_cum_sum = jnp.cumsum(245b_tensor_transpose_c_products, axis=-4246)247
248correction = jnp.einsum(249'...ti, ...tj, ...ijd -> ...td',250a_view[Ellipsis, 1:, :, :],251a_view[Ellipsis, 1:, :, :],252b_tensor_transpose_c_products_cum_sum,253precision=precision,254)255
256pad_list = [(0, 0)] * (len(a.shape) + 1)257pad_list[-3] = (1, 0)258correction = jnp.pad(correction, pad_list)259result = result + correction260result = result.reshape(c.shape)261
262cache = None263if build_cache:264cache = b_tensor_transpose_c_products_cum_sum[265Ellipsis, -1, :, :, : # Take the last matrix266] + jnp.einsum(267'...ti, ...tj, ...td -> ...ijd',268b_view[Ellipsis, -1, :, :],269b_view[Ellipsis, -1, :, :],270c_view[Ellipsis, -1, :, :],271) # Add the remaining contribution272return result, cache273