google-research
93 строки · 2.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contains auxiliary functions to operator over "paired" tensors."""
17
18
19from typing import Tuple
20
21import tensorflow as tf
22
23
24def pair_masks(mask_x, mask_y):
25"""Combines a pair of 2D masks into a single 3D mask.
26
27Args:
28mask_x: A tf.Tensor<float>[batch, len_x] with binary entries.
29mask_y: A tf.Tensor<float>[batch, len_y] with binary entries.
30
31Returns:
32A tf.Tensor<float>[batch, len_x, len_y] with binary entries, defined as
33out[n][i][j] := mask_x[n][i] * mask_y[n][j].
34"""
35mask1, mask2 = tf.cast(mask_x, tf.float32), tf.cast(mask_y, tf.float32)
36return tf.cast(tf.einsum('ij,ik->ijk', mask1, mask2), tf.bool)
37
38
39def build(indices, *args):
40"""Builds the pairs of whatever is passed as args for the given indices.
41
42Args:
43indices: a tf.Tensor<int32>[batch, 2]
44*args: a sequence of tf.Tensor[2 * batch, ...].
45
46Returns:
47A tuple of tf.Tensor[batch, 2, ...]
48"""
49return tuple(tf.gather(arg, indices) for arg in args)
50
51
52def consecutive_indices(batch):
53"""Builds a batch of consecutive indices of size N from a batch of size 2N.
54
55Args:
56batch: tf.Tensor<float>[2N, ...].
57
58Returns:
59A tf.Tensor<int32>[N, 2] of consecutive indices.
60"""
61batch_size = tf.shape(batch)[0]
62return tf.reshape(tf.range(batch_size), (-1, 2))
63
64
65def roll_indices(indices, shift = 1):
66"""Build a batch of non matching indices by shifting the batch of indices.
67
68Args:
69indices: a tf.Tensor<int32>[N, 2] of indices.
70shift: how much to shift the second column.
71
72Returns:
73A tf.Tensor<int32>[N, 2] of indices where the second columns has been
74rolled.
75"""
76return tf.stack([indices[:, 0], tf.roll(indices[:, 1], shift=shift, axis=0)],
77axis=1)
78
79
80def square_distances(embs_1, embs_2):
81"""Returns the matrix of square distances.
82
83Args:
84embs_1: tf.Tensor<float>[batch, len, dim].
85embs_2: tf.Tensor<float>[batch, len, dim].
86
87Returns:
88A tf.Tensor<float>[batch, len, len] containing the square distances.
89"""
90gram_embs = tf.matmul(embs_1, embs_2, transpose_b=True)
91sq_norm_embs_1 = tf.linalg.norm(embs_1, axis=-1, keepdims=True)**2
92sq_norm_embs_2 = tf.linalg.norm(embs_2, axis=-1)**2
93return sq_norm_embs_1 + sq_norm_embs_2[:, tf.newaxis, :] - 2 * gram_embs
94