google-research

Форк
0
900 строк · 29.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Dataset and model specific code.
17
"""
18
import logging
19
import os
20
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21

22
from absl import flags
23
import constants
24
import dataclasses
25
import tensorflow as tf
26
import tf_utils
27
import transformers
28
import utils
29

30
# tf.config.run_functions_eagerly(True)
31

32

33
FLAGS = flags.FLAGS
34
LOGGER = logging.getLogger(__name__)
35

36
TokenizerType = Union[transformers.PreTrainedTokenizer,
37
                      transformers.PreTrainedTokenizerFast]
38

39

40
################################################################################
41
# Model Specific
42
################################################################################
43
@dataclasses.dataclass
44
class CreateModelReturn:
45
  tokenizer: TokenizerType
46
  model: Union[transformers.PreTrainedModel, List[transformers.PreTrainedModel]]
47
  strategy: Optional[tf.distribute.Strategy]
48

49

50
def load_model(
51
    model_load_path,
52
    model_key,
53
    distribute_mode,
54
    tpu_setup,
55
    num_replicas,
56
    ):
57
  """Tries to load the model.
58

59
  Logs duration and memory use. Logs additional information if loading the model
60
  fails.
61

62
  Args:
63
    model_load_path: Where to load the model from. Needs to be a **local** path.
64
    model_key: Key used to select the correct model loading function from
65
      the MODEL_FACTORIES dict.
66
    distribute_mode: A string describing how the model is distributed.
67
    tpu_setup: TPU configuration information.
68
    num_replicas: Number of data parallelism replicas.
69

70
  Returns:
71
    Returns an object containing the tokenizer, the model and the strategy.
72

73

74
  Raises:
75
    RuntimeError: If model_load_path points to nothing.
76
  """
77
  if distribute_mode not in constants.DistributeModeChoices.choices():
78
    raise ValueError(f"Unsupported distribute_mode: `{distribute_mode}`")
79

80
  if distribute_mode in constants.STRATEGIES:
81
    ##############################################################################
82
    # Model creation in case we are using tf.distribute.Strategies.
83
    ##############################################################################
84
    # We currently don't support GPU strategies, though adding them would be
85
    # simple.
86

87
    if distribute_mode == constants.DistributeModeChoices.tpustrategy:
88
      strategy = tf.distribute.TPUStrategy(
89
          tpu_setup.resolver,
90
      )
91
    elif distribute_mode == constants.DistributeModeChoices.onedevicestrategy:
92
      # Test mode with a single device, possibly a CPU.
93
      strategy = tf.distribute.OneDeviceStrategy(tf_utils.devices_to_use()[0])
94
    else:
95
      raise NotImplementedError(distribute_mode)
96

97
    with strategy.scope():
98
      config: CreateModelReturn = MODEL_FACTORIES[model_key](
99
          model_key,
100
          distribute_mode,
101
          None  # The replicas are created by the tf.distribute.Strategy obj
102
      )
103
      config.strategy = strategy
104

105
  else:
106
    ############################################################################
107
    # Model creation in the case we aren't using strategies.
108
    ############################################################################
109
    # In this case, most of the parallelism work is done inside of the specific
110
    # model creation functions.
111

112
    config: CreateModelReturn = MODEL_FACTORIES[model_key](
113
        model_load_path,
114
        model_key,
115
        distribute_mode,
116
        num_replicas,
117
    )
118
    config.strategy = None
119
  return config
120

121

122
def _create_gpt2(
123
    model_name,
124
    distribute_mode,
125
    num_replicas  # pylint: disable=unused-argument
126
):
127
  """Loads the tokenizer and the model for the GPT2 extra large model."""
128

129
  ##############################################################################
130
  # Load the tokenizer
131
  ##############################################################################
132
  LOGGER.debug("Loading the weights: `%s`", model_name)
133
  tokenizer = transformers.GPT2TokenizerFast.from_pretrained(model_name)
134
  LOGGER.debug("Done loading the tokenizer.")
135
  LOGGER.debug("Loading the model weights.")
136

137
  ##############################################################################
138
  # Build the model(s) if we are splitting the model between devices per replica
139
  ##############################################################################
140
  if distribute_mode in {
141
      constants.DistributeModeChoices.split_and_data_parallel,
142
      constants.DistributeModeChoices.split_vertically
143
  }:
144
    # TODO(julesgm): This part needs to be reworked.
145
    raise NotImplementedError()
146

147
    # target_devices_info = tf_utils.InformationOnDevices()
148
    ############################################################################
149
    # Build the model function
150
    ############################################################################
151
    # if tf_utils.devices_to_use()[0].device_type == "CPU":
152
    #
153
    #  # Edge case of testing on a CPU-only device. Mostly for local debugging.
154
    #  # Our compute node tooling doesn't work in this case.
155
    #   def make_model(data_parallelism_rank=0):
156
    # pylint: disable=unused-argument
157
    #     model = modeling_tf_gpt2_model_par.
158
    #     TFGPT2LMHeadModel.from_pretrained(
159
    #         tf_model_path,
160
    #         config=config,
161
    #         cache_dir=cache_dir,
162
    #         devices=tf_utils.devices_to_use(),
163
    #         cpu=tf_utils.devices_to_use()[0],
164
    #     )
165
    #     return model
166
    # else:
167
    #   # Regular case with GPUs or TPUs.
168
    #   def make_model(data_parallelism_rank=0):
169
    #     # TODO(julesgm): This part needs work.
170
    #     model = modeling_tf_gpt2_model_par.
171
    #     TFGPT2LMHeadModel.from_pretrained(
172
    #         tf_model_path,
173
    #         config=config,
174
    #         cache_dir=cache_dir,
175
    #         devices=target_devices_info.devices_by_device_id[
176
    #             data_parallelism_rank],
177
    #         cpu=tf_utils.device_mapping().CPUs[1],
178
    #     )
179
    #     return model
180

181
    # ############################################################################
182
    # # Build the model(s)
183
    # ############################################################################
184
    # if distribute_mode == constants.
185
    # DistributeModeChoices.split_and_data_parallel:
186
    #   # Multiple instances if we are doing data parallelism
187
    #   if num_replicas > target_devices_info.num_devices:
188
    #     raise ValueError("num_replicas larger than "
189
    #                      "target_devices_info.num_devices. \n"
190
    #                      f" - num_replicas: {num_replicas} \n"
191
    #                      f" - num_devices:
192
    #                      {target_devices_info.num_devices}")
193
    #   model = [make_model(rank) for rank
194
    #            in range(num_replicas)]
195
    # else:
196
    #   model = make_model()
197

198
  ##############################################################################
199
  # Build the model instance otherwise
200
  ##############################################################################
201
  else:
202
    with utils.log_duration(LOGGER, "main", "Loading the model."):
203
      model = transformers.TFGPT2LMHeadModel.from_pretrained(
204
          model_name,
205
          )
206

207
  logging.debug("Done loading the %s model.", model_name)
208
  return CreateModelReturn(
209
      tokenizer=tokenizer,
210
      model=model,
211
      strategy=None,
212
  )
213

214

215
################################################################################
216
# Dataset Specific
217
################################################################################
218
def create_lm_ds_kilt_eli5(
219
    *,
220
    tokenizer,
221
    context_window_size,  # pylint: disable=unused-argument
222
    dataset_name,  # pylint: disable=unused-argument
223
    batch_size,
224
    split,
225
    db_path,  # pylint: disable=unused-argument
226
    random_seed,
227
    use_subset,  # pylint: disable=unused-argument
228
    subset_size,  # pylint: disable=unused-argument
229
    repeat,
230
    use_helper_words,
231
    approach_type,
232
    retriever,
233
    num_retrievals,
234
    retrieval_temperature,
235
    enable_debug_checks,
236
    retrieval_bank_size,  # pylint: disable=unused-argument
237
    dataset_type,
238
    qty_shuffle,
239
    tfr_prefix,
240
    max_length_generation,
241
):
242
  """Dataset preparation function for the Kilt version of the ELI5 dataset.
243

244
  This is for when the dataset is consumed by language models.
245

246
  Args:
247
    tokenizer: Tokenizer of the reader model.
248
    context_window_size: Size of the context of the reader model.
249
      Not used here.
250
    dataset_name: Exact name of the dataset. Some datasets share the same
251
      function, with small specific differences. Not used here.
252
    batch_size: Size of the batch for the reader model.
253
    prefetch_size: How many batches to prefetch.
254
    split: The train, evaluation or test split.
255
    dataset_paths_root: Root directory of the datasets. Not used here.
256
    random_seed: Seed used to shuffle the dataset. Should change at each epoch.
257
    use_subset: Whether to use a subset of the data
258
    subset_size: Size of the subset
259
    repeat: Whether to repeat the dataset
260
    use_helper_words: Whether to add helper words in the merged samples.
261
    approach_type: Type of overall solution we are using.
262
    retriever: Object that does the retrieval.
263
    num_retrievals: Number of retrievals to do.
264
    retrieval_temperature: For the retrieval methods that do sampling, what
265
      temperature to use.
266
  Returns:
267
    A tf.data.Dataset object that generates input_ids and label_ids for the
268
    generator model.
269
  Raises:
270
    RuntimeError: If we didn't find any files with the glob pattern.
271
    RuntimeError: If we are using a dataset type that is not supported.
272
  """
273

274
  maybe_retrieve_and_merge = _make_maybe_retrieve_and_merge_fn(
275
      tokenizer=tokenizer,
276
      context_size=context_window_size,
277
      retriever=retriever,
278
      temperature=retrieval_temperature,
279
      num_retrievals=num_retrievals,
280
      ds_split=split,
281
      approach_type=approach_type,  # FLAG_APPROACH_TYPE.value
282
      use_helper_words=use_helper_words,  # FLAG_USE_HELPER_WORDS
283
      enable_debug_checks=enable_debug_checks,
284
      max_length_generation=max_length_generation,
285
  )
286
  if dataset_type == constants.DatasetTypeChoices.hdf5:
287
    raise ValueError("The hdf5 dataset type is not supported anymore."
288
                     "It is strictly worse than tfr.")
289
    #
290
    # with utils.log_duration(LOGGER, "create_lm_ds_kilt_eli5",
291
    # "loading codes.h5"):
292
    #   input_file = h5py.File(tf.io.gfile.GFile(db_path, "rb"),
293
    #   "r")[split]
294
    #
295
    # if use_subset:
296
    #   new = {}
297
    #   for k, v in input_file.items():
298
    #     new[k] = v[:subset_size]
299
    #   input_file = new
300
    #
301
    # def load(field_name):
302
    #   if field_name == constants.CTH5Fields.gpt2_retrieved_ids:
303
    #     return input_file[field_name][:, :retrieval_bank_size]
304
    #   else:
305
    #     return input_file[field_name]
306
    #
307
    # with utils.log_duration(
308
    #     LOGGER, "create_lm_ds_kilt_eli5", "gpt2_question_ids_inputs"
309
    # ):
310
    #   gpt2_question_ids_inputs = load(
311
    #       constants.CTH5Fields.gpt2_question_ids_inputs
312
    #   )
313
    #
314
    # with utils.log_duration(
315
    #     LOGGER,
316
    #     "create_lm_ds_kilt_eli5",
317
    #     constants.CTH5Fields.gpt2_answer_ids_inputs
318
    # ):
319
    #   answer_ids_inputs = load(
320
    #       constants.CTH5Fields.gpt2_answer_ids_inputs
321
    #   )
322
    #
323
    # stacks = {
324
    #     constants.CTH5Fields.gpt2_question_ids_inputs:
325
    #     gpt2_question_ids_inputs,
326
    #     constants.CTH5Fields.gpt2_answer_ids_inputs:
327
    #     answer_ids_inputs,
328
    # }
329
    #
330
    # if approach_type == constants.ApproachTypeChoices.cached_pretok:
331
    #   with utils.log_duration(
332
    #       LOGGER, "create_lm_ds_kilt_eli5", constants.CTH5Fields.distances
333
    #   ):
334
    #     stacks[constants.CTH5Fields.distances] = load(
335
    #         constants.CTH5Fields.distances
336
    #     )
337
    #   with utils.log_duration(
338
    #       LOGGER,
339
    #       "create_lm_ds_kilt_eli5",
340
    #       constants.CTH5Fields.gpt2_retrieved_ids
341
    #   ):
342
    #     stacks[constants.CTH5Fields.gpt2_retrieved_ids] = load(
343
    #         constants.CTH5Fields.gpt2_retrieved_ids,
344
    #         retrieval_bank_size=retrieval_bank_size,
345
    #     )
346
    #
347
    # LOGGER.debug("from_tensor_slices")
348
    #
349
    # ds = tf.data.Dataset.from_tensor_slices(stacks)
350
  elif dataset_type == constants.DatasetTypeChoices.tfr:
351
    glob_pattern = os.path.join(tfr_prefix, f"{split}*")
352
    filenames = list(tf.io.gfile.glob(glob_pattern))
353
    if not filenames:
354
      raise RuntimeError(
355
          f"filnames is empty. Glob pattern was: {glob_pattern}"
356
      )
357

358
    ds = tf.data.TFRecordDataset(
359
        filenames=filenames,
360
        num_parallel_reads=tf.data.experimental.AUTOTUNE,
361
    )
362

363
    description: Dict[str, tf.io.FixedLenFeature] = {
364
        constants.CTH5Fields.distances:
365
            tf.io.FixedLenFeature((), tf.string),
366
        constants.CTH5Fields.gpt2_retrieved_ids:
367
            tf.io.FixedLenFeature((), tf.string),
368
        constants.CTH5Fields.gpt2_question_ids_inputs:
369
            tf.io.FixedLenFeature((), tf.string),
370
    }
371
    if split != constants.SplitChoices.test:
372
      description[
373
          constants.CTH5Fields.gpt2_answer_ids_inputs
374
      ] = tf.io.FixedLenFeature((), tf.string)
375

376
    feature_dtypes: Dict[str, tf.dtypes] = {
377
        constants.CTH5Fields.distances:
378
            tf.float32,
379
        constants.CTH5Fields.gpt2_retrieved_ids:
380
            tf.int32,
381
        constants.CTH5Fields.gpt2_question_ids_inputs:
382
            tf.int32,
383
    }
384
    if split != constants.SplitChoices.test:
385
      feature_dtypes[
386
          constants.CTH5Fields.gpt2_answer_ids_inputs
387
      ] = tf.int32
388

389
    feature_shape: Dict[str, Tuple[int, Ellipsis]] = {
390
        constants.CTH5Fields.distances:
391
            (10,),
392
        constants.CTH5Fields.gpt2_retrieved_ids:
393
            (10, context_window_size,),
394
        constants.CTH5Fields.gpt2_question_ids_inputs:
395
            (context_window_size,),
396
    }
397
    if split != constants.SplitChoices.test:
398
      feature_shape[constants.CTH5Fields.gpt2_answer_ids_inputs] = (
399
          context_window_size,
400
      )
401

402
    @tf.function
403
    def parse(sample):
404
      example = tf.io.parse_single_example(sample, description)
405
      output = {}
406
      for k, v in example.items():
407
        output[k] = tf.io.parse_tensor(v, out_type=feature_dtypes[k])
408
        output[k].set_shape(feature_shape[k])
409
      return output
410

411
    ds = ds.map(
412
        parse,
413
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
414
        deterministic=False
415
        )
416
  else:
417
    raise RuntimeError(dataset_type)
418

419
  if repeat:
420
    ds = ds.repeat()
421

422
  utils.check_not_none(random_seed)
423
  utils.check_not_none(qty_shuffle)
424
  ds = ds.shuffle(qty_shuffle, seed=random_seed)
425

426
  ds = ds.batch(
427
      batch_size,
428
      drop_remainder=split != constants.SplitChoices.test
429
  )
430

431
  # We can't use parallel calls here, the huggingface Rust fast tokenizer
432
  # breaks with multiple threads. It seems to still be worth it over their
433
  # slow one though, vs using parallel threads.
434
  ds = ds.map(maybe_retrieve_and_merge,)
435

436
  return ds.prefetch(tf.data.experimental.AUTOTUNE)
437

438

439
def _make_maybe_retrieve_and_merge_fn(
440
    *,
441
    tokenizer,
442
    context_size,
443
    ds_split,
444
    approach_type,  # FLAG_APPROACH_TYPE.value
445
    use_helper_words,  # FLAG_USE_HELPER_WORDS
446
    retriever,  # pylint: disable=unused-argument
447
    temperature,
448
    num_retrievals,
449
    enable_debug_checks,
450
    max_length_generation,
451
    tf_function_kwargs = None,
452
):
453
  """Build the `maybe_retrieve_and_merge` closure."""
454
  tf_function_kwargs = {} if tf_function_kwargs is None else tf_function_kwargs
455
  not_test_split = ds_split != constants.SplitChoices.test
456

457
  @tf.function(**tf_function_kwargs)
458
  def maybe_retrieve_and_merge(
459
      batch,
460
  ):
461
    """Retrieve if needed, then finalize the prep. for model consumption."""
462

463
    batch_size = tf.shape(batch[
464
        constants.CTH5Fields.gpt2_question_ids_inputs
465
    ])[0]
466

467
    # Prepare the question ids inputs
468
    question_ids_inputs = batch[constants.CTH5Fields.gpt2_question_ids_inputs]
469
    question_ids_inputs = tf.RaggedTensor.from_tensor(
470
        question_ids_inputs,
471
        padding=constants.RAGGED_PADDING_ID
472
    )
473

474
    # Prepare the answer ids inputs
475
    answer_ids_inputs = None
476
    answer_ids_labels = None
477
    if not_test_split:
478
      answer_ids_inputs = batch[constants.CTH5Fields.gpt2_answer_ids_inputs]
479
      answer_ids_inputs = tf.RaggedTensor.from_tensor(
480
          answer_ids_inputs,
481
          padding=constants.RAGGED_PADDING_ID
482
      )
483
      answer_ids_labels = answer_ids_inputs
484

485
    ############################################################################
486
    # Prepare the helper words
487
    ############################################################################
488
    helper_word_token_ids = None
489
    if use_helper_words:
490

491
      helper_text = {"question": "Question:\n",
492
                     "context": "\nContext:\n",
493
                     "answer": "\nAnswer:\n"
494
                     }
495
      helper_word_token_ids = {}
496
      for k in helper_text:
497
        ids = tf.constant(tokenizer.encode(helper_text[k]), dtype=tf.int32)
498
        ids = tf.repeat(tf.expand_dims(ids, 0), batch_size, axis=0)
499
        helper_word_token_ids[k] = ids
500
      question_ids_inputs = tf.concat(
501
          [helper_word_token_ids["question"], question_ids_inputs],
502
          axis=1
503
      )
504

505
    ##########################################################################
506
    # W/ Cached Retrievals
507
    ##########################################################################
508
    label_ids = None
509
    if approach_type == constants.ApproachTypeChoices.cached_pretok:
510
      bpe_indices_gpt2 = batch[constants.CTH5Fields.gpt2_retrieved_ids]
511
      bpe_indices_gpt2 = tf.RaggedTensor.from_tensor(
512
          bpe_indices_gpt2,
513
          ragged_rank=2,
514
          padding=constants.RAGGED_PADDING_ID
515
      )
516

517
      distances = batch[constants.CTH5Fields.distances]
518
      input_ids, label_ids = _prepare_samples_w_retrieval(
519
          split=ds_split,
520
          batch_size=batch_size,
521
          question_ids_inputs=question_ids_inputs,
522
          answer_ids_inputs=(
523
              answer_ids_inputs if not_test_split else None
524
          ),
525
          gpt2_tokenized_retrieved=bpe_indices_gpt2,
526
          num_retrievals=num_retrievals,
527
          temperature=temperature,
528
          context_size=context_size,
529
          enable_debug_checks=enable_debug_checks,
530
          distances=distances,
531
          max_generation_length=max_length_generation,
532
          helper_word_token_ids=(
533
              helper_word_token_ids if use_helper_words else None
534
          ),
535
          use_helper_words=use_helper_words,
536
      )
537

538
    elif approach_type == constants.ApproachTypeChoices.naked_lm:
539
      ##########################################################################
540
      # Without Retrievals
541
      ##########################################################################
542
      if use_helper_words:
543
        question_ids_inputs = tf.concat([
544
            question_ids_inputs,
545
            helper_word_token_ids["answer"],
546
        ], axis=1)
547

548
      question_ids_labels = tf.ones_like(
549
          question_ids_inputs
550
      ) * constants.PPL_MASK_ID
551

552
      if not_test_split:
553
        input_ids = tf.concat((question_ids_inputs, answer_ids_inputs),
554
                              axis=1)
555
        label_ids = tf.concat((question_ids_labels, answer_ids_labels),
556
                              axis=1)
557
      else:
558
        input_ids = question_ids_inputs
559
    else:
560
      raise RuntimeError("Unnsupported approach_type value"
561
                         f" {approach_type}")
562

563
    ############################################################################
564
    # Finalize the preparation
565
    ############################################################################
566
    # Convert to dense tensors
567
    input_ids = input_ids.to_tensor(tokenizer.eos_token_id)
568

569
    if not_test_split:
570
      final_eos = tf.RaggedTensor.from_tensor(
571
          tokenizer.eos_token_id * tf.ones([batch_size, 1], dtype=tf.int32)
572
      )
573
      label_ids = tf.concat([label_ids, final_eos], axis=1)
574
      label_ids = label_ids.to_tensor(constants.PPL_MASK_ID)
575

576
    # All samples need to have at least one token != -100 (PPL_MASK_ID)
577
    if enable_debug_checks and not_test_split:
578
      not_any_padding = tf.reduce_any(
579
          label_ids != constants.PPL_MASK_ID, axis=1
580
      )
581
      none_has_padding = tf.math.reduce_all(
582
          not_any_padding
583
      )
584
      qty_doesnt_have_padding = tf.reduce_sum(
585
          tf.cast(not_any_padding))
586

587
      check_no_padding = tf.Assert(
588
          none_has_padding,
589
          [qty_doesnt_have_padding]
590
      )
591
      with tf.control_dependencies([check_no_padding]):
592
        label_ids = tf.identity(label_ids)
593

594
    # Limit size
595
    input_ids = input_ids[:, :context_size]
596
    if not_test_split:
597
      label_ids = label_ids[:, :context_size]
598

599
    ############################################################################
600
    # Pad `input_ids` and `label_ids` to context_size
601
    ############################################################################
602
    # Prepare the ones
603
    pad_qty = tf.math.maximum(
604
        0, tf.constant(context_size) - tf.shape(input_ids)[1]
605
    )
606
    padding_ones = tf.ones(
607
        [batch_size, pad_qty],
608
        dtype=input_ids.dtype
609
    )
610
    # Pad the inputs
611
    input_padding = tokenizer.eos_token_id * padding_ones
612
    input_ids = tf.concat((input_ids, input_padding), axis=1)
613

614
    # Pad the labels labels
615
    if not_test_split:
616
      pad_qty = tf.math.maximum(
617
          0, tf.constant(context_size) - tf.shape(label_ids)[1]
618
      )
619
      padding_ones = tf.ones(
620
          [batch_size, pad_qty],
621
          dtype=input_ids.dtype
622
      )
623
      label_padding = -100 * padding_ones
624
      label_ids = tf.concat((label_ids, label_padding), axis=1)
625

626
    # Make checks
627
    if enable_debug_checks:
628
      control_dependencies = []
629
      control_dependencies.append(tf.Assert(
630
          tf.math.reduce_all(input_ids != -1),
631
          [input_ids],
632
          name="NoMinusOnesInputs"
633
      ))
634
      if not_test_split:
635
        control_dependencies.append(tf.Assert(
636
            tf.math.reduce_all(label_ids != -1),
637
            [label_ids],
638
            name="NoMinusOnesLabel"
639
        ))
640
        control_dependencies.append(tf.Assert(
641
            tf.logical_not(
642
                tf.math.reduce_any(
643
                    tf.math.reduce_all(label_ids != -100, axis=1)
644
                )
645
            ),
646
            [label_ids],
647
            name="NotAllMinusOneHundred"
648
        ))
649
      with tf.control_dependencies(control_dependencies):
650
        input_ids = tf.identity(input_ids)
651

652
    return dict(
653
        input_ids=input_ids,
654
        label_ids=label_ids if not_test_split else None
655
    )
656

657
  return maybe_retrieve_and_merge
658

659

660
@tf.function
661
def _tokenize_and_concat_while_loop(
662
    all_retrieved_tokens,
663
    indices,
664
    num_retrieved,
665
    batch_size,
666
):
667
  """Tokenizes and puts together the retrievals, per batch unit."""
668
  def condition(
669
      index,
670
      _  # pylint: disable=unused-argument
671
  ):
672
    return tf.less(index, num_retrieved)
673

674
  def body(
675
      index,
676
      concat_tokens,
677
  ):
678

679
    addition = tf.gather(all_retrieved_tokens, indices[:, index], batch_dims=1)
680

681
    concat_tokens = tf.concat([
682
        concat_tokens, addition
683
    ], axis=1)
684

685
    return index + 1, concat_tokens
686

687
  if batch_size is None:
688
    raise RuntimeError("batch_size is `None`. This should not happen.")
689

690
  return tf.while_loop(
691
      condition, body, [
692
          0, tf.RaggedTensor.from_tensor(
693
              tf.zeros(
694
                  shape=(batch_size, 0),
695
                  dtype=tf.int32
696
              ),
697
          )
698
      ])[1]
699

700

701
@tf.function
702
def _prepare_samples_w_retrieval(
703
    split,
704
    batch_size,
705
    question_ids_inputs,
706
    answer_ids_inputs,
707
    gpt2_tokenized_retrieved,
708
    distances,
709
    num_retrievals,
710
    temperature,
711
    context_size,
712
    enable_debug_checks,
713
    use_helper_words,
714
    helper_word_token_ids,
715
    max_generation_length
716
):
717
  """Prepares the samples that use retrieval."""
718
  assert (split == constants.SplitChoices.test) == (
719
      answer_ids_inputs is None
720
  ), (split == constants.SplitChoices.test, answer_ids_inputs)
721
  # If and only if
722

723
  is_not_test = split != constants.SplitChoices.test
724

725
  if not isinstance(question_ids_inputs, tf.RaggedTensor):
726
    question_ids_inputs = tf.RaggedTensor.from_tensor(
727
        question_ids_inputs,
728
        padding=constants.RAGGED_PADDING_ID
729
    )
730

731
  if enable_debug_checks:
732
    asserts = []
733
    asserts.append(
734
        tf.Assert(
735
            tf.math.reduce_all(
736
                question_ids_inputs != constants.RAGGED_PADDING_ID,
737
            ),
738
            [question_ids_inputs.to_tensor()]
739
        )
740
    )
741
    if is_not_test:
742
      asserts.append(
743
          tf.Assert(
744
              tf.math.reduce_all(
745
                  answer_ids_inputs != constants.RAGGED_PADDING_ID,
746
              ),
747
              [answer_ids_inputs.to_tensor()]
748
          )
749
      )
750
    with tf.control_dependencies(asserts):
751
      question_ids_inputs = tf.identity(question_ids_inputs)
752

753
  # These checks are at graph composition time, so OK
754
  utils.check_isinstance(question_ids_inputs, tf.RaggedTensor)
755

756
  if is_not_test:
757
    utils.check_isinstance(answer_ids_inputs, tf.RaggedTensor)
758

759
  ##############################################################################
760
  # Sample from the possible retrievals
761
  ##############################################################################
762
  # Choose the indices
763
  indices = tf_utils.sample_without_replacement(
764
      distances / temperature, num_retrievals
765
  )
766

767
  # Concatenate the retrievals
768
  concat_retrieved = _tokenize_and_concat_while_loop(
769
      gpt2_tokenized_retrieved,
770
      indices=indices,
771
      batch_size=batch_size,
772
      num_retrieved=num_retrievals,
773
  )
774

775
  # Add Context and Answer Helper Words
776
  if use_helper_words:
777
    concat_retrieved = tf.concat([
778
        helper_word_token_ids["context"],
779
        concat_retrieved,
780
    ], axis=1)
781

782
  # Cut the lengths down to max_lens_retrieval.
783
  # The eventual length of the ["question"] helper_tokens is included in
784
  # question_ids_inputs.
785
  if is_not_test:
786
    max_lens_retrieval = (
787
        context_size * tf.ones(
788
            shape=(batch_size,),
789
            dtype=tf.int64,
790
        )
791
        - (question_ids_inputs.row_lengths() +
792
           # We always generate the same length of text.
793
           max_generation_length +  # answer_ids_inputs.row_lengths() +
794
           (helper_word_token_ids["answer"].shape[1] if use_helper_words else 0)
795
           )
796
    )
797

798
  else:
799
    max_lens_retrieval = (
800
        context_size * tf.ones(
801
            shape=(batch_size,),
802
            dtype=tf.int64,
803
        ) - (question_ids_inputs.row_lengths()  +
804
             max_generation_length +
805
             (helper_word_token_ids["answer"].shape[1]
806
              if use_helper_words else 0
807
              )
808
             )
809
    )
810

811
  concat_retrieved = tf.ragged.boolean_mask(
812
      concat_retrieved,
813
      (
814
          tf.ragged.range(concat_retrieved.row_lengths()) <
815
          tf.expand_dims(max_lens_retrieval, axis=1)
816
      )
817
  )
818

819
  if enable_debug_checks:
820
    asserts = [
821
        tf.Assert(
822
            tf.math.reduce_all(max_lens_retrieval < context_size),
823
            [max_lens_retrieval, context_size]
824
        ),
825
    ]
826
    with tf.control_dependencies(asserts):
827
      concat_retrieved = tf.identity(concat_retrieved)
828

829
  if use_helper_words:
830
    if is_not_test:
831
      new_input_ids = tf.concat(
832
          [question_ids_inputs,
833
           concat_retrieved,
834
           helper_word_token_ids["answer"],
835
           answer_ids_inputs
836
           ],
837
          axis=1
838
      )
839
      new_label_ids = tf.concat(
840
          [-100 * tf.ones_like(question_ids_inputs),
841
           -100 * tf.ones_like(concat_retrieved),
842
           -100 * tf.ones_like(helper_word_token_ids["answer"]),
843
           answer_ids_inputs
844
           ],
845
          axis=1
846
      )
847
    else:
848
      new_input_ids = tf.concat(
849
          [question_ids_inputs,
850
           concat_retrieved,
851
           helper_word_token_ids["answer"],
852
           ],
853
          axis=1
854
      )
855
  else:
856
    if is_not_test:
857
      new_input_ids = tf.concat(
858
          [question_ids_inputs,
859
           concat_retrieved,
860
           answer_ids_inputs
861
           ],
862
          axis=1
863
      )
864
      new_label_ids = tf.concat(
865
          [-100 * tf.ones_like(question_ids_inputs),
866
           -100 * tf.ones_like(concat_retrieved),
867
           answer_ids_inputs
868
           ],
869
          axis=1
870
      )
871
    else:
872
      new_input_ids = tf.concat(
873
          [question_ids_inputs,
874
           concat_retrieved,
875
           ],
876
          axis=1
877
      )
878
  return new_input_ids, new_label_ids if is_not_test else None
879

880

881
################################################################################
882
# Varia
883
################################################################################
884

885
DATASET_CARDINALITIES = {
886
    constants.DatasetNameChoices.kilt_eli5: {
887
        "train": 272637,
888
        "eval": 1507,
889
        "test": 600,
890
    }
891
}
892

893
# Pick the correct model creation function from the Hugging Face Model key.
894
MODEL_FACTORIES = {
895
    "gpt2": _create_gpt2,
896
    "gpt2-medium": _create_gpt2,
897
    "gpt2-large": _create_gpt2,
898
    "gpt2-xl": _create_gpt2,
899
    "distilgpt2": _create_gpt2,
900
}
901

902

903

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.