google-research
442 строки · 17.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implements loss functions for dual encoder training with a cache."""
17
18import abc19import collections20from typing import Callable, Dict, Iterable, List, Optional, Tuple21
22import tensorflow.compat.v2 as tf23
24from negative_cache import negative_cache25from negative_cache import retrieval_fns26from negative_cache import util27
28CacheLossReturn = collections.namedtuple('CacheLossReturn', [29'training_loss',30'interpretable_loss',31'updated_item_data',32'updated_item_indices',33'updated_item_mask',34'staleness',35])36
37
38class CacheLoss(object, metaclass=abc.ABCMeta):39
40@abc.abstractmethod41def __call__(42self, doc_network,43query_embeddings, pos_doc_embeddings,44cache):45pass46
47
48_RetrievalReturn = collections.namedtuple('_RetrievalReturn', [49'retrieved_data', 'scores', 'retrieved_indices',50'retrieved_cache_embeddings'51])52
53
54def _score_documents(query_embeddings,55doc_embeddings,56score_transform = None,57all_pairs = False):58"""Calculates the dot product of query, document embedding pairs."""59if all_pairs:60scores = tf.matmul(query_embeddings, doc_embeddings, transpose_b=True)61else:62scores = tf.reduce_sum(query_embeddings * doc_embeddings, axis=1)63if score_transform is not None:64scores = score_transform(scores)65return scores66
67
68def _batch_concat_with_no_op(tensors):69"""If there is only one tensor to concatenate, this is a no-op."""70if len(tensors) == 1:71return tensors[0]72else:73return tf.concat(tensors, axis=0)74
75
76def _retrieve_from_caches(query_embeddings,77cache,78retrieval_fn,79embedding_key,80data_keys,81sorted_data_sources,82score_transform=None,83top_k = None):84"""Retrieve elements from a cache with the given retrieval function."""85all_embeddings = _batch_concat_with_no_op([86cache[data_source].data[embedding_key]87for data_source in sorted_data_sources88])89all_data = {}90for key in data_keys:91all_data[key] = _batch_concat_with_no_op(92[cache[data_source].data[key] for data_source in sorted_data_sources])93scores = _score_documents(94query_embeddings,95all_embeddings,96score_transform=score_transform,97all_pairs=True)98if top_k:99scores, top_k_indices = util.approximate_top_k_with_indices(scores, top_k)100top_k_indices = tf.cast(top_k_indices, dtype=tf.int64)101retrieved_indices = retrieval_fn(scores)102batch_index = tf.expand_dims(103tf.range(tf.shape(retrieved_indices)[0], dtype=tf.int64), axis=1)104retrieved_indices_with_batch_index = tf.concat(105[batch_index, retrieved_indices], axis=1)106retrieved_indices = tf.gather_nd(top_k_indices,107retrieved_indices_with_batch_index)108retrieved_indices = tf.expand_dims(retrieved_indices, axis=1)109else:110retrieved_indices = retrieval_fn(scores)111retrieved_indices = tf.stop_gradient(retrieved_indices)112retrieved_data = {113k: tf.gather_nd(v, retrieved_indices) for k, v in all_data.items()114}115retrieved_cache_embeddings = tf.gather_nd(all_embeddings, retrieved_indices)116return _RetrievalReturn(retrieved_data, scores, retrieved_indices,117retrieved_cache_embeddings)118
119
120def _get_data_sorce_start_position_and_cache_sizes(121cache, embedding_key,122sorted_data_sources
123):124"""Gets the first index and size per data sources in the concatenated data."""125curr_position = tf.constant(0, dtype=tf.int64)126start_positions = {}127cache_sizes = {}128for data_source in sorted_data_sources:129start_positions[data_source] = curr_position130cache_sizes[data_source] = tf.shape(131cache[data_source].data[embedding_key], out_type=tf.int64)[0]132curr_position = curr_position + cache_sizes[data_source]133return start_positions, cache_sizes134
135
136def _get_retrieved_embedding_updates(137cache, embedding_key,138sorted_data_sources, retrieved_indices,139retrieved_embeddings
140):141"""Gets the updates for the retrieved data."""142updated_item_indices = {}143updated_item_data = {}144updated_item_mask = {}145start_positions, cache_sizes = _get_data_sorce_start_position_and_cache_sizes(146cache, embedding_key, sorted_data_sources)147for data_source in sorted_data_sources:148updated_item_indices[149data_source] = retrieved_indices - start_positions[data_source]150updated_item_data[data_source] = {embedding_key: retrieved_embeddings}151updated_item_mask[data_source] = (152retrieved_indices >= start_positions[data_source]) & (153retrieved_indices <154start_positions[data_source] + cache_sizes[data_source])155updated_item_indices[data_source] = tf.squeeze(156updated_item_indices[data_source], axis=1)157updated_item_mask[data_source] = tf.squeeze(158updated_item_mask[data_source], axis=1)159return updated_item_data, updated_item_indices, updated_item_mask160
161
162def _get_staleness(cache_embeddings,163updated_embeddings):164error = cache_embeddings - updated_embeddings165mse = tf.reduce_sum(error**2, axis=1)166normalized_mse = mse / tf.reduce_sum(updated_embeddings**2, axis=1)167return normalized_mse168
169
170_LossCalculationReturn = collections.namedtuple('_LossCalculationReturn', [171'training_loss', 'interpretable_loss', 'staleness', 'retrieval_return',172'retrieved_negative_embeddings'173])174
175
176class AbstractCacheClassificationLoss(CacheLoss, metaclass=abc.ABCMeta):177"""Abstract method for cache classification losses.178
179Inherit from this object and override `_retrieve_from_cache` and
180`_score_documents` to implement a cache classification loss based on the
181specified retrieval and scoring approaches.
182"""
183
184@abc.abstractmethod185def _retrieve_from_cache(self, query_embeddings, cache):186pass187
188@abc.abstractmethod189def _score_documents(self, query_embeddings, doc_embeddings):190pass191
192def _calculate_training_loss_and_summaries(193self,194doc_network,195query_embeddings,196pos_doc_embeddings,197cache,198reducer=tf.math.reduce_mean):199"""Calculates the cache classification loss and associated summaries."""200positive_scores = self._score_documents(query_embeddings,201pos_doc_embeddings)202retrieval_return = self._retrieve_from_cache(query_embeddings, cache)203retrieved_negative_embeddings = doc_network(retrieval_return.retrieved_data)204retrieved_negative_scores = self._score_documents(205query_embeddings, retrieved_negative_embeddings)206cache_and_pos_scores = tf.concat(207[tf.expand_dims(positive_scores, axis=1), retrieval_return.scores],208axis=1)209prob_pos = tf.nn.softmax(cache_and_pos_scores, axis=1)[:, 0]210prob_pos = tf.stop_gradient(prob_pos)211training_loss = (1.0 - prob_pos) * (212retrieved_negative_scores - positive_scores)213interpretable_loss = -tf.math.log(prob_pos)214staleness = _get_staleness(retrieval_return.retrieved_cache_embeddings,215retrieved_negative_embeddings)216if reducer is not None:217training_loss = reducer(training_loss)218interpretable_loss = reducer(interpretable_loss)219staleness = reducer(staleness)220return _LossCalculationReturn(221training_loss=training_loss,222interpretable_loss=interpretable_loss,223staleness=staleness,224retrieval_return=retrieval_return,225retrieved_negative_embeddings=retrieved_negative_embeddings)226
227
228class CacheClassificationLoss(AbstractCacheClassificationLoss):229"""Implements an efficient way to train with a cache classification loss.230
231The cache classification loss is the negative log probability of the positive
232document when the distribution is the softmax of all documents. This object
233allows calculating:
234(1) An efficient stochastic loss function whose gradient is approximately
235the same as the cache classification loss in expectation. This gradient
236can be calculated by feeding only O(batch_size) documents through the
237document network, rather than O(cache_size) for the standard
238implementation.
239(2) An approximation of the value cache classification loss using the cached
240embeddings. The loss described above is not interpretable. This loss is
241a direct approximation of the cache classification loss, however we
242cannot calculate a gradient of this loss.
243
244Calling the CacheClassificationLoss return a CacheLossReturn object, which
245has the following fields:
246training_loss: Use this to calculate gradients.
247interpretable_loss: An interpretable number for the CacheClassificationLoss
248to use as a Tensorboard summary.
249updated_item_data, updated_item_indices, updated_item_mask: Use these in
250the negative cache updates. These describe the cache elements that were
251retrieved and current embedding calculated.
252staleness: This is the square error between the retrieved cache embeddings
253and the retrieved embeddings as defined by the current state of the
254model. Create a summary of this value as a proxy for the error due to
255cache staleness.
256"""
257
258def __init__(self,259embedding_key,260data_keys,261score_transform = None,262top_k = None,263reducer = tf.math.reduce_mean):264"""Initializes the CacheClassificationLoss object.265
266Args:
267embedding_key: The key containing the embedding in the cache.
268data_keys: The keys containing the document data in the cache.
269score_transform: Scores are transformed by this function before use.
270Specifically we have scores(i, j) = score_transform(dot(query_embed_i,
271doc_embed_j))
272top_k: If set, the top k scoring negative elements will be mined and the
273rest of the elements masked before calculating the loss.
274reducer: Function that reduces the losses to a single scaler. If None,
275then the elementwise losses are returned.
276"""
277self.embedding_key = embedding_key278self.data_keys = data_keys279self.score_transform = score_transform280self.top_k = top_k281self.reducer = reducer282self._retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()283
284def _retrieve_from_cache(285self, query_embeddings,286cache):287sorted_data_sources = sorted(cache.keys())288return _retrieve_from_caches(query_embeddings, cache, self._retrieval_fn,289self.embedding_key, self.data_keys,290sorted_data_sources, self.score_transform,291self.top_k)292
293def _score_documents(self, query_embeddings,294doc_embeddings):295return _score_documents(296query_embeddings, doc_embeddings, score_transform=self.score_transform)297
298def __call__(299self, doc_network,300query_embeddings, pos_doc_embeddings,301cache):302"""Calculates the cache classification losses.303
304Args:
305doc_network: The network that embeds the document data.
306query_embeddings: Embeddings for the queries.
307pos_doc_embeddings: Embeddings for the documents that are positive for the
308given queries.
309cache: The cache of document data and embeddings.
310
311Returns:
312A CacheLossReturn object with the training loss, interpretable loss, and
313data needed to update the cache element embeddings that were retrieved and
314recalculated.
315"""
316loss_calculation_return = self._calculate_training_loss_and_summaries(317doc_network, query_embeddings, pos_doc_embeddings, cache, self.reducer)318training_loss = loss_calculation_return.training_loss319interpretable_loss = loss_calculation_return.interpretable_loss320staleness = loss_calculation_return.staleness321retrieval_return = loss_calculation_return.retrieval_return322retrieved_negative_embeddings = loss_calculation_return.retrieved_negative_embeddings323sorted_data_sources = sorted(cache.keys())324updated_item_data, updated_item_indices, updated_item_mask = _get_retrieved_embedding_updates(325cache, self.embedding_key, sorted_data_sources,326retrieval_return.retrieved_indices, retrieved_negative_embeddings)327return CacheLossReturn(328training_loss=training_loss,329interpretable_loss=interpretable_loss,330updated_item_data=updated_item_data,331updated_item_indices=updated_item_indices,332updated_item_mask=updated_item_mask,333staleness=staleness)334
335
336def _get_local_elements_global_data(all_elements_local_data, num_replicas):337all_elements_local_data = tf.expand_dims(all_elements_local_data, axis=1)338return tf.raw_ops.AllToAll(339input=all_elements_local_data,340group_assignment=[list(range(num_replicas))],341concat_dimension=1,342split_dimension=0,343split_count=num_replicas)344
345
346class DistributedCacheClassificationLoss(AbstractCacheClassificationLoss):347"""Implements a cache classification loss with a sharded cache.348
349This object implements a cache classification loss when the cache is sharded
350onto multiple replicas. This code calculates the loss treating the sharded
351cache as one unit, so all queries are affected by all cache elements in every
352replica.
353
354Currently, the updated_item_* fields (i.e., the embedding updates for items
355already in the cache) in the CacheLossReturn are empty. This does not affect
356new items introduced to the cache.
357"""
358
359def __init__(self,360embedding_key,361data_keys,362score_transform = None,363top_k = None,364reducer = tf.math.reduce_mean):365self.embedding_key = embedding_key366self.data_keys = data_keys367self.score_transform = score_transform368self.top_k = top_k369self.reducer = reducer370self._retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()371
372def _score_documents(self, query_embeddings,373doc_embeddings):374return _score_documents(375query_embeddings, doc_embeddings, score_transform=self.score_transform)376
377def _retrieve_from_cache(378self, query_embeddings,379cache):380sorted_data_sources = sorted(cache.keys())381all_query_embeddings = util.cross_replica_concat(query_embeddings, axis=0)382num_replicas = tf.distribute.get_replica_context().num_replicas_in_sync383# Performs approximate top k across replicas.384if self.top_k:385top_k_per_replica = self.top_k // num_replicas386else:387top_k_per_replica = self.top_k388retrieval_return = _retrieve_from_caches(all_query_embeddings, cache,389self._retrieval_fn,390self.embedding_key, self.data_keys,391sorted_data_sources,392self.score_transform,393top_k_per_replica)394# We transfer all queries to all replica and retrieve from every shard.395all_queries_local_weight = tf.math.reduce_logsumexp(396retrieval_return.scores, axis=1)397local_queries_global_weights = _get_local_elements_global_data(398all_queries_local_weight, num_replicas)399local_queries_all_retrieved_data = {}400for key in retrieval_return.retrieved_data:401local_queries_all_retrieved_data[key] = _get_local_elements_global_data(402retrieval_return.retrieved_data[key], num_replicas)403local_queries_all_retrieved_embeddings = _get_local_elements_global_data(404retrieval_return.retrieved_cache_embeddings, num_replicas)405# We then sample a shard index proportional to its total weight.406# This allows us to do Gumbel-Max sampling without modifying APIs.407selected_replica = self._retrieval_fn(local_queries_global_weights)408selected_replica = tf.stop_gradient(selected_replica)409num_elements = tf.shape(selected_replica)[0]410batch_indices = tf.range(num_elements)411batch_indices = tf.cast(batch_indices, tf.int64)412batch_indices = tf.expand_dims(batch_indices, axis=1)413selected_replica_with_batch = tf.concat([batch_indices, selected_replica],414axis=1)415retrieved_data = {416k: tf.gather_nd(v, selected_replica_with_batch)417for k, v in local_queries_all_retrieved_data.items()418}419retrieved_cache_embeddings = tf.gather_nd(420local_queries_all_retrieved_embeddings, selected_replica_with_batch)421return _RetrievalReturn(422retrieved_data=retrieved_data,423scores=local_queries_global_weights,424retrieved_indices=None,425retrieved_cache_embeddings=retrieved_cache_embeddings)426
427def __call__(428self, doc_network,429query_embeddings, pos_doc_embeddings,430cache):431loss_calculation_return = self._calculate_training_loss_and_summaries(432doc_network, query_embeddings, pos_doc_embeddings, cache, self.reducer)433training_loss = loss_calculation_return.training_loss434interpretable_loss = loss_calculation_return.interpretable_loss435staleness = loss_calculation_return.staleness436return CacheLossReturn(437training_loss=training_loss,438interpretable_loss=interpretable_loss,439updated_item_data={k: None for k in cache},440updated_item_indices={k: None for k in cache},441updated_item_mask={k: None for k in cache},442staleness=staleness)443