google-research

Форк
0
93 строки · 2.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contains auxiliary functions to operator over "paired" tensors."""
17

18

19
from typing import Tuple
20

21
import tensorflow as tf
22

23

24
def pair_masks(mask_x, mask_y):
25
  """Combines a pair of 2D masks into a single 3D mask.
26

27
  Args:
28
    mask_x: A tf.Tensor<float>[batch, len_x] with binary entries.
29
    mask_y: A tf.Tensor<float>[batch, len_y] with binary entries.
30

31
  Returns:
32
    A tf.Tensor<float>[batch, len_x, len_y] with binary entries, defined as
33
      out[n][i][j] := mask_x[n][i] * mask_y[n][j].
34
  """
35
  mask1, mask2 = tf.cast(mask_x, tf.float32), tf.cast(mask_y, tf.float32)
36
  return tf.cast(tf.einsum('ij,ik->ijk', mask1, mask2), tf.bool)
37

38

39
def build(indices, *args):
40
  """Builds the pairs of whatever is passed as args for the given indices.
41

42
  Args:
43
    indices: a tf.Tensor<int32>[batch, 2]
44
    *args: a sequence of tf.Tensor[2 * batch, ...].
45

46
  Returns:
47
    A tuple of tf.Tensor[batch, 2, ...]
48
  """
49
  return tuple(tf.gather(arg, indices) for arg in args)
50

51

52
def consecutive_indices(batch):
53
  """Builds a batch of consecutive indices of size N from a batch of size 2N.
54

55
  Args:
56
    batch: tf.Tensor<float>[2N, ...].
57

58
  Returns:
59
    A tf.Tensor<int32>[N, 2] of consecutive indices.
60
  """
61
  batch_size = tf.shape(batch)[0]
62
  return tf.reshape(tf.range(batch_size), (-1, 2))
63

64

65
def roll_indices(indices, shift = 1):
66
  """Build a batch of non matching indices by shifting the batch of indices.
67

68
  Args:
69
    indices: a tf.Tensor<int32>[N, 2] of indices.
70
    shift: how much to shift the second column.
71

72
  Returns:
73
    A tf.Tensor<int32>[N, 2] of indices where the second columns has been
74
    rolled.
75
  """
76
  return tf.stack([indices[:, 0], tf.roll(indices[:, 1], shift=shift, axis=0)],
77
                  axis=1)
78

79

80
def square_distances(embs_1, embs_2):
81
  """Returns the matrix of square distances.
82

83
  Args:
84
    embs_1: tf.Tensor<float>[batch, len, dim].
85
    embs_2: tf.Tensor<float>[batch, len, dim].
86

87
  Returns:
88
    A tf.Tensor<float>[batch, len, len] containing the square distances.
89
  """
90
  gram_embs = tf.matmul(embs_1, embs_2, transpose_b=True)
91
  sq_norm_embs_1 = tf.linalg.norm(embs_1, axis=-1, keepdims=True)**2
92
  sq_norm_embs_2 = tf.linalg.norm(embs_2, axis=-1)**2
93
  return sq_norm_embs_1 + sq_norm_embs_2[:, tf.newaxis, :] - 2 * gram_embs
94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.