google-research
900 строк · 29.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Dataset and model specific code.
17"""
18import logging19import os20from typing import Any, Callable, Dict, List, Optional, Tuple, Union21
22from absl import flags23import constants24import dataclasses25import tensorflow as tf26import tf_utils27import transformers28import utils29
30# tf.config.run_functions_eagerly(True)
31
32
33FLAGS = flags.FLAGS34LOGGER = logging.getLogger(__name__)35
36TokenizerType = Union[transformers.PreTrainedTokenizer,37transformers.PreTrainedTokenizerFast]38
39
40################################################################################
41# Model Specific
42################################################################################
43@dataclasses.dataclass44class CreateModelReturn:45tokenizer: TokenizerType46model: Union[transformers.PreTrainedModel, List[transformers.PreTrainedModel]]47strategy: Optional[tf.distribute.Strategy]48
49
50def load_model(51model_load_path,52model_key,53distribute_mode,54tpu_setup,55num_replicas,56):57"""Tries to load the model.58
59Logs duration and memory use. Logs additional information if loading the model
60fails.
61
62Args:
63model_load_path: Where to load the model from. Needs to be a **local** path.
64model_key: Key used to select the correct model loading function from
65the MODEL_FACTORIES dict.
66distribute_mode: A string describing how the model is distributed.
67tpu_setup: TPU configuration information.
68num_replicas: Number of data parallelism replicas.
69
70Returns:
71Returns an object containing the tokenizer, the model and the strategy.
72
73
74Raises:
75RuntimeError: If model_load_path points to nothing.
76"""
77if distribute_mode not in constants.DistributeModeChoices.choices():78raise ValueError(f"Unsupported distribute_mode: `{distribute_mode}`")79
80if distribute_mode in constants.STRATEGIES:81##############################################################################82# Model creation in case we are using tf.distribute.Strategies.83##############################################################################84# We currently don't support GPU strategies, though adding them would be85# simple.86
87if distribute_mode == constants.DistributeModeChoices.tpustrategy:88strategy = tf.distribute.TPUStrategy(89tpu_setup.resolver,90)91elif distribute_mode == constants.DistributeModeChoices.onedevicestrategy:92# Test mode with a single device, possibly a CPU.93strategy = tf.distribute.OneDeviceStrategy(tf_utils.devices_to_use()[0])94else:95raise NotImplementedError(distribute_mode)96
97with strategy.scope():98config: CreateModelReturn = MODEL_FACTORIES[model_key](99model_key,100distribute_mode,101None # The replicas are created by the tf.distribute.Strategy obj102)103config.strategy = strategy104
105else:106############################################################################107# Model creation in the case we aren't using strategies.108############################################################################109# In this case, most of the parallelism work is done inside of the specific110# model creation functions.111
112config: CreateModelReturn = MODEL_FACTORIES[model_key](113model_load_path,114model_key,115distribute_mode,116num_replicas,117)118config.strategy = None119return config120
121
122def _create_gpt2(123model_name,124distribute_mode,125num_replicas # pylint: disable=unused-argument126):127"""Loads the tokenizer and the model for the GPT2 extra large model."""128
129##############################################################################130# Load the tokenizer131##############################################################################132LOGGER.debug("Loading the weights: `%s`", model_name)133tokenizer = transformers.GPT2TokenizerFast.from_pretrained(model_name)134LOGGER.debug("Done loading the tokenizer.")135LOGGER.debug("Loading the model weights.")136
137##############################################################################138# Build the model(s) if we are splitting the model between devices per replica139##############################################################################140if distribute_mode in {141constants.DistributeModeChoices.split_and_data_parallel,142constants.DistributeModeChoices.split_vertically143}:144# TODO(julesgm): This part needs to be reworked.145raise NotImplementedError()146
147# target_devices_info = tf_utils.InformationOnDevices()148############################################################################149# Build the model function150############################################################################151# if tf_utils.devices_to_use()[0].device_type == "CPU":152#153# # Edge case of testing on a CPU-only device. Mostly for local debugging.154# # Our compute node tooling doesn't work in this case.155# def make_model(data_parallelism_rank=0):156# pylint: disable=unused-argument157# model = modeling_tf_gpt2_model_par.158# TFGPT2LMHeadModel.from_pretrained(159# tf_model_path,160# config=config,161# cache_dir=cache_dir,162# devices=tf_utils.devices_to_use(),163# cpu=tf_utils.devices_to_use()[0],164# )165# return model166# else:167# # Regular case with GPUs or TPUs.168# def make_model(data_parallelism_rank=0):169# # TODO(julesgm): This part needs work.170# model = modeling_tf_gpt2_model_par.171# TFGPT2LMHeadModel.from_pretrained(172# tf_model_path,173# config=config,174# cache_dir=cache_dir,175# devices=target_devices_info.devices_by_device_id[176# data_parallelism_rank],177# cpu=tf_utils.device_mapping().CPUs[1],178# )179# return model180
181# ############################################################################182# # Build the model(s)183# ############################################################################184# if distribute_mode == constants.185# DistributeModeChoices.split_and_data_parallel:186# # Multiple instances if we are doing data parallelism187# if num_replicas > target_devices_info.num_devices:188# raise ValueError("num_replicas larger than "189# "target_devices_info.num_devices. \n"190# f" - num_replicas: {num_replicas} \n"191# f" - num_devices:192# {target_devices_info.num_devices}")193# model = [make_model(rank) for rank194# in range(num_replicas)]195# else:196# model = make_model()197
198##############################################################################199# Build the model instance otherwise200##############################################################################201else:202with utils.log_duration(LOGGER, "main", "Loading the model."):203model = transformers.TFGPT2LMHeadModel.from_pretrained(204model_name,205)206
207logging.debug("Done loading the %s model.", model_name)208return CreateModelReturn(209tokenizer=tokenizer,210model=model,211strategy=None,212)213
214
215################################################################################
216# Dataset Specific
217################################################################################
218def create_lm_ds_kilt_eli5(219*,220tokenizer,221context_window_size, # pylint: disable=unused-argument222dataset_name, # pylint: disable=unused-argument223batch_size,224split,225db_path, # pylint: disable=unused-argument226random_seed,227use_subset, # pylint: disable=unused-argument228subset_size, # pylint: disable=unused-argument229repeat,230use_helper_words,231approach_type,232retriever,233num_retrievals,234retrieval_temperature,235enable_debug_checks,236retrieval_bank_size, # pylint: disable=unused-argument237dataset_type,238qty_shuffle,239tfr_prefix,240max_length_generation,241):242"""Dataset preparation function for the Kilt version of the ELI5 dataset.243
244This is for when the dataset is consumed by language models.
245
246Args:
247tokenizer: Tokenizer of the reader model.
248context_window_size: Size of the context of the reader model.
249Not used here.
250dataset_name: Exact name of the dataset. Some datasets share the same
251function, with small specific differences. Not used here.
252batch_size: Size of the batch for the reader model.
253prefetch_size: How many batches to prefetch.
254split: The train, evaluation or test split.
255dataset_paths_root: Root directory of the datasets. Not used here.
256random_seed: Seed used to shuffle the dataset. Should change at each epoch.
257use_subset: Whether to use a subset of the data
258subset_size: Size of the subset
259repeat: Whether to repeat the dataset
260use_helper_words: Whether to add helper words in the merged samples.
261approach_type: Type of overall solution we are using.
262retriever: Object that does the retrieval.
263num_retrievals: Number of retrievals to do.
264retrieval_temperature: For the retrieval methods that do sampling, what
265temperature to use.
266Returns:
267A tf.data.Dataset object that generates input_ids and label_ids for the
268generator model.
269Raises:
270RuntimeError: If we didn't find any files with the glob pattern.
271RuntimeError: If we are using a dataset type that is not supported.
272"""
273
274maybe_retrieve_and_merge = _make_maybe_retrieve_and_merge_fn(275tokenizer=tokenizer,276context_size=context_window_size,277retriever=retriever,278temperature=retrieval_temperature,279num_retrievals=num_retrievals,280ds_split=split,281approach_type=approach_type, # FLAG_APPROACH_TYPE.value282use_helper_words=use_helper_words, # FLAG_USE_HELPER_WORDS283enable_debug_checks=enable_debug_checks,284max_length_generation=max_length_generation,285)286if dataset_type == constants.DatasetTypeChoices.hdf5:287raise ValueError("The hdf5 dataset type is not supported anymore."288"It is strictly worse than tfr.")289#290# with utils.log_duration(LOGGER, "create_lm_ds_kilt_eli5",291# "loading codes.h5"):292# input_file = h5py.File(tf.io.gfile.GFile(db_path, "rb"),293# "r")[split]294#295# if use_subset:296# new = {}297# for k, v in input_file.items():298# new[k] = v[:subset_size]299# input_file = new300#301# def load(field_name):302# if field_name == constants.CTH5Fields.gpt2_retrieved_ids:303# return input_file[field_name][:, :retrieval_bank_size]304# else:305# return input_file[field_name]306#307# with utils.log_duration(308# LOGGER, "create_lm_ds_kilt_eli5", "gpt2_question_ids_inputs"309# ):310# gpt2_question_ids_inputs = load(311# constants.CTH5Fields.gpt2_question_ids_inputs312# )313#314# with utils.log_duration(315# LOGGER,316# "create_lm_ds_kilt_eli5",317# constants.CTH5Fields.gpt2_answer_ids_inputs318# ):319# answer_ids_inputs = load(320# constants.CTH5Fields.gpt2_answer_ids_inputs321# )322#323# stacks = {324# constants.CTH5Fields.gpt2_question_ids_inputs:325# gpt2_question_ids_inputs,326# constants.CTH5Fields.gpt2_answer_ids_inputs:327# answer_ids_inputs,328# }329#330# if approach_type == constants.ApproachTypeChoices.cached_pretok:331# with utils.log_duration(332# LOGGER, "create_lm_ds_kilt_eli5", constants.CTH5Fields.distances333# ):334# stacks[constants.CTH5Fields.distances] = load(335# constants.CTH5Fields.distances336# )337# with utils.log_duration(338# LOGGER,339# "create_lm_ds_kilt_eli5",340# constants.CTH5Fields.gpt2_retrieved_ids341# ):342# stacks[constants.CTH5Fields.gpt2_retrieved_ids] = load(343# constants.CTH5Fields.gpt2_retrieved_ids,344# retrieval_bank_size=retrieval_bank_size,345# )346#347# LOGGER.debug("from_tensor_slices")348#349# ds = tf.data.Dataset.from_tensor_slices(stacks)350elif dataset_type == constants.DatasetTypeChoices.tfr:351glob_pattern = os.path.join(tfr_prefix, f"{split}*")352filenames = list(tf.io.gfile.glob(glob_pattern))353if not filenames:354raise RuntimeError(355f"filnames is empty. Glob pattern was: {glob_pattern}"356)357
358ds = tf.data.TFRecordDataset(359filenames=filenames,360num_parallel_reads=tf.data.experimental.AUTOTUNE,361)362
363description: Dict[str, tf.io.FixedLenFeature] = {364constants.CTH5Fields.distances:365tf.io.FixedLenFeature((), tf.string),366constants.CTH5Fields.gpt2_retrieved_ids:367tf.io.FixedLenFeature((), tf.string),368constants.CTH5Fields.gpt2_question_ids_inputs:369tf.io.FixedLenFeature((), tf.string),370}371if split != constants.SplitChoices.test:372description[373constants.CTH5Fields.gpt2_answer_ids_inputs374] = tf.io.FixedLenFeature((), tf.string)375
376feature_dtypes: Dict[str, tf.dtypes] = {377constants.CTH5Fields.distances:378tf.float32,379constants.CTH5Fields.gpt2_retrieved_ids:380tf.int32,381constants.CTH5Fields.gpt2_question_ids_inputs:382tf.int32,383}384if split != constants.SplitChoices.test:385feature_dtypes[386constants.CTH5Fields.gpt2_answer_ids_inputs387] = tf.int32388
389feature_shape: Dict[str, Tuple[int, Ellipsis]] = {390constants.CTH5Fields.distances:391(10,),392constants.CTH5Fields.gpt2_retrieved_ids:393(10, context_window_size,),394constants.CTH5Fields.gpt2_question_ids_inputs:395(context_window_size,),396}397if split != constants.SplitChoices.test:398feature_shape[constants.CTH5Fields.gpt2_answer_ids_inputs] = (399context_window_size,400)401
402@tf.function403def parse(sample):404example = tf.io.parse_single_example(sample, description)405output = {}406for k, v in example.items():407output[k] = tf.io.parse_tensor(v, out_type=feature_dtypes[k])408output[k].set_shape(feature_shape[k])409return output410
411ds = ds.map(412parse,413num_parallel_calls=tf.data.experimental.AUTOTUNE,414deterministic=False415)416else:417raise RuntimeError(dataset_type)418
419if repeat:420ds = ds.repeat()421
422utils.check_not_none(random_seed)423utils.check_not_none(qty_shuffle)424ds = ds.shuffle(qty_shuffle, seed=random_seed)425
426ds = ds.batch(427batch_size,428drop_remainder=split != constants.SplitChoices.test429)430
431# We can't use parallel calls here, the huggingface Rust fast tokenizer432# breaks with multiple threads. It seems to still be worth it over their433# slow one though, vs using parallel threads.434ds = ds.map(maybe_retrieve_and_merge,)435
436return ds.prefetch(tf.data.experimental.AUTOTUNE)437
438
439def _make_maybe_retrieve_and_merge_fn(440*,441tokenizer,442context_size,443ds_split,444approach_type, # FLAG_APPROACH_TYPE.value445use_helper_words, # FLAG_USE_HELPER_WORDS446retriever, # pylint: disable=unused-argument447temperature,448num_retrievals,449enable_debug_checks,450max_length_generation,451tf_function_kwargs = None,452):453"""Build the `maybe_retrieve_and_merge` closure."""454tf_function_kwargs = {} if tf_function_kwargs is None else tf_function_kwargs455not_test_split = ds_split != constants.SplitChoices.test456
457@tf.function(**tf_function_kwargs)458def maybe_retrieve_and_merge(459batch,460):461"""Retrieve if needed, then finalize the prep. for model consumption."""462
463batch_size = tf.shape(batch[464constants.CTH5Fields.gpt2_question_ids_inputs465])[0]466
467# Prepare the question ids inputs468question_ids_inputs = batch[constants.CTH5Fields.gpt2_question_ids_inputs]469question_ids_inputs = tf.RaggedTensor.from_tensor(470question_ids_inputs,471padding=constants.RAGGED_PADDING_ID472)473
474# Prepare the answer ids inputs475answer_ids_inputs = None476answer_ids_labels = None477if not_test_split:478answer_ids_inputs = batch[constants.CTH5Fields.gpt2_answer_ids_inputs]479answer_ids_inputs = tf.RaggedTensor.from_tensor(480answer_ids_inputs,481padding=constants.RAGGED_PADDING_ID482)483answer_ids_labels = answer_ids_inputs484
485############################################################################486# Prepare the helper words487############################################################################488helper_word_token_ids = None489if use_helper_words:490
491helper_text = {"question": "Question:\n",492"context": "\nContext:\n",493"answer": "\nAnswer:\n"494}495helper_word_token_ids = {}496for k in helper_text:497ids = tf.constant(tokenizer.encode(helper_text[k]), dtype=tf.int32)498ids = tf.repeat(tf.expand_dims(ids, 0), batch_size, axis=0)499helper_word_token_ids[k] = ids500question_ids_inputs = tf.concat(501[helper_word_token_ids["question"], question_ids_inputs],502axis=1503)504
505##########################################################################506# W/ Cached Retrievals507##########################################################################508label_ids = None509if approach_type == constants.ApproachTypeChoices.cached_pretok:510bpe_indices_gpt2 = batch[constants.CTH5Fields.gpt2_retrieved_ids]511bpe_indices_gpt2 = tf.RaggedTensor.from_tensor(512bpe_indices_gpt2,513ragged_rank=2,514padding=constants.RAGGED_PADDING_ID515)516
517distances = batch[constants.CTH5Fields.distances]518input_ids, label_ids = _prepare_samples_w_retrieval(519split=ds_split,520batch_size=batch_size,521question_ids_inputs=question_ids_inputs,522answer_ids_inputs=(523answer_ids_inputs if not_test_split else None524),525gpt2_tokenized_retrieved=bpe_indices_gpt2,526num_retrievals=num_retrievals,527temperature=temperature,528context_size=context_size,529enable_debug_checks=enable_debug_checks,530distances=distances,531max_generation_length=max_length_generation,532helper_word_token_ids=(533helper_word_token_ids if use_helper_words else None534),535use_helper_words=use_helper_words,536)537
538elif approach_type == constants.ApproachTypeChoices.naked_lm:539##########################################################################540# Without Retrievals541##########################################################################542if use_helper_words:543question_ids_inputs = tf.concat([544question_ids_inputs,545helper_word_token_ids["answer"],546], axis=1)547
548question_ids_labels = tf.ones_like(549question_ids_inputs
550) * constants.PPL_MASK_ID551
552if not_test_split:553input_ids = tf.concat((question_ids_inputs, answer_ids_inputs),554axis=1)555label_ids = tf.concat((question_ids_labels, answer_ids_labels),556axis=1)557else:558input_ids = question_ids_inputs559else:560raise RuntimeError("Unnsupported approach_type value"561f" {approach_type}")562
563############################################################################564# Finalize the preparation565############################################################################566# Convert to dense tensors567input_ids = input_ids.to_tensor(tokenizer.eos_token_id)568
569if not_test_split:570final_eos = tf.RaggedTensor.from_tensor(571tokenizer.eos_token_id * tf.ones([batch_size, 1], dtype=tf.int32)572)573label_ids = tf.concat([label_ids, final_eos], axis=1)574label_ids = label_ids.to_tensor(constants.PPL_MASK_ID)575
576# All samples need to have at least one token != -100 (PPL_MASK_ID)577if enable_debug_checks and not_test_split:578not_any_padding = tf.reduce_any(579label_ids != constants.PPL_MASK_ID, axis=1580)581none_has_padding = tf.math.reduce_all(582not_any_padding
583)584qty_doesnt_have_padding = tf.reduce_sum(585tf.cast(not_any_padding))586
587check_no_padding = tf.Assert(588none_has_padding,589[qty_doesnt_have_padding]590)591with tf.control_dependencies([check_no_padding]):592label_ids = tf.identity(label_ids)593
594# Limit size595input_ids = input_ids[:, :context_size]596if not_test_split:597label_ids = label_ids[:, :context_size]598
599############################################################################600# Pad `input_ids` and `label_ids` to context_size601############################################################################602# Prepare the ones603pad_qty = tf.math.maximum(6040, tf.constant(context_size) - tf.shape(input_ids)[1]605)606padding_ones = tf.ones(607[batch_size, pad_qty],608dtype=input_ids.dtype609)610# Pad the inputs611input_padding = tokenizer.eos_token_id * padding_ones612input_ids = tf.concat((input_ids, input_padding), axis=1)613
614# Pad the labels labels615if not_test_split:616pad_qty = tf.math.maximum(6170, tf.constant(context_size) - tf.shape(label_ids)[1]618)619padding_ones = tf.ones(620[batch_size, pad_qty],621dtype=input_ids.dtype622)623label_padding = -100 * padding_ones624label_ids = tf.concat((label_ids, label_padding), axis=1)625
626# Make checks627if enable_debug_checks:628control_dependencies = []629control_dependencies.append(tf.Assert(630tf.math.reduce_all(input_ids != -1),631[input_ids],632name="NoMinusOnesInputs"633))634if not_test_split:635control_dependencies.append(tf.Assert(636tf.math.reduce_all(label_ids != -1),637[label_ids],638name="NoMinusOnesLabel"639))640control_dependencies.append(tf.Assert(641tf.logical_not(642tf.math.reduce_any(643tf.math.reduce_all(label_ids != -100, axis=1)644)645),646[label_ids],647name="NotAllMinusOneHundred"648))649with tf.control_dependencies(control_dependencies):650input_ids = tf.identity(input_ids)651
652return dict(653input_ids=input_ids,654label_ids=label_ids if not_test_split else None655)656
657return maybe_retrieve_and_merge658
659
660@tf.function661def _tokenize_and_concat_while_loop(662all_retrieved_tokens,663indices,664num_retrieved,665batch_size,666):667"""Tokenizes and puts together the retrievals, per batch unit."""668def condition(669index,670_ # pylint: disable=unused-argument671):672return tf.less(index, num_retrieved)673
674def body(675index,676concat_tokens,677):678
679addition = tf.gather(all_retrieved_tokens, indices[:, index], batch_dims=1)680
681concat_tokens = tf.concat([682concat_tokens, addition683], axis=1)684
685return index + 1, concat_tokens686
687if batch_size is None:688raise RuntimeError("batch_size is `None`. This should not happen.")689
690return tf.while_loop(691condition, body, [6920, tf.RaggedTensor.from_tensor(693tf.zeros(694shape=(batch_size, 0),695dtype=tf.int32696),697)698])[1]699
700
701@tf.function702def _prepare_samples_w_retrieval(703split,704batch_size,705question_ids_inputs,706answer_ids_inputs,707gpt2_tokenized_retrieved,708distances,709num_retrievals,710temperature,711context_size,712enable_debug_checks,713use_helper_words,714helper_word_token_ids,715max_generation_length
716):717"""Prepares the samples that use retrieval."""718assert (split == constants.SplitChoices.test) == (719answer_ids_inputs is None720), (split == constants.SplitChoices.test, answer_ids_inputs)721# If and only if722
723is_not_test = split != constants.SplitChoices.test724
725if not isinstance(question_ids_inputs, tf.RaggedTensor):726question_ids_inputs = tf.RaggedTensor.from_tensor(727question_ids_inputs,728padding=constants.RAGGED_PADDING_ID729)730
731if enable_debug_checks:732asserts = []733asserts.append(734tf.Assert(735tf.math.reduce_all(736question_ids_inputs != constants.RAGGED_PADDING_ID,737),738[question_ids_inputs.to_tensor()]739)740)741if is_not_test:742asserts.append(743tf.Assert(744tf.math.reduce_all(745answer_ids_inputs != constants.RAGGED_PADDING_ID,746),747[answer_ids_inputs.to_tensor()]748)749)750with tf.control_dependencies(asserts):751question_ids_inputs = tf.identity(question_ids_inputs)752
753# These checks are at graph composition time, so OK754utils.check_isinstance(question_ids_inputs, tf.RaggedTensor)755
756if is_not_test:757utils.check_isinstance(answer_ids_inputs, tf.RaggedTensor)758
759##############################################################################760# Sample from the possible retrievals761##############################################################################762# Choose the indices763indices = tf_utils.sample_without_replacement(764distances / temperature, num_retrievals765)766
767# Concatenate the retrievals768concat_retrieved = _tokenize_and_concat_while_loop(769gpt2_tokenized_retrieved,770indices=indices,771batch_size=batch_size,772num_retrieved=num_retrievals,773)774
775# Add Context and Answer Helper Words776if use_helper_words:777concat_retrieved = tf.concat([778helper_word_token_ids["context"],779concat_retrieved,780], axis=1)781
782# Cut the lengths down to max_lens_retrieval.783# The eventual length of the ["question"] helper_tokens is included in784# question_ids_inputs.785if is_not_test:786max_lens_retrieval = (787context_size * tf.ones(788shape=(batch_size,),789dtype=tf.int64,790)791- (question_ids_inputs.row_lengths() +792# We always generate the same length of text.793max_generation_length + # answer_ids_inputs.row_lengths() +794(helper_word_token_ids["answer"].shape[1] if use_helper_words else 0)795)796)797
798else:799max_lens_retrieval = (800context_size * tf.ones(801shape=(batch_size,),802dtype=tf.int64,803) - (question_ids_inputs.row_lengths() +804max_generation_length +805(helper_word_token_ids["answer"].shape[1]806if use_helper_words else 0807)808)809)810
811concat_retrieved = tf.ragged.boolean_mask(812concat_retrieved,813(814tf.ragged.range(concat_retrieved.row_lengths()) <815tf.expand_dims(max_lens_retrieval, axis=1)816)817)818
819if enable_debug_checks:820asserts = [821tf.Assert(822tf.math.reduce_all(max_lens_retrieval < context_size),823[max_lens_retrieval, context_size]824),825]826with tf.control_dependencies(asserts):827concat_retrieved = tf.identity(concat_retrieved)828
829if use_helper_words:830if is_not_test:831new_input_ids = tf.concat(832[question_ids_inputs,833concat_retrieved,834helper_word_token_ids["answer"],835answer_ids_inputs
836],837axis=1838)839new_label_ids = tf.concat(840[-100 * tf.ones_like(question_ids_inputs),841-100 * tf.ones_like(concat_retrieved),842-100 * tf.ones_like(helper_word_token_ids["answer"]),843answer_ids_inputs
844],845axis=1846)847else:848new_input_ids = tf.concat(849[question_ids_inputs,850concat_retrieved,851helper_word_token_ids["answer"],852],853axis=1854)855else:856if is_not_test:857new_input_ids = tf.concat(858[question_ids_inputs,859concat_retrieved,860answer_ids_inputs
861],862axis=1863)864new_label_ids = tf.concat(865[-100 * tf.ones_like(question_ids_inputs),866-100 * tf.ones_like(concat_retrieved),867answer_ids_inputs
868],869axis=1870)871else:872new_input_ids = tf.concat(873[question_ids_inputs,874concat_retrieved,875],876axis=1877)878return new_input_ids, new_label_ids if is_not_test else None879
880
881################################################################################
882# Varia
883################################################################################
884
885DATASET_CARDINALITIES = {886constants.DatasetNameChoices.kilt_eli5: {887"train": 272637,888"eval": 1507,889"test": 600,890}891}
892
893# Pick the correct model creation function from the Hugging Face Model key.
894MODEL_FACTORIES = {895"gpt2": _create_gpt2,896"gpt2-medium": _create_gpt2,897"gpt2-large": _create_gpt2,898"gpt2-xl": _create_gpt2,899"distilgpt2": _create_gpt2,900}
901
902
903