google-research
304 строки · 12.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tensorflow 2 Retriever using a (poss. fine-tuned) BERT query encoder.
17"""
18# import logging
19# import os
20# from typing import Any, Dict, Union # pylint: disable=unused-import
21#
22# from absl import flags
23# import dataclasses
24# import numpy as np
25# import tensorflow as tf
26# import tensorflow_hub as hub
27# import tf_utils
28# import utils
29# # import h5py
30# # import scann_utils
31# # import bert_utils
32#
33#
34# FLAGS = flags.FLAGS
35# LOGGER = logging.getLogger(__name__)
36#
37#
38# @dataclasses.dataclass
39# class ScannConfig:
40#
41# num_neighbors: int
42# training_sample_size: int
43# num_leaves: int
44# num_leaves_to_search: int
45# reordering_num_neighbors: int
46
47#
48# class BERTScaNNRetriever:
49# """Class used for BERT based retrievers such as REALM and DPR.
50#
51# Parameters:
52# self.query_encoder: Model instance to encode the queries.
53# self.tokenizer: Tokenizer for the query model.
54# self.vocab_lookup_table: Vocabulary index for the query and key embedders.
55# self.scann_config: Configuration dataclass for the ScaNN builder.
56# self.block_emb: Contains the dense vectors over which MIPS is done.
57# self.scann_searcher: ScaNN MIPS index instance.
58# self.cls_token_id: Id of the CLS token for the query and the embedder
59# self.sep_token_id: Id of the SEP token for the query and the embedder
60# modules.
61# self.blocks: Object from which the raw text is obtained with indices.
62# """
63#
64# def __init__(self, retriever_module_path: str,
65# block_records_path: str, num_block_records: int,
66# mode: tf.estimator.ModeKeys, scann_config: ScannConfig):
67# """Constructor for BERTScaNNRetriever.
68#
69# Arguments:
70# retriever_module_path: Path of the BERT tf-hub checkpoint.
71# block_records_path: Path of the textual form of the retrieval dataset in
72# the TFRecord format.
73# num_block_records: Number of samples in the retrieval dataset.
74# mode: tf.estimator.ModeKeys for the model, currently only eval is
75# supported.
76#
77# scann_config: Configuration dataclass used to initialize the ScaNN MIPS
78# searcher object.
79#
80# """
81#
82# # Two and a half min. on CPU
83# with utils.log_duration(LOGGER, "BERTScaNNRetriever.__init__",
84# "hub load query enc"):
85# self.query_encoder = hub.load(retriever_module_path, tags={"train"} if
86# mode == tf.estimator.ModeKeys.TRAIN else {})
87#
88# # Instantaneous
89# with utils.log_duration(LOGGER, "BERTScaNNRetriever.__init__",
90# "build own tok info"):
91# # Building our own tokenization info saves us 5 min where we would load
92# # the BERT model again in bert_utils.get_tf_tokenizer
93# # Getting the vocab path from the tf2 hub object (from tf.load) seems
94# # broken
95# vocab_file = os.path.join(retriever_module_path, "assets", "vocab.txt")
96# utils.check_exists(vocab_file)
97# do_lower_case = self.query_encoder.signatures["tokenization_info"
98# ]()["do_lower_case"]
99# tokenization_info = dict(vocab_file=vocab_file,
100# do_lower_case=do_lower_case)
101#
102# # Instantaneous (for something that happens once) if tokenization_info
103# # is passed (our addition) a few minutes otherwise, on CPU
104# # (not passing tokenization_info makes it have to load BERT).
105# with utils.log_duration(LOGGER, "BERTScaNNRetriever.__init__",
106# "get_tf_tokenizer"):
107#
108# self.tokenizer, self.vocab_lookup_table = bert_utils.get_tf_tokenizer(
109# retriever_module_path, tokenization_info)
110#
111# # 9 min on CPU if not in dev mode. Longuest part of the setup phase.
112# # We are using a lot of default values in the load_scann_searcher call
113# # that it would probably be helpful to finetune
114# with utils.log_duration(LOGGER, "BERTScaNNRetriever.__init__",
115# "load_scann_searcher"):
116# checkpoint_path = os.path.join(retriever_module_path, "encoded",
117# "encoded.ckpt")
118# self.scann_config = scann_config
119# self.block_emb, self.scann_searcher = scann_utils.load_scann_searcher(
120# var_name="block_emb", checkpoint_path=checkpoint_path,
121# **vars(scann_config))
122#
123# # Instantaneous for something that happens once
124# with utils.log_duration(LOGGER, "BERTScaNNRetriever",
125# "CLS and SEP tokens"):
126# self.cls_token_id = tf.cast(self.vocab_lookup_table.lookup(
127# tf.constant("[CLS]")), tf.int32)
128# self.sep_token_id = tf.cast(self.vocab_lookup_table.lookup(
129# tf.constant("[SEP]")), tf.int32)
130#
131# # 3 min on CPU whwn nor in dev mode
132# with utils.log_duration(LOGGER, "BERTScaNNRetriever",
133# "Load the textual dataset"):
134# # Extract the appropriate text
135# # The buffer_size is taken from the original ORQA code.
136# blocks_dataset = tf.data.TFRecordDataset(block_records_path,
137# # Value taken from the REALM
138# # code.
139# buffer_size=512 * 1024 * 1024)
140# # Get a single batch with all elements (?)
141# blocks_dataset = blocks_dataset.batch(num_block_records,
142# drop_remainder=True)
143# # Create a thing that gets single elements over the dataset
144# self.blocks = tf.data.experimental.get_single_element(blocks_dataset)
145#
146# @tf.function
147# def retrieve(self, query_text: str) -> Dict[str, Any]:
148# """Retrieves over the retrieval dataset, from a batch of text queries.
149#
150# First generates the query vector from the text, then queries the
151# approximate maximum inner-product search engine.
152# Args:
153# query_text: Batch of text queries. In string form.
154#
155# Returns:
156# Returns the text of the approximate nearest neighbors, as well as their
157# inner product similarity with their query's vector representation.
158# """
159#
160# # Tokenize the input tokens
161# utils.check_equal(len(query_text), FLAGS.batch_size)
162# question_token_ids = self.tokenizer.batch_encode_plus(
163# query_text)["input_ids"]
164# question_token_ids = tf.cast(
165# question_token_ids.merge_dims(1, 2).to_tensor(), tf.int32)
166#
167# # Add a CLS token at the start of the input, and a SEP token at the end
168# cls_ids = tf.fill((question_token_ids.shape[0], 1), self.cls_token_id)
169# sep_ids = tf.fill((question_token_ids.shape[0], 1), self.sep_token_id)
170# question_token_ids = tf.concat((cls_ids, question_token_ids, sep_ids), 1)
171# utils.check_equal(question_token_ids.shape[0], FLAGS.batch_size)
172#
173# with utils.log_duration(LOGGER, "retrieve_multi", "Encode the query"):
174# question_emb = self.query_encoder.signatures["projected"](
175# input_ids=question_token_ids,
176# input_mask=tf.ones_like(question_token_ids),
177# segment_ids=tf.zeros_like(question_token_ids))["default"]
178# LOGGER.debug("question_emb.shape: %s", question_emb.shape)
179# utils.check_equal(question_emb.shape[0], FLAGS.batch_size)
180#
181# with utils.log_duration(LOGGER, "retrieve_multi", "search with ScaNN"):
182# retrieved_block_ids, _ = self.scann_searcher.search_batched(
183# question_emb)
184# utils.check_equal(retrieved_block_ids.shape, (
185# FLAGS.batch_size, self.scann_config.num_neighbors))
186#
187# # Gather the embeddings
188# # [batch_size, retriever_beam_size, projection_size]
189# retrieved_block_ids = retrieved_block_ids.astype(np.int64)
190# retrieved_block_emb = tf.gather(self.block_emb, retrieved_block_ids)
191# utils.check_equal(retrieved_block_emb.shape[:2], (
192# FLAGS.batch_size, self.scann_config.num_neighbors))
193#
194# # Actually retrieve the text
195# retrieved_blocks = tf.gather(self.blocks, retrieved_block_ids)
196# utils.check_equal(retrieved_blocks.shape, (
197# FLAGS.batch_size, self.scann_config.num_neighbors
198# ))
199# return retrieved_blocks
200
201#
202# class FullyCachedRetriever:
203# def __init__(
204# self, db_path: str, block_records_path: str, num_block_records: int
205# ):
206# """Uses a file where all the retrievals have been made in advance.
207#
208# Uses the exact retrievals from query_cacher.py, which have been made in
209# advance, as the questions don't change. The retrievals are made by
210# fetching
211# the pre-made retrievals by using the question-id in a lookup table.
212# The inner products are also present in the file; they are used to sample
213# from the pre-made retrievals to teach the model to adapt to having a wider
214# variety of retrievals each epoch.
215#
216# Args:
217# db_path: Path to the hdf5 file that was generated with
218# `query_cacher.py`,
219# that contains the pre-made retrievals for all questions.
220# block_records_path: Path to the file with the reference
221# (often wikipedia)
222# text, that gets retrieved.
223# num_block_records: Number of entries in the reference db.
224# """
225# # Load the db
226#
227# input_file = h5py.File(tf.io.gfile.GFile(db_path, "rb"), "r")
228# self._keys = ["train", "eval", "test"]
229#
230# LOGGER.debug("Building the hash table")
231#
232# self._indices_by_ids = {}
233# for split in self._keys:
234# self._indices_by_ids[split] = (
235# tf.lookup.StaticHashTable(
236# tf.lookup.KeyValueTensorInitializer(
237# input_file[split]["sample_ids"],
238# tf.range(input_file[split]["retrieval"][
239# "indices"].shape[0])
240# ), 1))
241#
242# LOGGER.debug("Building the self._distances_by_h5_index")
243# self._distances_by_h5_index = {
244# split: tf.constant(input_file[split]["retrieval"]["distances"][:])
245# for split in self._keys
246# }
247#
248# LOGGER.debug("Building the self._db_entry_by_h5_index")
249# self._db_entry_by_h5_index = {
250# split: tf.constant(input_file[split]["retrieval"]["indices"][:])
251# for split in self._keys
252# }
253#
254# with utils.log_duration(
255# LOGGER, "FullyCachedRetriever.__init__", "Load the textual dataset"
256# ):
257# # Extract the appropriate text
258# # The buffer_size is taken from the original ORQA code.
259# blocks_dataset = tf.data.TFRecordDataset(
260# block_records_path, buffer_size=512 * 1024 * 1024
261# )
262# blocks_dataset = blocks_dataset.batch(
263# num_block_records, drop_remainder=True
264# )
265# self._blocks = tf.data.experimental.get_single_element(blocks_dataset)
266#
267# @tf.function
268# def retrieve(
269# self, ds_split: str, question_ids: tf.Tensor, temperature: float, k: int
270# ) -> tf.Tensor:
271# """Does the retrieving.
272#
273# Args:
274# ds_split:
275# The h5 files are split per dataset split "train", "eval", "test". This
276# argument tells us which one to use.
277# question_ids: Id of the question. To be used to get the
278# cached retrievals.
279# temperature: Temperature to be used when sampling from the neighbors.
280# k: Number of neighbors to use.
281#
282# Returns:
283# A dict with the logits and the retrieved reference text blocks.
284# """
285#
286# indices = self._indices_by_ids[ds_split].lookup(question_ids)
287# distances = tf.gather(self._distances_by_h5_index[ds_split], indices)
288# db_indices = tf.gather(self._db_entry_by_h5_index[ds_split], indices)
289#
290# # pick block ids
291# logits = distances / temperature
292# selections = tf_utils.sample_without_replacement(logits, k)
293# final_indices = tf.gather(db_indices, selections, batch_dims=-1)
294# # final_logits = tf.gather(logits, selections, batch_dims=-1)
295#
296# retrieved_blocks = tf.gather(self._blocks, final_indices)
297# # utils.check_equal(final_logits.shape, final_indices.shape)
298# return retrieved_blocks
299
300
301# RetrieverType = Union[
302# BERTScaNNRetriever,
303# FullyCachedRetriever
304# ]
305