google-research

Форк
0
304 строки · 12.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tensorflow 2 Retriever using a (poss. fine-tuned) BERT query encoder.
17
"""
18
# import logging
19
# import os
20
# from typing import Any, Dict, Union  # pylint: disable=unused-import
21
#
22
# from absl import flags
23
# import dataclasses
24
# import numpy as np
25
# import tensorflow as tf
26
# import tensorflow_hub as hub
27
# import tf_utils
28
# import utils
29
# # import h5py
30
# # import scann_utils
31
# # import bert_utils
32
#
33
#
34
# FLAGS = flags.FLAGS
35
# LOGGER = logging.getLogger(__name__)
36
#
37
#
38
# @dataclasses.dataclass
39
# class ScannConfig:
40
#
41
#   num_neighbors: int
42
#   training_sample_size: int
43
#   num_leaves: int
44
#   num_leaves_to_search: int
45
#   reordering_num_neighbors: int
46

47
#
48
# class BERTScaNNRetriever:
49
#   """Class used for BERT based retrievers such as REALM and DPR.
50
#
51
#   Parameters:
52
#     self.query_encoder: Model instance to encode the queries.
53
#     self.tokenizer: Tokenizer for the query model.
54
#     self.vocab_lookup_table: Vocabulary index for the query and key embedders.
55
#     self.scann_config: Configuration dataclass for the ScaNN builder.
56
#     self.block_emb: Contains the dense vectors over which MIPS is done.
57
#     self.scann_searcher: ScaNN MIPS index instance.
58
#     self.cls_token_id: Id of the CLS token for the query and the embedder
59
#     self.sep_token_id: Id of the SEP token for the query and the embedder
60
#       modules.
61
#     self.blocks: Object from which the raw text is obtained with indices.
62
#   """
63
#
64
#   def __init__(self, retriever_module_path: str,
65
#                block_records_path: str, num_block_records: int,
66
#                mode: tf.estimator.ModeKeys, scann_config: ScannConfig):
67
#     """Constructor for BERTScaNNRetriever.
68
#
69
#     Arguments:
70
#       retriever_module_path: Path of the BERT tf-hub checkpoint.
71
#       block_records_path: Path of the textual form of the retrieval dataset in
72
#                           the TFRecord format.
73
#       num_block_records: Number of samples in the retrieval dataset.
74
#       mode: tf.estimator.ModeKeys for the model, currently only eval is
75
#             supported.
76
#
77
#       scann_config: Configuration dataclass used to initialize the ScaNN MIPS
78
#                     searcher object.
79
#
80
#     """
81
#
82
#     # Two and a half min. on CPU
83
#     with utils.log_duration(LOGGER, "BERTScaNNRetriever.__init__",
84
#                             "hub load query enc"):
85
#       self.query_encoder = hub.load(retriever_module_path, tags={"train"} if
86
#                                   mode == tf.estimator.ModeKeys.TRAIN else {})
87
#
88
#     # Instantaneous
89
#     with utils.log_duration(LOGGER, "BERTScaNNRetriever.__init__",
90
#                             "build own tok info"):
91
#       # Building our own tokenization info saves us 5 min where we would load
92
#       # the BERT model again in bert_utils.get_tf_tokenizer
93
#       # Getting the vocab path from the tf2 hub object (from tf.load) seems
94
#       # broken
95
#       vocab_file = os.path.join(retriever_module_path, "assets", "vocab.txt")
96
#       utils.check_exists(vocab_file)
97
#       do_lower_case = self.query_encoder.signatures["tokenization_info"
98
#                                                     ]()["do_lower_case"]
99
#       tokenization_info = dict(vocab_file=vocab_file,
100
#                                do_lower_case=do_lower_case)
101
#
102
#     # Instantaneous (for something that happens once) if tokenization_info
103
#     # is passed (our addition) a few minutes otherwise, on CPU
104
#     # (not passing tokenization_info makes it have to load BERT).
105
#     with utils.log_duration(LOGGER, "BERTScaNNRetriever.__init__",
106
#                             "get_tf_tokenizer"):
107
#
108
#       self.tokenizer, self.vocab_lookup_table = bert_utils.get_tf_tokenizer(
109
#           retriever_module_path, tokenization_info)
110
#
111
#     # 9 min on CPU if not in dev mode. Longuest part of the setup phase.
112
#     # We are using a lot of default values in the load_scann_searcher call
113
#     # that it would probably be helpful to finetune
114
#     with utils.log_duration(LOGGER, "BERTScaNNRetriever.__init__",
115
#                             "load_scann_searcher"):
116
#       checkpoint_path = os.path.join(retriever_module_path, "encoded",
117
#                                      "encoded.ckpt")
118
#       self.scann_config = scann_config
119
#       self.block_emb, self.scann_searcher = scann_utils.load_scann_searcher(
120
#           var_name="block_emb", checkpoint_path=checkpoint_path,
121
#           **vars(scann_config))
122
#
123
#     # Instantaneous for something that happens once
124
#     with utils.log_duration(LOGGER, "BERTScaNNRetriever",
125
#     "CLS and SEP tokens"):
126
#       self.cls_token_id = tf.cast(self.vocab_lookup_table.lookup(
127
#           tf.constant("[CLS]")), tf.int32)
128
#       self.sep_token_id = tf.cast(self.vocab_lookup_table.lookup(
129
#           tf.constant("[SEP]")), tf.int32)
130
#
131
#       # 3 min on CPU whwn nor in dev mode
132
#     with utils.log_duration(LOGGER, "BERTScaNNRetriever",
133
#                             "Load the textual dataset"):
134
#       # Extract the appropriate text
135
#       # The buffer_size is taken from the original ORQA code.
136
#       blocks_dataset = tf.data.TFRecordDataset(block_records_path,
137
#                                                # Value taken from the REALM
138
#                                                # code.
139
#                                                buffer_size=512 * 1024 * 1024)
140
#       # Get a single batch with all elements (?)
141
#       blocks_dataset = blocks_dataset.batch(num_block_records,
142
#                                             drop_remainder=True)
143
#       # Create a thing that gets single elements over the dataset
144
#       self.blocks = tf.data.experimental.get_single_element(blocks_dataset)
145
#
146
#   @tf.function
147
#   def retrieve(self, query_text: str) -> Dict[str, Any]:
148
#     """Retrieves over the retrieval dataset, from a batch of text queries.
149
#
150
#     First generates the query vector from the text, then queries the
151
#     approximate maximum inner-product search engine.
152
#     Args:
153
#       query_text: Batch of text queries. In string form.
154
#
155
#     Returns:
156
#       Returns the text of the approximate nearest neighbors, as well as their
157
#       inner product similarity with their query's vector representation.
158
#     """
159
#
160
#     # Tokenize the input tokens
161
#     utils.check_equal(len(query_text), FLAGS.batch_size)
162
#     question_token_ids = self.tokenizer.batch_encode_plus(
163
#         query_text)["input_ids"]
164
#     question_token_ids = tf.cast(
165
#         question_token_ids.merge_dims(1, 2).to_tensor(), tf.int32)
166
#
167
#     # Add a CLS token at the start of the input, and a SEP token at the end
168
#     cls_ids = tf.fill((question_token_ids.shape[0], 1), self.cls_token_id)
169
#     sep_ids = tf.fill((question_token_ids.shape[0], 1), self.sep_token_id)
170
#     question_token_ids = tf.concat((cls_ids, question_token_ids, sep_ids), 1)
171
#     utils.check_equal(question_token_ids.shape[0], FLAGS.batch_size)
172
#
173
#     with utils.log_duration(LOGGER, "retrieve_multi", "Encode the query"):
174
#       question_emb = self.query_encoder.signatures["projected"](
175
#           input_ids=question_token_ids,
176
#           input_mask=tf.ones_like(question_token_ids),
177
#           segment_ids=tf.zeros_like(question_token_ids))["default"]
178
#     LOGGER.debug("question_emb.shape: %s", question_emb.shape)
179
#     utils.check_equal(question_emb.shape[0], FLAGS.batch_size)
180
#
181
#     with utils.log_duration(LOGGER, "retrieve_multi", "search with ScaNN"):
182
#       retrieved_block_ids, _ = self.scann_searcher.search_batched(
183
#       question_emb)
184
#       utils.check_equal(retrieved_block_ids.shape, (
185
#           FLAGS.batch_size, self.scann_config.num_neighbors))
186
#
187
#     # Gather the embeddings
188
#     # [batch_size, retriever_beam_size, projection_size]
189
#     retrieved_block_ids = retrieved_block_ids.astype(np.int64)
190
#     retrieved_block_emb = tf.gather(self.block_emb, retrieved_block_ids)
191
#     utils.check_equal(retrieved_block_emb.shape[:2], (
192
#         FLAGS.batch_size, self.scann_config.num_neighbors))
193
#
194
#     # Actually retrieve the text
195
#     retrieved_blocks = tf.gather(self.blocks, retrieved_block_ids)
196
#     utils.check_equal(retrieved_blocks.shape, (
197
#         FLAGS.batch_size, self.scann_config.num_neighbors
198
#     ))
199
#     return retrieved_blocks
200

201
#
202
# class FullyCachedRetriever:
203
#   def __init__(
204
#       self, db_path: str, block_records_path: str, num_block_records: int
205
#   ):
206
#     """Uses a file where all the retrievals have been made in advance.
207
#
208
#     Uses the exact retrievals from query_cacher.py, which have been made in
209
#     advance, as the questions don't change. The retrievals are made by
210
#     fetching
211
#     the pre-made retrievals by using the question-id in a lookup table.
212
#     The inner products are also present in the file; they are used to sample
213
#     from the pre-made retrievals to teach the model to adapt to having a wider
214
#     variety of retrievals each epoch.
215
#
216
#     Args:
217
#       db_path: Path to the hdf5 file that was generated with
218
#       `query_cacher.py`,
219
#         that contains the pre-made retrievals for all questions.
220
#       block_records_path: Path to the file with the reference
221
#       (often wikipedia)
222
#         text, that gets retrieved.
223
#       num_block_records: Number of entries in the reference db.
224
#     """
225
#     # Load the db
226
#
227
#     input_file = h5py.File(tf.io.gfile.GFile(db_path, "rb"), "r")
228
#     self._keys = ["train", "eval", "test"]
229
#
230
#     LOGGER.debug("Building the hash table")
231
#
232
#     self._indices_by_ids = {}
233
#     for split in self._keys:
234
#       self._indices_by_ids[split] = (
235
#           tf.lookup.StaticHashTable(
236
#               tf.lookup.KeyValueTensorInitializer(
237
#                   input_file[split]["sample_ids"],
238
#                   tf.range(input_file[split]["retrieval"][
239
#                   "indices"].shape[0])
240
#               ), 1))
241
#
242
#     LOGGER.debug("Building the self._distances_by_h5_index")
243
#     self._distances_by_h5_index = {
244
#         split: tf.constant(input_file[split]["retrieval"]["distances"][:])
245
#         for split in self._keys
246
#     }
247
#
248
#     LOGGER.debug("Building the self._db_entry_by_h5_index")
249
#     self._db_entry_by_h5_index = {
250
#         split: tf.constant(input_file[split]["retrieval"]["indices"][:])
251
#         for split in self._keys
252
#     }
253
#
254
#     with utils.log_duration(
255
#         LOGGER, "FullyCachedRetriever.__init__", "Load the textual dataset"
256
#     ):
257
#       # Extract the appropriate text
258
#       # The buffer_size is taken from the original ORQA code.
259
#       blocks_dataset = tf.data.TFRecordDataset(
260
#           block_records_path, buffer_size=512 * 1024 * 1024
261
#       )
262
#       blocks_dataset = blocks_dataset.batch(
263
#           num_block_records, drop_remainder=True
264
#       )
265
#       self._blocks = tf.data.experimental.get_single_element(blocks_dataset)
266
#
267
#   @tf.function
268
#   def retrieve(
269
#       self, ds_split: str, question_ids: tf.Tensor, temperature: float, k: int
270
#   ) -> tf.Tensor:
271
#     """Does the retrieving.
272
#
273
#     Args:
274
#       ds_split:
275
#         The h5 files are split per dataset split "train", "eval", "test". This
276
#         argument tells us which one to use.
277
#       question_ids: Id of the question. To be used to get the
278
#       cached retrievals.
279
#       temperature: Temperature to be used when sampling from the neighbors.
280
#       k: Number of neighbors to use.
281
#
282
#     Returns:
283
#       A dict with the logits and the retrieved reference text blocks.
284
#     """
285
#
286
#     indices = self._indices_by_ids[ds_split].lookup(question_ids)
287
#     distances = tf.gather(self._distances_by_h5_index[ds_split], indices)
288
#     db_indices = tf.gather(self._db_entry_by_h5_index[ds_split], indices)
289
#
290
#     # pick block ids
291
#     logits = distances / temperature
292
#     selections = tf_utils.sample_without_replacement(logits, k)
293
#     final_indices = tf.gather(db_indices, selections, batch_dims=-1)
294
#     # final_logits = tf.gather(logits, selections, batch_dims=-1)
295
#
296
#     retrieved_blocks = tf.gather(self._blocks, final_indices)
297
#     # utils.check_equal(final_logits.shape, final_indices.shape)
298
#     return retrieved_blocks
299

300

301
# RetrieverType = Union[
302
  # BERTScaNNRetriever,
303
  # FullyCachedRetriever
304
# ]
305

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.