google-research

Форк
0
/
retrieval_fns.py 
61 строка · 1.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""A collection of retrieval functions for negative mining.
17

18
Retrieval functions take in a matrix of scores and return a batch x `k` set of
19
indices indicating the `k` items retrieved.
20
"""
21
import abc
22
import tensorflow.compat.v2 as tf
23

24

25
class AbstractRetrievalFn(tf.Module, metaclass=abc.ABCMeta):
26

27
  @abc.abstractmethod
28
  def __call__(self, scores):
29
    pass
30

31

32
class MaxScoreRetrievalFn(AbstractRetrievalFn):
33

34
  def __call__(self, scores):
35
    indices = tf.argmax(scores, axis=1)
36
    return tf.expand_dims(indices, 1)
37

38

39
def _sample_gumbel(shape):
40
  uniform_vals = tf.random.uniform(shape)
41
  gumbel_vals = -tf.math.log(-tf.math.log(uniform_vals))
42
  return gumbel_vals
43

44

45
class GumbelMaxRetrievalFn(AbstractRetrievalFn):
46
  """Creates a retrieval function that uses Gumbel-max sampling.
47

48
  Gumbel-max sampling is an approach to sample from the softmax distribution of
49
  a set of scores by perturbing the scores then taking the argmax. The scores
50
  are first scaled by `inv_temp` then perturbed by adding Gumbel noise.
51
  """
52

53
  def __init__(self, inv_temp=1.0):
54
    super(GumbelMaxRetrievalFn, self).__init__()
55
    self.inv_temp = inv_temp
56

57
  def __call__(self, scores):
58
    gumbel_vals = _sample_gumbel(tf.shape(scores))
59
    perturbed_scores = self.inv_temp * scores + gumbel_vals
60
    indices = tf.argmax(perturbed_scores, axis=1)
61
    return tf.expand_dims(indices, 1)
62

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.