google-research

Форк
0
/
query_cacher_tfrecord.py 
545 строк · 18.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Iterates over the whole ELI5 dataset (for each dataset split), extracts REALM or REALM++ embeddings, then does exact retrieval for a large number of neighbors, then saves the resulting db indices and distance (inner product).
17

18
This makes it so we don't have to do real retrieval during training, we just
19
sample from the neighbors as a function of the inner product, which is much
20
faster.
21

22
Examples of use:
23

24
# Local Test
25
pytype query_cacher_tfrecord.py -P . --check-variable-types \
26
--check-container-types \
27
--check-parameter-types --precise-return && \
28
python3 check_flags.py query_cacher_tfrecord.py && \
29
python3 query_cacher_tfrecord.py $(python3 json_to_args.py \
30
configs/query_cacher_configs/local.json) \
31
--logger_levels=__main__:DEBUG,utils:DEBUG,tf_utils:DEBUG \
32
--use_subset=True
33

34
"""
35
import collections
36
import logging
37
import os
38
import time
39
from typing import Callable, Dict, List
40

41
from absl import app
42
from absl import flags
43
from absl import logging as absl_logging
44
import bert_utils
45
import constants
46
import datasets
47
import numpy as np
48
import tensorflow as tf
49
import tensorflow.python.distribute.values as values
50
import tensorflow.python.framework.ops as ops
51
import tensorflow.python.trackable.autotrackable as autotrackable
52
import tensorflow_hub as hub
53
import tf_utils
54
import tqdm
55
import transformers
56
import utils
57

58

59
_FLAG_JOB_NAME = flags.DEFINE_string(
60
    "run_name",
61
    None,
62
    "Name of the run."
63
)
64
_FLAG_OUTPUT_PATH = flags.DEFINE_string(
65
    "output_dir",
66
    None,
67
    "Directory in which to save, on the cloud.")
68
_FLAG_RETRIEVER_CONFIG_PATH = flags.DEFINE_string(
69
    "retriever_config_path",
70
    None,
71
    "Path to the retriever's configuration file."
72
)
73
_FLAG_BATCH_SIZE = flags.DEFINE_integer(
74
    "batch_size",
75
    100,
76
    "Size of the batch for the encoder BERT model."
77
)
78

79
_FLAG_DATASET_ROOT = flags.DEFINE_string(
80
    "dataset_root",
81
    None,
82
    "Root of the place where the datasets are saved."
83
)
84

85
# Flags specific to query encoding
86
_FLAG_EMBEDDING_DEPTH = flags.DEFINE_integer(
87
    "embedding_depth",
88
    128,
89
    "Size of the BERT (REALM) embeddings.",
90
)
91

92
# Flags specific to retrieval caching
93
_FLAG_NUM_RETRIEVALS = flags.DEFINE_integer(
94
    "num_retrievals",
95
    10,
96
    "Number of neighbors to retrieve.",
97
)
98
_FLAG_CONTEXT_SIZE = flags.DEFINE_integer(
99
    "context_size",
100
    1024,
101
    "Length to pad to."
102
)
103
_FLAG_MAX_LENGTH_RETRIEVALS = flags.DEFINE_integer(
104
    "max_length_retrievals",
105
    350,
106
    "Maximum length of the retrievals."
107
)
108

109
_FLAG_NUM_SHARDS = flags.DEFINE_integer(
110
    "num_shards",
111
    2048,
112
    "Number of files to output tfr shards."
113
)
114

115
_FLAG_USE_SUBSET = flags.DEFINE_boolean(
116
    "use_subset",
117
    False,
118
    "Whether or not to use a subset."
119
)
120

121
_FLAG_SUBSET_QTY = flags.DEFINE_integer(
122
    "subset_qty",
123
    500,
124
    "subset_qty"
125
)
126

127
LOGGER = logging.getLogger(__name__)
128

129

130
def _bytes_feature(value):
131
  """Returns a bytes_list from a string / byte."""
132
  if isinstance(value, type(tf.constant(0))):
133
    value = value.numpy()
134
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
135

136

137
class BERTBatchFields(utils.FlagChoices):
138
  bert_question_token_ids = "bert_question_token_ids"
139
  bert_attention_masks = "bert_attention_masks"
140

141

142
def _make_transform_fn(
143
    bert_tokenizer,
144
    bert_cls_token_id,
145
    bert_sep_token_id,
146
):
147
  """Prepares the transformation function."""
148
  @tf.function
149
  def _prepare_for_bert(sample):
150
    """Prepares a question sample from ELI5 to be fed to BERT."""
151
    bert_question_token_ids = bert_tokenizer.tokenize(
152
        tf.expand_dims(sample["question"], 0))
153
    bert_question_token_ids = tf.cast(
154
        bert_question_token_ids.merge_dims(1, 2).to_tensor(), tf.int32)
155
    cls_ids = tf.fill([tf.shape(bert_question_token_ids)[0], 1],
156
                      bert_cls_token_id)
157
    sep_ids = tf.fill([tf.shape(bert_question_token_ids)[0], 1],
158
                      bert_sep_token_id)
159
    bert_question_token_ids = tf.concat(
160
        (cls_ids, bert_question_token_ids, sep_ids), 1)
161

162
    return dict(
163
        bert_question_token_ids=bert_question_token_ids,
164
        bert_attention_masks=tf.ones_like(bert_question_token_ids),
165
        **sample
166
    )
167

168
  return _prepare_for_bert
169

170

171
@tf.function
172
def _squeeze(batch):
173
  """Squeezes and converts tensors to dense tensors w/ padding."""
174
  batch = dict(**batch)
175
  batch[BERTBatchFields.bert_question_token_ids] = tf.squeeze(
176
      batch[BERTBatchFields.bert_question_token_ids].to_tensor(0), 1)
177
  batch[BERTBatchFields.bert_attention_masks] = tf.squeeze(
178
      batch[BERTBatchFields.bert_attention_masks].to_tensor(0), 1)
179
  return batch
180

181

182
def _make_encode_fn(
183
    query_encoder
184
):
185
  """Prepares the BERT encoder function."""
186

187
  @tf.function(reduce_retracing=True)
188
  def _encode(batch):
189
    """Encodes a sample with REALM BERT."""
190
    # Add a CLS token at the start of the input, and a SEP token at the end
191

192
    return query_encoder.signatures["projected"](
193
        input_ids=batch[BERTBatchFields.bert_question_token_ids],
194
        input_mask=batch[BERTBatchFields.bert_attention_masks],
195
        segment_ids=tf.zeros_like(
196
            batch[BERTBatchFields.bert_question_token_ids]
197
        ))["default"]
198

199
  return _encode
200

201

202
def make_encode_fn_strategy_run_fn(
203
    strategy,
204
    encode_fn,
205
):
206
  """Builds the runner function for the REALM query function."""
207

208
  # Giving {} as a default value would make the default value mutable, which
209
  # is prohibited (because changing the object would change the default value).
210

211
  @tf.function(reduce_retracing=True)
212
  def encode_fn_strategy_run_fn(batch):
213
    """Runs the distribute strategy on the query encoder."""
214
    return strategy.run(encode_fn, args=(batch,))
215

216
  return encode_fn_strategy_run_fn
217

218

219
######################################################################
220
# Effectuate the retrievals.
221
######################################################################
222
def _prep_field(field, gpt2_tokenizer):
223
  """Prepares different fields to be saved in a tfr."""
224
  decoded_list = [sample.decode() for sample in field.numpy().tolist()]
225
  encoded = gpt2_tokenizer.batch_encode_plus(
226
      decoded_list,
227
      padding="max_length",
228
      truncation=True,
229
  ).input_ids
230

231
  ids = np.array(
232
      encoded,
233
      dtype=np.int32,
234
  )
235

236
  ids[ids == gpt2_tokenizer.eos_token_id] = -1
237
  return ids
238

239

240
def main(argv):
241
  if len(argv) > 1:
242
    raise RuntimeError(argv)
243
  absl_logging.use_python_logging()
244
  utils.log_module_args(LOGGER, argv[0])
245

246
  retriever_config = tf_utils.REALMSave(
247
      **utils.from_json_file(_FLAG_RETRIEVER_CONFIG_PATH.value)
248
  )
249
  assert not _FLAG_USE_SUBSET.value
250

251
  time_stamp = time.strftime("%Y%m%d-%H%M%S")
252
  target_path = os.path.join(_FLAG_OUTPUT_PATH.value, time_stamp.strip())
253
  if target_path[-1] != "/":
254
    target_path += "/"
255

256
  ##############################################################################
257
  # Setup devices and strategy
258
  ##############################################################################
259
  with utils.log_duration(LOGGER, "main", "Initializing devices"):
260
    tpu_config = tf_utils.init_tpus()
261
    device_type = tf_utils.current_accelerator_type()
262
    LOGGER.debug("Devices: %s", str(tf_utils.devices_to_use()))
263

264
    if device_type == "TPU":
265
      if tpu_config is None:
266
        raise RuntimeError("We should have a tpu_config.")
267
      strategy = tf.distribute.TPUStrategy(tpu_config.resolver)
268
      batch_size = len(tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
269
    elif device_type == "GPU" or device_type == "CPU":
270
      strategy = tf.distribute.MirroredStrategy()
271
      batch_size = len(tf_utils.devices_to_use()) * _FLAG_BATCH_SIZE.value
272
    else:
273
      raise RuntimeError(device_type)
274

275
  ##############################################################################
276
  # Load the dataset.
277
  ##############################################################################
278
  eli5 = {}
279
  keys = ["train", "eval", "test"]
280
  gpt2_tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2-xl")
281
  gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
282

283
  with utils.log_duration(LOGGER, "main", "Loading the ELI5 datasets."):
284
    for split in tqdm.tqdm(keys):
285
      load_path = os.path.join(
286
          _FLAG_DATASET_ROOT.value,
287
          "HuggingfaceDatasets",
288
          f"{split}_kilt_eli5.hf"
289
      )
290
      with tf.device("/job:localhost"):
291
        eli5[split] = datasets.load_from_disk(load_path)
292

293
  ##############################################################################
294
  #
295
  ##############################################################################
296
  with utils.log_duration(
297
      LOGGER, "Main", "Load the textual dataset"
298
  ):
299
    # Extract the appropriate text
300
    # The buffer_size is taken from the original ORQA code.
301
    blocks_dataset = tf.data.TFRecordDataset(
302
        retriever_config.text_records, buffer_size=512 * 1024 * 1024
303
    )
304
    blocks_dataset = blocks_dataset.batch(
305
        retriever_config.num_block_records, drop_remainder=True
306
    )
307
    blocks = tf.data.experimental.get_single_element(blocks_dataset)
308

309
  ############################################################################
310
  # Prepare the output file.
311
  ############################################################################
312
  writers = {}
313

314
  all_paths = {}
315
  for split in keys:
316
    maybe_subset = "_subset" if _FLAG_USE_SUBSET.value else ""
317
    paths = [os.path.join(target_path + maybe_subset, f"{split}_{i}.tfr")
318
             for i in range(_FLAG_NUM_SHARDS.value)
319
             ]
320
    all_paths[split] = paths
321
    writers[split] = [tf.io.TFRecordWriter(filename) for filename in paths]
322

323
    with utils.log_duration(LOGGER, "main", "Loading the reference db."):
324
      checkpoint_path = os.path.join(
325
          retriever_config.query_embedder_path, "encoded", "encoded.ckpt"
326
      )
327

328
      reference_db_device = tf_utils.device_mapping().CPUs[0].name
329
      with tf.device(reference_db_device):
330
        reference_db = tf_utils.load_reference_db(
331
            checkpoint_path,
332
            variable_name="block_emb",
333
        )
334

335
  ############################################################################
336
  # Prep the encoder and the tokenizer
337
  ############################################################################
338
  with utils.log_duration(
339
      LOGGER, "main", "Loading the encoder model and the tokenizer."
340
  ):
341
    with strategy.scope():
342
      query_encoder = hub.load(retriever_config.query_embedder_path, tags={})
343
    encode_fn = _make_encode_fn(query_encoder)
344
    encode_fn_strategy_run = make_encode_fn_strategy_run_fn(
345
        strategy=strategy,
346
        encode_fn=encode_fn,
347
        )
348

349
    vocab_file = os.path.join(
350
        retriever_config.query_embedder_path, "assets", "vocab.txt"
351
    )
352
    utils.check_exists(vocab_file)
353
    do_lower_case = query_encoder.signatures["tokenization_info"](
354
    )["do_lower_case"]
355
    tokenization_info = dict(
356
        vocab_file=vocab_file, do_lower_case=do_lower_case
357
    )
358

359
    tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
360
        query_encoder, tokenization_info
361
    )
362

363
  ############################################################################
364
  # Preprocess the dataset
365
  ############################################################################
366
  cls_token_id = tf.cast(
367
      vocab_lookup_table.lookup(tf.constant("[CLS]")), tf.int32
368
  )
369
  sep_token_id = tf.cast(
370
      vocab_lookup_table.lookup(tf.constant("[SEP]")), tf.int32
371
  )
372
  transform = _make_transform_fn(
373
      bert_tokenizer=tokenizer,
374
      bert_cls_token_id=cls_token_id,
375
      bert_sep_token_id=sep_token_id,
376
  )
377

378
  feature_dtypes = {
379
      constants.CTH5Fields.distances:
380
          tf.float32,
381
      constants.CTH5Fields.gpt2_retrieved_ids:
382
          tf.int32,
383
      constants.CTH5Fields.gpt2_answer_ids_inputs:
384
          tf.int32,
385
      constants.CTH5Fields.gpt2_question_ids_inputs:
386
          tf.int32,
387
  }
388

389
  with utils.log_duration(LOGGER, "main", "generating codes"):
390
    for split in keys:
391
      sample_count = 0
392
      eli5: Dict[str, datasets.Dataset]
393

394
      if split != "test":
395
        for_slices = dict(
396
            sample_id=eli5[split]["id"],
397
            question=eli5[split]["input"],
398
            answer=[sample["answer"][0] for sample in eli5[split]["output"]]
399
        )
400
      else:
401
        for_slices = dict(
402
            sample_id=eli5[split]["id"],
403
            question=eli5[split]["input"],
404
        )
405

406
      ds = tf.data.Dataset.from_tensor_slices(for_slices)
407
      ds = ds.map(transform, num_parallel_calls=tf.data.experimental.AUTOTUNE)
408

409
      ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size))
410
      ds = ds.map(_squeeze, num_parallel_calls=tf.data.experimental.AUTOTUNE)
411

412
      tqdm_inner = tqdm.tqdm(
413
          enumerate(ds),
414
          total=len(eli5[split]["id"]) // _FLAG_BATCH_SIZE.value,
415
          desc=f"Split `{split}`: Batches"
416
      )
417

418
      for i, batch in tqdm_inner:
419
        features = collections.defaultdict(list)
420

421
        ######################################################################
422
        # Enforce the current real batch size
423
        ######################################################################
424
        current_batch_size = batch["sample_id"].shape[0]
425
        for k, v in batch.items():
426
          utils.check_equal(v.shape[0], current_batch_size)
427
        ######################################################################
428

429
        gpt2_question_ids_inputs = _prep_field(
430
            batch["question"], gpt2_tokenizer
431
        )
432
        utils.check_equal(gpt2_question_ids_inputs.dtype, np.int32)
433
        utils.check_equal(
434
            gpt2_question_ids_inputs.shape[0], current_batch_size
435
        )
436

437
        if split != "test":
438
          gpt2_answer_ids_inputs = _prep_field(
439
              batch["answer"], gpt2_tokenizer
440
          )
441
          utils.check_equal(gpt2_answer_ids_inputs.dtype, np.int32)
442
          utils.check_equal(
443
              gpt2_answer_ids_inputs.shape[0], current_batch_size
444
          )
445

446
          assert len(gpt2_answer_ids_inputs.shape) == 2, (
447
              gpt2_answer_ids_inputs.shape
448
          )
449

450
        ######################################################################
451
        # Save the gpt2 tokenized question and answer
452
        ######################################################################
453

454
        features[constants.CTH5Fields.gpt2_question_ids_inputs].extend(
455
            gpt2_question_ids_inputs)
456

457
        if split != "test":
458
          features[constants.CTH5Fields.gpt2_answer_ids_inputs].extend(
459
              gpt2_answer_ids_inputs)
460

461
        ######################################################################
462
        # Encode the samples.
463
        ######################################################################
464
        batch = strategy.experimental_distribute_values_from_function(
465
            tf_utils.make_dict_distribute_fn(batch)
466
        )
467

468
        embeddings = encode_fn_strategy_run(batch)
469
        embeddings = tf_utils.process_strat_output(
470
            embeddings, "embeddings", strategy, current_batch_size
471
        )
472
        utils.check_isinstance(embeddings, ops.EagerTensor)
473
        utils.check_equal(embeddings.shape[0], current_batch_size)
474

475
        # pytype doesn't seem to see that we check the type
476
        utils.check_equal(embeddings.shape[1], _FLAG_EMBEDDING_DEPTH.value)  # pytype: disable=attribute-error
477

478
        ######################################################################
479
        # Retrieve.
480
        ######################################################################
481
        with tf.device(reference_db_device):
482
          top_k, inner_prods = tf_utils.mips_exact_search(
483
              embeddings, _FLAG_NUM_RETRIEVALS.value, reference_db
484
          )
485
        top_k = tf_utils.process_strat_output(
486
            top_k, "top_k", strategy, current_batch_size
487
        )
488
        utils.check_equal(
489
            inner_prods.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)
490
        )
491
        utils.check_equal(
492
            top_k.shape, (current_batch_size, _FLAG_NUM_RETRIEVALS.value)
493
        )
494

495
        features[constants.CTH5Fields.distances].extend(inner_prods)
496

497
        gathered = tf.gather(blocks, top_k).numpy()
498
        utils.check_equal(gathered.shape[0], current_batch_size)
499
        retrievals = []
500
        for j in range(gathered.shape[0]):
501
          local_gathered = gathered[j].tolist()
502
          utils.check_equal(len(local_gathered), _FLAG_NUM_RETRIEVALS.value)
503
          local_gathered = [sample.decode() for sample in local_gathered]
504
          token_ids = np.array(
505
              gpt2_tokenizer.batch_encode_plus(
506
                  local_gathered,
507
                  padding="max_length",
508
                  truncation=True,
509
              ).input_ids
510
          )
511
          for line in token_ids:
512
            assert not np.all(line == 0), line
513

514
          token_ids[token_ids == gpt2_tokenizer.eos_token_id] = -1
515
          retrievals.append(token_ids)
516
        features[constants.CTH5Fields.gpt2_retrieved_ids] = retrievals
517

518
        utils.check_equal(
519
            retrievals[0].shape,
520
            (_FLAG_NUM_RETRIEVALS.value, _FLAG_CONTEXT_SIZE.value)
521
        )
522

523
        for k, v in features.items():
524
          utils.check_equal(len(v), current_batch_size)
525

526
        for k in range(current_batch_size):
527
          feature = tf.train.Features(
528
              feature={
529
                  k: _bytes_feature(tf.io.serialize_tensor(
530
                      tf.cast(v[k], feature_dtypes[k])))
531
                  for k, v in features.items()
532
              }
533
          )
534

535
          writers[split][sample_count % _FLAG_NUM_SHARDS.value].write(
536
              tf.train.Example(features=feature).SerializeToString()
537
          )
538
          sample_count += 1
539
        if sample_count % 1000 == 0:
540
          LOGGER.debug("Paths: %s", str(all_paths[split][0]))
541

542
  LOGGER.debug("Done.")
543

544
if __name__ == "__main__":
545
  app.run(main)
546

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.