google-research
545 строк · 18.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Iterates over the whole ELI5 dataset (for each dataset split), extracts REALM or REALM++ embeddings, then does exact retrieval for a large number of neighbors, then saves the resulting db indices and distance (inner product).
17
18This makes it so we don't have to do real retrieval during training, we just
19sample from the neighbors as a function of the inner product, which is much
20faster.
21
22Examples of use:
23
24# Local Test
25pytype query_cacher_tfrecord.py -P . --check-variable-types \
26--check-container-types \
27--check-parameter-types --precise-return && \
28python3 check_flags.py query_cacher_tfrecord.py && \
29python3 query_cacher_tfrecord.py $(python3 json_to_args.py \
30configs/query_cacher_configs/local.json) \
31--logger_levels=__main__:DEBUG,utils:DEBUG,tf_utils:DEBUG \
32--use_subset=True
33
34"""
35import collections
36import logging
37import os
38import time
39from typing import Callable, Dict, List
40
41from absl import app
42from absl import flags
43from absl import logging as absl_logging
44import bert_utils
45import constants
46import datasets
47import numpy as np
48import tensorflow as tf
49import tensorflow.python.distribute.values as values
50import tensorflow.python.framework.ops as ops
51import tensorflow.python.trackable.autotrackable as autotrackable
52import tensorflow_hub as hub
53import tf_utils
54import tqdm
55import transformers
56import utils
57
58
59_FLAG_JOB_NAME = flags.DEFINE_string(
60"run_name",
61None,
62"Name of the run."
63)
64_FLAG_OUTPUT_PATH = flags.DEFINE_string(
65"output_dir",
66None,
67"Directory in which to save, on the cloud.")
68_FLAG_RETRIEVER_CONFIG_PATH = flags.DEFINE_string(
69"retriever_config_path",
70None,
71"Path to the retriever's configuration file."
72)
73_FLAG_BATCH_SIZE = flags.DEFINE_integer(
74"batch_size",
75100,
76"Size of the batch for the encoder BERT model."
77)
78
79_FLAG_DATASET_ROOT = flags.DEFINE_string(
80"dataset_root",
81None,
82"Root of the place where the datasets are saved."
83)
84
85# Flags specific to query encoding
86_FLAG_EMBEDDING_DEPTH = flags.DEFINE_integer(
87"embedding_depth",
88128,
89"Size of the BERT (REALM) embeddings.",
90)
91
92# Flags specific to retrieval caching
93_FLAG_NUM_RETRIEVALS = flags.DEFINE_integer(
94"num_retrievals",
9510,
96"Number of neighbors to retrieve.",
97)
98_FLAG_CONTEXT_SIZE = flags.DEFINE_integer(
99"context_size",
1001024,
101"Length to pad to."
102)
103_FLAG_MAX_LENGTH_RETRIEVALS = flags.DEFINE_integer(
104"max_length_retrievals",
105350,
106"Maximum length of the retrievals."
107)
108
109_FLAG_NUM_SHARDS = flags.DEFINE_integer(
110"num_shards",
1112048,
112"Number of files to output tfr shards."
113)
114
115_FLAG_USE_SUBSET = flags.DEFINE_boolean(
116"use_subset",
117False,
118"Whether or not to use a subset."
119)
120
121_FLAG_SUBSET_QTY = flags.DEFINE_integer(
122"subset_qty",
123500,
124"subset_qty"
125)
126
127LOGGER = logging.getLogger(__name__)
128
129
130def _bytes_feature(value):
131"""Returns a bytes_list from a string / byte."""
132if isinstance(value, type(tf.constant(0))):
133value = value.numpy()
134return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
135
136
137class BERTBatchFields(utils.FlagChoices):
138bert_question_token_ids = "bert_question_token_ids"
139bert_attention_masks = "bert_attention_masks"
140
141
142def _make_transform_fn(
143bert_tokenizer,
144bert_cls_token_id,
145bert_sep_token_id,
146):
147"""Prepares the transformation function."""
148@tf.function
149def _prepare_for_bert(sample):
150"""Prepares a question sample from ELI5 to be fed to BERT."""
151bert_question_token_ids = bert_tokenizer.tokenize(
152tf.expand_dims(sample["question"], 0))
153bert_question_token_ids = tf.cast(
154bert_question_token_ids.merge_dims(1, 2).to_tensor(), tf.int32)
155cls_ids = tf.fill([tf.shape(bert_question_token_ids)[0], 1],
156bert_cls_token_id)
157sep_ids = tf.fill([tf.shape(bert_question_token_ids)[0], 1],
158bert_sep_token_id)
159bert_question_token_ids = tf.concat(
160(cls_ids, bert_question_token_ids, sep_ids), 1)
161
162return dict(
163bert_question_token_ids=bert_question_token_ids,
164bert_attention_masks=tf.ones_like(bert_question_token_ids),
165**sample
166)
167
168return _prepare_for_bert
169
170
171@tf.function
172def _squeeze(batch):
173"""Squeezes and converts tensors to dense tensors w/ padding."""
174batch = dict(**batch)
175batch[BERTBatchFields.bert_question_token_ids] = tf.squeeze(
176batch[BERTBatchFields.bert_question_token_ids].to_tensor(0), 1)
177batch[BERTBatchFields.bert_attention_masks] = tf.squeeze(
178batch[BERTBatchFields.bert_attention_masks].to_tensor(0), 1)
179return batch
180
181
182def _make_encode_fn(
183query_encoder
184):
185"""Prepares the BERT encoder function."""
186
187@tf.function(reduce_retracing=True)
188def _encode(batch):
189"""Encodes a sample with REALM BERT."""
190# Add a CLS token at the start of the input, and a SEP token at the end
191
192return query_encoder.signatures["projected"](
193input_ids=batch[BERTBatchFields.bert_question_token_ids],
194input_mask=batch[BERTBatchFields.bert_attention_masks],
195segment_ids=tf.zeros_like(
196batch[BERTBatchFields.bert_question_token_ids]
197))["default"]
198
199return _encode
200
201
202def make_encode_fn_strategy_run_fn(
203strategy,
204encode_fn,
205):
206"""Builds the runner function for the REALM query function."""
207
208# Giving {} as a default value would make the default value mutable, which
209# is prohibited (because changing the object would change the default value).
210
211@tf.function(reduce_retracing=True)
212def encode_fn_strategy_run_fn(batch):
213"""Runs the distribute strategy on the query encoder."""
214return strategy.run(encode_fn, args=(batch,))
215
216return encode_fn_strategy_run_fn
217
218
219######################################################################
220# Effectuate the retrievals.
221######################################################################
222def _prep_field(field, gpt2_tokenizer):
223"""Prepares different fields to be saved in a tfr."""
224decoded_list = [sample.decode() for sample in field.numpy().tolist()]
225encoded = gpt2_tokenizer.batch_encode_plus(
226decoded_list,
227padding="max_length",
228truncation=True,
229).input_ids
230
231ids = np.array(
232encoded,
233dtype=np.int32,
234)
235
236ids[ids == gpt2_tokenizer.eos_token_id] = -1
237return ids
238
239
240def main(argv):
241if len(argv) > 1:
242raise RuntimeError(argv)
243absl_logging.use_python_logging()
244utils.log_module_args(LOGGER, argv[0])
245
246retriever_config = tf_utils.REALMSave(
247**utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value)
248)
249assert not _FLAG_USE_SUBSET.value
250
251time_stamp = time.strftime("%Y%m%d-%H%M%S")
252target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp.strip())
253if target_path[-1] != "/":
254target_path += "/"
255
256##############################################################################
257# Setup devices and strategy
258##############################################################################
259with utils.log_duration(LOGGER, "main", "Initializing devices"):
260tpu_config = tf_utils.init_tpus()
261device_type = tf_utils.current_accelerator_type()
262LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use()))
263
264if device_type == "TPU":
265if tpu_config is None:
266raise RuntimeError("We should have a tpu_config.")
267strategy = tf.distribute.TPUStrategy(tpu_config.resolver)
268batch_size = len(tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
269elif device_type == "GPU" or device_type == "CPU":
270strategy = tf.distribute.MirroredStrategy()
271batch_size = len(tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
272else:
273raise RuntimeError(device_type)
274
275##############################################################################
276# Load the dataset.
277##############################################################################
278eli5 = {}
279keys = ["train", "eval", "test"]
280gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl")
281gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
282
283with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."):
284for split in tqdm.tqdm(keys):
285load_path = os.path.join(
286_FLAG_DATASET_ROOT.value,
287"HuggingfaceDatasets",
288f"{split}_kilt_eli5.hf"
289)
290with tf.device("/job:localhost"):
291eli5[split] = datasets.load_from_disk(load_path)
292
293##############################################################################
294#
295##############################################################################
296with utils.log_duration(
297LOGGER, "Main", "Load the textual dataset"
298):
299# Extract the appropriate text
300# The buffer_size is taken from the original ORQA code.
301blocks_dataset = tf.data.TFRecordDataset(
302retriever_config.text_records, buffer_size=512 * 1024 * 1024
303)
304blocks_dataset = blocks_dataset.batch(
305retriever_config.num_block_records, drop_remainder=True
306)
307blocks = tf.data.experimental.get_single_element(blocks_dataset)
308
309############################################################################
310# Prepare the output file.
311############################################################################
312writers = {}
313
314all_paths = {}
315for split in keys:
316maybe_subset = "_subset" if _FLAG_USE_SUBSET.value else ""
317paths = [os.path.join(target_path + maybe_subset, f"{split}_{i}.tfr")
318for i in range(_FLAG_NUM_SHARDS.value)
319]
320all_paths[split] = paths
321writers[split] = [tf.io.TFRecordWriter(filename) for filename in paths]
322
323with utils.log_duration(LOGGER, "main", "Loading the reference db."):
324checkpoint_path = os.path.join(
325retriever_config.query_embedder_path, "encoded", "encoded.ckpt"
326)
327
328reference_db_device = tf_utils.device_mapping().CPUs[0].name
329with tf.device(reference_db_device):
330reference_db = tf_utils.load_reference_db(
331checkpoint_path,
332variable_name="block_emb",
333)
334
335############################################################################
336# Prep the encoder and the tokenizer
337############################################################################
338with utils.log_duration(
339LOGGER, "main", "Loading the encoder model and the tokenizer."
340):
341with strategy.scope():
342query_encoder = hub.load(retriever_config.query_embedder_path, tags={})
343encode_fn = _make_encode_fn(query_encoder)
344encode_fn_strategy_run = make_encode_fn_strategy_run_fn(
345strategy=strategy,
346encode_fn=encode_fn,
347)
348
349vocab_file = os.path.join(
350retriever_config.query_embedder_path, "assets", "vocab.txt"
351)
352utils.check_exists(vocab_file)
353do_lower_case = query_encoder.signatures["tokenization_info"](
354)["do_lower_case"]
355tokenization_info = dict(
356vocab_file=vocab_file, do_lower_case=do_lower_case
357)
358
359tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
360query_encoder, tokenization_info
361)
362
363############################################################################
364# Preprocess the dataset
365############################################################################
366cls_token_id = tf.cast(
367vocab_lookup_table.lookup(tf.constant("[CLS]")), tf.int32
368)
369sep_token_id = tf.cast(
370vocab_lookup_table.lookup(tf.constant("[SEP]")), tf.int32
371)
372transform = _make_transform_fn(
373bert_tokenizer=tokenizer,
374bert_cls_token_id=cls_token_id,
375bert_sep_token_id=sep_token_id,
376)
377
378feature_dtypes = {
379constants.CTH5Fields.distances:
380tf.float32,
381constants.CTH5Fields.gpt2_retrieved_ids:
382tf.int32,
383constants.CTH5Fields.gpt2_answer_ids_inputs:
384tf.int32,
385constants.CTH5Fields.gpt2_question_ids_inputs:
386tf.int32,
387}
388
389with utils.log_duration(LOGGER, "main", "generating codes"):
390for split in keys:
391sample_count = 0
392eli5: Dict[str, datasets.Dataset]
393
394if split != "test":
395for_slices = dict(
396sample_id=eli5[split]["id"],
397question=eli5[split]["input"],
398answer=[sample["answer"][0] for sample in eli5[split]["output"]]
399)
400else:
401for_slices = dict(
402sample_id=eli5[split]["id"],
403question=eli5[split]["input"],
404)
405
406ds = tf.data.Dataset.from_tensor_slices(for_slices)
407ds = ds.map(transform, num_parallel_calls=tf.data.experimental.AUTOTUNE)
408
409ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size))
410ds = ds.map(_squeeze, num_parallel_calls=tf.data.experimental.AUTOTUNE)
411
412tqdm_inner = tqdm.tqdm(
413enumerate(ds),
414total=len(eli5[split]["id"]) // _FLAG_BATCH_SIZE.value,
415desc=f"Split `{split}`: Batches"
416)
417
418for i, batch in tqdm_inner:
419features = collections.defaultdict(list)
420
421######################################################################
422# Enforce the current real batch size
423######################################################################
424current_batch_size = batch["sample_id"].shape[0]
425for k, v in batch.items():
426utils.check_equal(v.shape[0], current_batch_size)
427######################################################################
428
429gpt2_question_ids_inputs = _prep_field(
430batch["question"], gpt2_tokenizer
431)
432utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32)
433utils.check_equal(
434gpt2_question_ids_inputs.shape[0], current_batch_size
435)
436
437if split != "test":
438gpt2_answer_ids_inputs = _prep_field(
439batch["answer"], gpt2_tokenizer
440)
441utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32)
442utils.check_equal(
443gpt2_answer_ids_inputs.shape[0], current_batch_size
444)
445
446assert len(gpt2_answer_ids_inputs.shape) == 2, (
447gpt2_answer_ids_inputs.shape
448)
449
450######################################################################
451# Save the gpt2 tokenized question and answer
452######################################################################
453
454features[constants.CTH5Fields.gpt2_question_ids_inputs].extend(
455gpt2_question_ids_inputs)
456
457if split != "test":
458features[constants.CTH5Fields.gpt2_answer_ids_inputs].extend(
459gpt2_answer_ids_inputs)
460
461######################################################################
462# Encode the samples.
463######################################################################
464batch = strategy.experimental_distribute_values_from_function(
465tf_utils.make_dict_distribute_fn(batch)
466)
467
468embeddings = encode_fn_strategy_run(batch)
469embeddings = tf_utils.process_strat_output(
470embeddings, "embeddings", strategy, current_batch_size
471)
472utils.check_isinstance(embeddings, ops.EagerTensor)
473utils.check_equal(embeddings.shape[0], current_batch_size)
474
475# pytype doesn't seem to see that we check the type
476utils.check_equal(embeddings.shape[1], _FLAG_EMBEDDING_DEPTH.value) # pytype: disable=attribute-error
477
478######################################################################
479# Retrieve.
480######################################################################
481with tf.device(reference_db_device):
482top_k, inner_prods = tf_utils.mips_exact_search(
483embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db
484)
485top_k = tf_utils.process_strat_output(
486top_k, "top_k", strategy, current_batch_size
487)
488utils.check_equal(
489inner_prods.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)
490)
491utils.check_equal(
492top_k.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)
493)
494
495features[constants.CTH5Fields.distances].extend(inner_prods)
496
497gathered = tf.gather(blocks, top_k).numpy()
498utils.check_equal(gathered.shape[0], current_batch_size)
499retrievals = []
500for j in range(gathered.shape[0]):
501local_gathered = gathered[j].tolist()
502utils.check_equal(len(local_gathered), _FLAG_NUM_RETRIEVALS.value)
503local_gathered = [sample.decode() for sample in local_gathered]
504token_ids = np.array(
505gpt2_tokenizer.batch_encode_plus(
506local_gathered,
507padding="max_length",
508truncation=True,
509).input_ids
510)
511for line in token_ids:
512assert not np.all(line == 0), line
513
514token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1
515retrievals.append(token_ids)
516features[constants.CTH5Fields.gpt2_retrieved_ids] = retrievals
517
518utils.check_equal(
519retrievals[0].shape,
520(_FLAG_NUM_RETRIEVALS.value, _FLAG_CONTEXT_SIZE.value)
521)
522
523for k, v in features.items():
524utils.check_equal(len(v), current_batch_size)
525
526for k in range(current_batch_size):
527feature = tf.train.Features(
528feature={
529k: _bytes_feature(tf.io.serialize_tensor(
530tf.cast(v[k], feature_dtypes[k])))
531for k, v in features.items()
532}
533)
534
535writers[split][sample_count % _FLAG_NUM_SHARDS.value].write(
536tf.train.Example(features=feature).SerializeToString()
537)
538sample_count += 1
539if sample_count % 1000 == 0:
540LOGGER.debug("Paths: %s", str(all_paths[split][0]))
541
542LOGGER.debug("Done.")
543
544if __name__ == "__main__":
545app.run(main)
546