google-research

Форк
0
442 строки · 17.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implements loss functions for dual encoder training with a cache."""
17

18
import abc
19
import collections
20
from typing import Callable, Dict, Iterable, List, Optional, Tuple
21

22
import tensorflow.compat.v2 as tf
23

24
from negative_cache import negative_cache
25
from negative_cache import retrieval_fns
26
from negative_cache import util
27

28
CacheLossReturn = collections.namedtuple('CacheLossReturn', [
29
    'training_loss',
30
    'interpretable_loss',
31
    'updated_item_data',
32
    'updated_item_indices',
33
    'updated_item_mask',
34
    'staleness',
35
])
36

37

38
class CacheLoss(object, metaclass=abc.ABCMeta):
39

40
  @abc.abstractmethod
41
  def __call__(
42
      self, doc_network,
43
      query_embeddings, pos_doc_embeddings,
44
      cache):
45
    pass
46

47

48
_RetrievalReturn = collections.namedtuple('_RetrievalReturn', [
49
    'retrieved_data', 'scores', 'retrieved_indices',
50
    'retrieved_cache_embeddings'
51
])
52

53

54
def _score_documents(query_embeddings,
55
                     doc_embeddings,
56
                     score_transform = None,
57
                     all_pairs = False):
58
  """Calculates the dot product of query, document embedding pairs."""
59
  if all_pairs:
60
    scores = tf.matmul(query_embeddings, doc_embeddings, transpose_b=True)
61
  else:
62
    scores = tf.reduce_sum(query_embeddings * doc_embeddings, axis=1)
63
  if score_transform is not None:
64
    scores = score_transform(scores)
65
  return scores
66

67

68
def _batch_concat_with_no_op(tensors):
69
  """If there is only one tensor to concatenate, this is a no-op."""
70
  if len(tensors) == 1:
71
    return tensors[0]
72
  else:
73
    return tf.concat(tensors, axis=0)
74

75

76
def _retrieve_from_caches(query_embeddings,
77
                          cache,
78
                          retrieval_fn,
79
                          embedding_key,
80
                          data_keys,
81
                          sorted_data_sources,
82
                          score_transform=None,
83
                          top_k = None):
84
  """Retrieve elements from a cache with the given retrieval function."""
85
  all_embeddings = _batch_concat_with_no_op([
86
      cache[data_source].data[embedding_key]
87
      for data_source in sorted_data_sources
88
  ])
89
  all_data = {}
90
  for key in data_keys:
91
    all_data[key] = _batch_concat_with_no_op(
92
        [cache[data_source].data[key] for data_source in sorted_data_sources])
93
  scores = _score_documents(
94
      query_embeddings,
95
      all_embeddings,
96
      score_transform=score_transform,
97
      all_pairs=True)
98
  if top_k:
99
    scores, top_k_indices = util.approximate_top_k_with_indices(scores, top_k)
100
    top_k_indices = tf.cast(top_k_indices, dtype=tf.int64)
101
    retrieved_indices = retrieval_fn(scores)
102
    batch_index = tf.expand_dims(
103
        tf.range(tf.shape(retrieved_indices)[0], dtype=tf.int64), axis=1)
104
    retrieved_indices_with_batch_index = tf.concat(
105
        [batch_index, retrieved_indices], axis=1)
106
    retrieved_indices = tf.gather_nd(top_k_indices,
107
                                     retrieved_indices_with_batch_index)
108
    retrieved_indices = tf.expand_dims(retrieved_indices, axis=1)
109
  else:
110
    retrieved_indices = retrieval_fn(scores)
111
  retrieved_indices = tf.stop_gradient(retrieved_indices)
112
  retrieved_data = {
113
      k: tf.gather_nd(v, retrieved_indices) for k, v in all_data.items()
114
  }
115
  retrieved_cache_embeddings = tf.gather_nd(all_embeddings, retrieved_indices)
116
  return _RetrievalReturn(retrieved_data, scores, retrieved_indices,
117
                          retrieved_cache_embeddings)
118

119

120
def _get_data_sorce_start_position_and_cache_sizes(
121
    cache, embedding_key,
122
    sorted_data_sources
123
):
124
  """Gets the first index and size per data sources in the concatenated data."""
125
  curr_position = tf.constant(0, dtype=tf.int64)
126
  start_positions = {}
127
  cache_sizes = {}
128
  for data_source in sorted_data_sources:
129
    start_positions[data_source] = curr_position
130
    cache_sizes[data_source] = tf.shape(
131
        cache[data_source].data[embedding_key], out_type=tf.int64)[0]
132
    curr_position = curr_position + cache_sizes[data_source]
133
  return start_positions, cache_sizes
134

135

136
def _get_retrieved_embedding_updates(
137
    cache, embedding_key,
138
    sorted_data_sources, retrieved_indices,
139
    retrieved_embeddings
140
):
141
  """Gets the updates for the retrieved data."""
142
  updated_item_indices = {}
143
  updated_item_data = {}
144
  updated_item_mask = {}
145
  start_positions, cache_sizes = _get_data_sorce_start_position_and_cache_sizes(
146
      cache, embedding_key, sorted_data_sources)
147
  for data_source in sorted_data_sources:
148
    updated_item_indices[
149
        data_source] = retrieved_indices - start_positions[data_source]
150
    updated_item_data[data_source] = {embedding_key: retrieved_embeddings}
151
    updated_item_mask[data_source] = (
152
        retrieved_indices >= start_positions[data_source]) & (
153
            retrieved_indices <
154
            start_positions[data_source] + cache_sizes[data_source])
155
    updated_item_indices[data_source] = tf.squeeze(
156
        updated_item_indices[data_source], axis=1)
157
    updated_item_mask[data_source] = tf.squeeze(
158
        updated_item_mask[data_source], axis=1)
159
  return updated_item_data, updated_item_indices, updated_item_mask
160

161

162
def _get_staleness(cache_embeddings,
163
                   updated_embeddings):
164
  error = cache_embeddings - updated_embeddings
165
  mse = tf.reduce_sum(error**2, axis=1)
166
  normalized_mse = mse / tf.reduce_sum(updated_embeddings**2, axis=1)
167
  return normalized_mse
168

169

170
_LossCalculationReturn = collections.namedtuple('_LossCalculationReturn', [
171
    'training_loss', 'interpretable_loss', 'staleness', 'retrieval_return',
172
    'retrieved_negative_embeddings'
173
])
174

175

176
class AbstractCacheClassificationLoss(CacheLoss, metaclass=abc.ABCMeta):
177
  """Abstract method for cache classification losses.
178

179
  Inherit from this object and override `_retrieve_from_cache` and
180
  `_score_documents` to implement a cache classification loss based on the
181
  specified retrieval and scoring approaches.
182
  """
183

184
  @abc.abstractmethod
185
  def _retrieve_from_cache(self, query_embeddings, cache):
186
    pass
187

188
  @abc.abstractmethod
189
  def _score_documents(self, query_embeddings, doc_embeddings):
190
    pass
191

192
  def _calculate_training_loss_and_summaries(
193
      self,
194
      doc_network,
195
      query_embeddings,
196
      pos_doc_embeddings,
197
      cache,
198
      reducer=tf.math.reduce_mean):
199
    """Calculates the cache classification loss and associated summaries."""
200
    positive_scores = self._score_documents(query_embeddings,
201
                                            pos_doc_embeddings)
202
    retrieval_return = self._retrieve_from_cache(query_embeddings, cache)
203
    retrieved_negative_embeddings = doc_network(retrieval_return.retrieved_data)
204
    retrieved_negative_scores = self._score_documents(
205
        query_embeddings, retrieved_negative_embeddings)
206
    cache_and_pos_scores = tf.concat(
207
        [tf.expand_dims(positive_scores, axis=1), retrieval_return.scores],
208
        axis=1)
209
    prob_pos = tf.nn.softmax(cache_and_pos_scores, axis=1)[:, 0]
210
    prob_pos = tf.stop_gradient(prob_pos)
211
    training_loss = (1.0 - prob_pos) * (
212
        retrieved_negative_scores - positive_scores)
213
    interpretable_loss = -tf.math.log(prob_pos)
214
    staleness = _get_staleness(retrieval_return.retrieved_cache_embeddings,
215
                               retrieved_negative_embeddings)
216
    if reducer is not None:
217
      training_loss = reducer(training_loss)
218
      interpretable_loss = reducer(interpretable_loss)
219
      staleness = reducer(staleness)
220
    return _LossCalculationReturn(
221
        training_loss=training_loss,
222
        interpretable_loss=interpretable_loss,
223
        staleness=staleness,
224
        retrieval_return=retrieval_return,
225
        retrieved_negative_embeddings=retrieved_negative_embeddings)
226

227

228
class CacheClassificationLoss(AbstractCacheClassificationLoss):
229
  """Implements an efficient way to train with a cache classification loss.
230

231
  The cache classification loss is the negative log probability of the positive
232
  document when the distribution is the softmax of all documents. This object
233
  allows calculating:
234
    (1) An efficient stochastic loss function whose gradient is approximately
235
        the same as the cache classification loss in expectation. This gradient
236
        can be calculated by feeding only O(batch_size) documents through the
237
        document network, rather than O(cache_size) for the standard
238
        implementation.
239
    (2) An approximation of the value cache classification loss using the cached
240
        embeddings. The loss described above is not interpretable. This loss is
241
        a direct approximation of the cache classification loss, however we
242
        cannot calculate a gradient of this loss.
243

244
  Calling the CacheClassificationLoss return a CacheLossReturn object, which
245
  has the following fields:
246
    training_loss: Use this to calculate gradients.
247
    interpretable_loss: An interpretable number for the CacheClassificationLoss
248
        to use as a Tensorboard summary.
249
    updated_item_data, updated_item_indices, updated_item_mask: Use these in
250
        the negative cache updates. These describe the cache elements that were
251
        retrieved and current embedding calculated.
252
    staleness: This is the square error between the retrieved cache embeddings
253
        and the retrieved embeddings as defined by the current state of the
254
        model. Create a summary of this value as a proxy for the error due to
255
        cache staleness.
256
  """
257

258
  def __init__(self,
259
               embedding_key,
260
               data_keys,
261
               score_transform = None,
262
               top_k = None,
263
               reducer = tf.math.reduce_mean):
264
    """Initializes the CacheClassificationLoss object.
265

266
    Args:
267
      embedding_key: The key containing the embedding in the cache.
268
      data_keys: The keys containing the document data in the cache.
269
      score_transform: Scores are transformed by this function before use.
270
        Specifically we have scores(i, j) = score_transform(dot(query_embed_i,
271
        doc_embed_j))
272
      top_k: If set, the top k scoring negative elements will be mined and the
273
        rest of the elements masked before calculating the loss.
274
      reducer: Function that reduces the losses to a single scaler. If None,
275
        then the elementwise losses are returned.
276
    """
277
    self.embedding_key = embedding_key
278
    self.data_keys = data_keys
279
    self.score_transform = score_transform
280
    self.top_k = top_k
281
    self.reducer = reducer
282
    self._retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()
283

284
  def _retrieve_from_cache(
285
      self, query_embeddings,
286
      cache):
287
    sorted_data_sources = sorted(cache.keys())
288
    return _retrieve_from_caches(query_embeddings, cache, self._retrieval_fn,
289
                                 self.embedding_key, self.data_keys,
290
                                 sorted_data_sources, self.score_transform,
291
                                 self.top_k)
292

293
  def _score_documents(self, query_embeddings,
294
                       doc_embeddings):
295
    return _score_documents(
296
        query_embeddings, doc_embeddings, score_transform=self.score_transform)
297

298
  def __call__(
299
      self, doc_network,
300
      query_embeddings, pos_doc_embeddings,
301
      cache):
302
    """Calculates the cache classification losses.
303

304
    Args:
305
      doc_network: The network that embeds the document data.
306
      query_embeddings: Embeddings for the queries.
307
      pos_doc_embeddings: Embeddings for the documents that are positive for the
308
        given queries.
309
      cache: The cache of document data and embeddings.
310

311
    Returns:
312
      A CacheLossReturn object with the training loss, interpretable loss, and
313
      data needed to update the cache element embeddings that were retrieved and
314
      recalculated.
315
    """
316
    loss_calculation_return = self._calculate_training_loss_and_summaries(
317
        doc_network, query_embeddings, pos_doc_embeddings, cache, self.reducer)
318
    training_loss = loss_calculation_return.training_loss
319
    interpretable_loss = loss_calculation_return.interpretable_loss
320
    staleness = loss_calculation_return.staleness
321
    retrieval_return = loss_calculation_return.retrieval_return
322
    retrieved_negative_embeddings = loss_calculation_return.retrieved_negative_embeddings
323
    sorted_data_sources = sorted(cache.keys())
324
    updated_item_data, updated_item_indices, updated_item_mask = _get_retrieved_embedding_updates(
325
        cache, self.embedding_key, sorted_data_sources,
326
        retrieval_return.retrieved_indices, retrieved_negative_embeddings)
327
    return CacheLossReturn(
328
        training_loss=training_loss,
329
        interpretable_loss=interpretable_loss,
330
        updated_item_data=updated_item_data,
331
        updated_item_indices=updated_item_indices,
332
        updated_item_mask=updated_item_mask,
333
        staleness=staleness)
334

335

336
def _get_local_elements_global_data(all_elements_local_data, num_replicas):
337
  all_elements_local_data = tf.expand_dims(all_elements_local_data, axis=1)
338
  return tf.raw_ops.AllToAll(
339
      input=all_elements_local_data,
340
      group_assignment=[list(range(num_replicas))],
341
      concat_dimension=1,
342
      split_dimension=0,
343
      split_count=num_replicas)
344

345

346
class DistributedCacheClassificationLoss(AbstractCacheClassificationLoss):
347
  """Implements a cache classification loss with a sharded cache.
348

349
  This object implements a cache classification loss when the cache is sharded
350
  onto multiple replicas. This code calculates the loss treating the sharded
351
  cache as one unit, so all queries are affected by all cache elements in every
352
  replica.
353

354
  Currently, the updated_item_* fields (i.e., the embedding updates for items
355
  already in the cache) in the CacheLossReturn are empty. This does not affect
356
  new items introduced to the cache.
357
  """
358

359
  def __init__(self,
360
               embedding_key,
361
               data_keys,
362
               score_transform = None,
363
               top_k = None,
364
               reducer = tf.math.reduce_mean):
365
    self.embedding_key = embedding_key
366
    self.data_keys = data_keys
367
    self.score_transform = score_transform
368
    self.top_k = top_k
369
    self.reducer = reducer
370
    self._retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()
371

372
  def _score_documents(self, query_embeddings,
373
                       doc_embeddings):
374
    return _score_documents(
375
        query_embeddings, doc_embeddings, score_transform=self.score_transform)
376

377
  def _retrieve_from_cache(
378
      self, query_embeddings,
379
      cache):
380
    sorted_data_sources = sorted(cache.keys())
381
    all_query_embeddings = util.cross_replica_concat(query_embeddings, axis=0)
382
    num_replicas = tf.distribute.get_replica_context().num_replicas_in_sync
383
    # Performs approximate top k across replicas.
384
    if self.top_k:
385
      top_k_per_replica = self.top_k // num_replicas
386
    else:
387
      top_k_per_replica = self.top_k
388
    retrieval_return = _retrieve_from_caches(all_query_embeddings, cache,
389
                                             self._retrieval_fn,
390
                                             self.embedding_key, self.data_keys,
391
                                             sorted_data_sources,
392
                                             self.score_transform,
393
                                             top_k_per_replica)
394
    # We transfer all queries to all replica and retrieve from every shard.
395
    all_queries_local_weight = tf.math.reduce_logsumexp(
396
        retrieval_return.scores, axis=1)
397
    local_queries_global_weights = _get_local_elements_global_data(
398
        all_queries_local_weight, num_replicas)
399
    local_queries_all_retrieved_data = {}
400
    for key in retrieval_return.retrieved_data:
401
      local_queries_all_retrieved_data[key] = _get_local_elements_global_data(
402
          retrieval_return.retrieved_data[key], num_replicas)
403
    local_queries_all_retrieved_embeddings = _get_local_elements_global_data(
404
        retrieval_return.retrieved_cache_embeddings, num_replicas)
405
    # We then sample a shard index proportional to its total weight.
406
    # This allows us to do Gumbel-Max sampling without modifying APIs.
407
    selected_replica = self._retrieval_fn(local_queries_global_weights)
408
    selected_replica = tf.stop_gradient(selected_replica)
409
    num_elements = tf.shape(selected_replica)[0]
410
    batch_indices = tf.range(num_elements)
411
    batch_indices = tf.cast(batch_indices, tf.int64)
412
    batch_indices = tf.expand_dims(batch_indices, axis=1)
413
    selected_replica_with_batch = tf.concat([batch_indices, selected_replica],
414
                                            axis=1)
415
    retrieved_data = {
416
        k: tf.gather_nd(v, selected_replica_with_batch)
417
        for k, v in local_queries_all_retrieved_data.items()
418
    }
419
    retrieved_cache_embeddings = tf.gather_nd(
420
        local_queries_all_retrieved_embeddings, selected_replica_with_batch)
421
    return _RetrievalReturn(
422
        retrieved_data=retrieved_data,
423
        scores=local_queries_global_weights,
424
        retrieved_indices=None,
425
        retrieved_cache_embeddings=retrieved_cache_embeddings)
426

427
  def __call__(
428
      self, doc_network,
429
      query_embeddings, pos_doc_embeddings,
430
      cache):
431
    loss_calculation_return = self._calculate_training_loss_and_summaries(
432
        doc_network, query_embeddings, pos_doc_embeddings, cache, self.reducer)
433
    training_loss = loss_calculation_return.training_loss
434
    interpretable_loss = loss_calculation_return.interpretable_loss
435
    staleness = loss_calculation_return.staleness
436
    return CacheLossReturn(
437
        training_loss=training_loss,
438
        interpretable_loss=interpretable_loss,
439
        updated_item_data={k: None for k in cache},
440
        updated_item_indices={k: None for k in cache},
441
        updated_item_mask={k: None for k in cache},
442
        staleness=staleness)
443

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.