google-research
844 строки · 28.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Helper module for dealing with datasets loaded from TFDS."""
17
18import copy19import enum20from typing import List, Dict, Optional, Text, Any, Tuple, Callable21import attr22import cv223import numpy as np24import tensorflow.compat.v1 as tf25import tensorflow_datasets as tfds26
27
28def tfds_load_dataset(dataset_name, *args, **kwargs):29"""Helper function used to bridge internal google, and the external world."""30data_dir = kwargs.pop("data_dir", None)31return tfds.load(32dataset_name, *args, data_dir=data_dir, download=True, **kwargs)33
34
35class Split(enum.Enum):36"""Enum representing different splits of data for cross validation.37
38Two validation sets are needed for meta-learning optimizers.
39"""
40
41TRAIN = "TRAIN"42VALID_INNER = "VALID_INNER"43VALID_OUTER = "VALID_OUTER"44TEST = "TEST"45
46
47def split_dataset(48dataset,49num_per_split,50num_splits = 3,51):52"""Helper to split a dataset for cross validaton.53
54The first num_splits-1 datasets contain num_per_split examples.
55
56The last dataset contains the remaining number of examples.
57This often used to split a training set into more validation sets:
58e.g. train_old --> [valid_inner, valid_outer, train]
59
60Args:
61dataset: name of tfds dataset
62num_per_split: number of examples to have for each of the split off dataset.
63num_splits: number of splits to create.
64
65Returns:
66A list of the requested datasets.
67"""
68new_datasets = []69
70# make the first n_split-1 splits containing num_per_split examples71for i in range(num_splits - 1):72new_datasets.append(dataset.skip(num_per_split * i).take(num_per_split))73# The remainder of the dataset74new_datasets.append(dataset.skip(num_per_split * (num_splits - 1)))75
76return new_datasets77
78
79def _add_onehot_label_to_dict(d,80num_label):81"""Returns a new dictionary with a label_onehot key."""82d = copy.copy(d)83d["label_onehot"] = tf.one_hot(d["label"], num_label)84return d85
86
87def _process_image_in_dict(d):88"""Returns a new dict with a uint8 image converted to 0-1 scaled image."""89d = copy.copy(d)90image = d["image"]91if image.dtype != tf.uint8:92raise ValueError("Only supports uint8 images")93d["image"] = tf.cast(image, tf.float32) / 255.94return d95
96
97@attr.s98class Datasets(object):99train = attr.ib(Any)100valid_inner = attr.ib(Any)101valid_outer = attr.ib(Any)102test = attr.ib(Any)103
104
105def get_image_datasets(106dataset_name,107batch_size,108num_per_valid = 3000,109num_train = None,110cache_dataset = True,111shuffle_buffer = None,112data_dir = None,113augmentation_fn = None,114):115"""Get an image `Datasets` instance that is ready to train with.116
117This includes caching for speed, repeating, shuffling, preprocessing, and
118batching for each of the 4 splits.
119
120Args:
121dataset_name: Name of tfds dataset.
122batch_size: Batch size to use.
123num_per_valid: Number of validation images.
124num_train: Number of training examples to use. If None, use all.
125cache_dataset: Optionally cache the dataset for speed.
126shuffle_buffer: Size of shuffle buffer. If none, use the full train set
127size.
128data_dir: Location of tfds data_dir.
129augmentation_fn: Function to apply before batching for augmentation.
130
131Returns:
132`Datasets` ready to train with.
133"""
134# TODO(lmetz) pin all versions of datasets so they are consistent in time.135
136splits, info = tfds_load_dataset(137dataset_name, with_info=True, data_dir=data_dir)138num_classes = info.features["label"].num_classes139
140# Some datasets have different splits defined. For meta-learning we need 4141# splits. The following takes the splits that are defined, and tries to use142# them when possible. For missing splits, examples are taken off of the train143# dataset.144
145if set(splits.keys()) == set(["train", "validation", "test"]):146train = splits["train"]147test = splits["test"]148valid_outer = splits["validation"]149
150# pylint: disable=unbalanced-tuple-unpacking151valid_inner, train = split_dataset(152train, num_per_split=num_per_valid, num_splits=2)153num_test = info.splits["test"].num_examples154total_num_train = info.splits["train"].num_examples155num_valid = info.splits["validation"].num_examples156
157elif (set(splits.keys()) == set(["train", "test"]) or158set(splits.keys()) == set(["train", "validation"])):159
160train = splits["train"]161# pylint: disable=unbalanced-tuple-unpacking162valid_inner, valid_outer, train = split_dataset(163train, num_per_split=num_per_valid, num_splits=3)164
165if "test" in info.splits:166heldout_split = info.splits["test"]167else:168heldout_split = info.splits["validation"]169num_test = heldout_split.num_examples170
171test = splits["test"] if "test" in splits else splits["validation"]172total_num_train = info.splits["train"].num_examples - num_per_valid * 2173num_valid = num_per_valid174
175elif set(splits.keys()) == set(["train"]):176train = splits["train"]177# pylint: disable=unbalanced-tuple-unpacking178valid_inner, valid_outer, test, train = split_dataset(179train, num_per_split=num_per_valid, num_splits=4)180
181total_num_train = info.splits["train"].num_examples - num_per_valid * 3182num_test = num_per_valid183num_valid = num_per_valid184else:185raise ValueError("Unsure how to manage the following splits: %s" %186str(list(splits.keys())))187
188if num_train:189train = train.take(num_train)190else:191num_train = total_num_train192
193datasets = Datasets(194train=train, valid_inner=valid_inner, valid_outer=valid_outer, test=test)195
196if cache_dataset:197datasets = tf.nest.map_structure(lambda ds: ds.cache(), datasets)198
199datasets = tf.nest.map_structure(lambda ds: ds.repeat(), datasets)200
201train_shuffle = shuffle_buffer if shuffle_buffer else num_train202valid_shuffle = shuffle_buffer if shuffle_buffer else num_valid203test_shuffle = shuffle_buffer if shuffle_buffer else num_test204
205datasets = Datasets(206train=datasets.train.shuffle(train_shuffle),207valid_inner=datasets.valid_inner.shuffle(valid_shuffle),208valid_outer=datasets.valid_outer.shuffle(valid_shuffle),209test=datasets.test.shuffle(test_shuffle))210
211def pre_process(example):212example = _add_onehot_label_to_dict(example, num_classes)213return _process_image_in_dict(example)214
215datasets = tf.nest.map_structure(lambda ds: ds.map(pre_process), datasets)216
217if augmentation_fn:218datasets = tf.nest.map_structure(lambda ds: ds.map(augmentation_fn),219datasets)220
221return tf.nest.map_structure(222lambda ds: ds.batch(batch_size, drop_remainder=True), datasets)223
224
225def _random_slice(example,226length):227"""Extract a random slice or pad to make all sequences a fixed length.228
229For example -- if one passes in [1,2,3,4] with length=2, this would return
230one of the following: [1,2], [2,3], [3,4].
231
232If the input is [1, 2] with length=4, this would return [1, 2, 0, 0].
233
234Args:
235example: Dictionary containing a single example with the "text" key. This
236"text" key should be a vector with an integer type.
237length: Length of the slice.
238
239Returns:
240An example containing only a fixed slice of text.
241"""
242input_length = tf.shape(example["text"])[0]243max_idx = input_length - length244# pylint: disable=g-long-lambda245start_idx = tf.cond(246tf.greater(max_idx, 0), lambda: tf.random_uniform(247[], tf.to_int32(0), tf.cast(max_idx, tf.int32), dtype=tf.int32),248lambda: 0)249# pylint: enable=g-long-lambda250
251to_pad = tf.maximum(length - input_length, 0)252pad_input = tf.pad(example["text"], [[0, to_pad]])253# copy to prevent a mutation of inputs.254example = copy.copy(example)255example["text"] = pad_input[start_idx:start_idx + length]256example["text"].set_shape([length])257
258pad_mask = tf.pad(tf.ones([input_length]), [[0, to_pad]])259example["mask"] = pad_mask[start_idx:start_idx + length]260example["mask"].set_shape([length])261
262return example263
264
265def random_slice_text_data(266dataset_name,267batch_size,268num_train = None,269patch_length = 128,270num_per_valid = 3000,271cache_dataset = False,272shuffle_buffer = None,273):274"""Gets a text dataset ready to train on.275
276This splits the dataset into 4 cross validation splits, takes a random slice
277to make all entries the same length, and batches the examples.
278
279Args:
280dataset_name: tensorflow_dataset's dataset name.
281batch_size: batch size.
282num_train: number of training examples. If None use all examples.
283patch_length: length of patch to extract.
284num_per_valid: number of images for each validation set.
285cache_dataset: Cache the dataset or not.
286shuffle_buffer: Shuffle buffer size. If None, use dataset size.
287
288Returns:
289Datasets object containing tf.Dataset.
290"""
291
292train, info = tfds_load_dataset(293dataset_name, split="train", with_info=True, shuffle_files=True)294total_num_train = info.splits["train"].num_examples295num_test = info.splits["test"].num_examples296
297# pylint: disable=unbalanced-tuple-unpacking298valid_inner, valid_outer, train = split_dataset(299train, num_per_split=num_per_valid)300# pylint: enable=unbalanced-tuple-unpacking301if num_train:302train = train.take(num_train)303
304test = tfds_load_dataset(dataset_name, split="test", shuffle_files=True)305
306datasets = Datasets(307train=train, valid_inner=valid_inner, valid_outer=valid_outer, test=test)308
309if cache_dataset:310datasets = tf.nest.map_structure(lambda ds: ds.cache(), datasets)311
312datasets = tf.nest.map_structure(lambda ds: ds.repeat(), datasets)313
314train_shuffle = shuffle_buffer if shuffle_buffer else total_num_train - num_per_valid * 2315valid_shuffle = shuffle_buffer if shuffle_buffer else num_per_valid316test_shuffle = shuffle_buffer if shuffle_buffer else num_test317
318datasets = Datasets(319train=datasets.train.shuffle(train_shuffle),320valid_inner=datasets.valid_inner.shuffle(valid_shuffle),321valid_outer=datasets.valid_outer.shuffle(valid_shuffle),322test=datasets.test.shuffle(test_shuffle))323
324def pre_process(example):325"""Preprocess example by adding onehot label, and taking a random slice."""326if "label" in info.features:327num_classes = info.features["label"].num_classes328example = _add_onehot_label_to_dict(example, num_classes)329return _random_slice(example, patch_length)330
331datasets = tf.nest.map_structure(lambda ds: ds.map(pre_process), datasets)332return tf.nest.map_structure(333lambda ds: ds.batch(batch_size, drop_remainder=True), datasets)334
335
336class ResizedDataset(tfds.core.GeneratorBasedBuilder):337"""Base class for a resized image tensorflow dataset."""338
339def __init__(self, parent_builder,340size, *args, **kwargs):341"""Initialize the resized image dataset builder.342
343Args:
344parent_builder: The builder to build the resized image dataset from.
345size: size to resize each example to.
346*args: args passed super class.
347**kwargs: kwargs passed super class.
348"""
349
350parent_builder.download_and_prepare()351self._builder = parent_builder352self._size = size353super(ResizedDataset, self).__init__(*args, **kwargs)354
355def _info(self):356info = self._builder.info357description = "\n This dataset has been resized to %dx%d!" % (self._size[0],358self._size[1])359
360new_feature_dict = {k: v for k, v in info.features.items()}361new_feature_dict["image"] = tfds.features.Image(362shape=list(self._size) + [3])363
364return tfds.core.DatasetInfo(365builder=self,366description=info.description + description,367homepage=info.homepage,368features=tfds.features.FeaturesDict(new_feature_dict),369supervised_keys=info.supervised_keys,370citation=info.citation)371
372def _split_generators(self, dl_manager):373return [374tfds.core.SplitGenerator(375name=split, gen_kwargs=dict(split=split))376for split in self._builder.info.splits.keys()377]378
379def _generate_examples(self, split):380for exi, ex in enumerate(381tfds.as_numpy(self._builder.as_dataset(split=split))):382ex = self._process_example(ex)383yield exi, ex384
385def _process_example(self, example):386# As of now, this simply converts the image to the passed in size.387# TODO(lmetz) It might also make sense to resize then crop out the center.388example["image"] = cv2.resize(389example["image"], dsize=self._size, interpolation=cv2.INTER_CUBIC)390return example391
392
393class Food101_32x32(ResizedDataset): # pylint: disable=invalid-name394"""The Food101 dataset resized to be 32x32."""395
396VERSION = "1.0.0"397
398def __init__(self, *args, **kwargs):399parent_builder = tfds.builder("food101", version="1.0.0")400super(Food101_32x32, self).__init__(401*args, parent_builder=parent_builder, size=(32, 32), **kwargs)402
403
404class Food101_64x64(ResizedDataset): # pylint: disable=invalid-name405"""The Food101 dataset resized to be 64x64."""406
407VERSION = "1.0.0"408
409def __init__(self, *args, **kwargs):410parent_builder = tfds.builder("food101", version="1.0.0")411super(Food101_64x64, self).__init__(412*args, parent_builder=parent_builder, size=(64, 64), **kwargs)413
414
415class Coil100_32x32(ResizedDataset): # pylint: disable=invalid-name416"""The coil100 dataset resized to be 32x32."""417
418VERSION = "1.0.0"419
420def __init__(self, *args, **kwargs):421parent_builder = tfds.builder("coil100", version="1.0.0")422super(Coil100_32x32, self).__init__(423*args, parent_builder=parent_builder, size=(32, 32), **kwargs)424
425
426class ColorectalHistology_32x32(ResizedDataset): # pylint: disable=invalid-name427"""The colorectal_histology dataset resized to be 32x32."""428
429VERSION = "1.0.0"430
431def __init__(self, *args, **kwargs):432parent_builder = tfds.builder("colorectal_histology", version="2.*.*")433super(ColorectalHistology_32x32, self).__init__(434*args, parent_builder=parent_builder, size=(32, 32), **kwargs)435
436
437class DeepWeeds_32x32(ResizedDataset): # pylint: disable=invalid-name438"""The deep_weeds dataset resized to be 32x32."""439
440VERSION = "1.0.0"441
442def __init__(self, *args, **kwargs):443parent_builder = tfds.builder("deep_weeds", version="1.0.0")444super(DeepWeeds_32x32, self).__init__(445*args, parent_builder=parent_builder, size=(32, 32), **kwargs)446
447
448class Sun397_32x32(ResizedDataset): # pylint: disable=invalid-name449"""The sun397/tfds dataset resized to be 32x32."""450
451VERSION = "1.0.0"452
453def __init__(self, *args, **kwargs):454parent_builder = tfds.builder("sun397/tfds", version="4.0.0")455super(Sun397_32x32, self).__init__(456*args, parent_builder=parent_builder, size=(32, 32), **kwargs)457
458
459class TokenizedConfig(tfds.core.BuilderConfig):460"""BuilderConfig for tokenized text datasets."""461
462def __init__(self, version=None, text_encoder_config=None, **kwargs):463"""BuilderConfig for tokenized text datasets.464
465Args:
466version (string): version as string.
467text_encoder_config: `tfds.deprecated.text.TextEncoderConfig`, configuration
468for the `tfds.deprecated.text.TextEncoder` used for the `"text"` feature.
469**kwargs: keyword arguments forwarded to super.
470"""
471super(TokenizedConfig, self).__init__(472version=tfds.core.Version(version), **kwargs)473self.text_encoder_config = (474text_encoder_config or tfds.deprecated.text.TextEncoderConfig())475
476
477# This is an arbitrarily chosen subset of languages.
478WIKIPEDIA_PREFIX = [479"20190301.zh", "20190301.ru", "20190301.ja", "20190301.hsb", "20190301.en"480]
481
482
483def _get_builder_configs(base_configs):484"""Get the builder configs for tokenized datasets."""485configs = []486for prefix in base_configs:487configs.append(488TokenizedConfig(489name="%s_bytes" % prefix,490version="0.0.1",491description=("Uses byte-level text encoding with "492"`tfds.deprecated.text.ByteTextEncoder`"),493text_encoder_config=tfds.deprecated.text.TextEncoderConfig(494encoder=tfds.deprecated.text.ByteTextEncoder()),495))496configs.append(497TokenizedConfig(498name="%s_subwords8k" % prefix,499version="0.0.1",500description=("Uses `tfds.deprecated.text.SubwordTextEncoder` with 8k "501"vocab size"),502text_encoder_config=tfds.deprecated.text.TextEncoderConfig(503encoder_cls=tfds.deprecated.text.SubwordTextEncoder,504vocab_size=8192),505))506return configs507
508
509class TokenizedWikipedia(tfds.core.GeneratorBasedBuilder):510"""Builder which tokenizes the tfds wikipedia datasets.511
512This dataset returns 1 paragraph (split via new line) per example
513extracted from the articles. We additionally filter examples to have more than
5145 bytes. Encoding is either bytes, or subwords. The vocab is constructed out
515of the first 200k examples. While this is likely not perfect this should be
516sufficient for meta-learning optimizers.
517
518Additionally, we make a train and test split by hashing the article seed.
519
520Finally, for computational reasons we only use 1 millon articles. For the size
521of the models we are training here this should be plenty.
522"""
523BUILDER_CONFIGS = _get_builder_configs(WIKIPEDIA_PREFIX)524
525def __init__(self, config=None, **kwargs):526"""Initialize the resized image dataset builder.527
528Args:
529config: str Config string specified to build dataset with.
530**kwargs: kwargs passed super class.
531"""
532
533# extract the base dataset.534base, _ = config.split("_")535self._builder = tfds.builder("wikipedia/%s" % base)536super(TokenizedWikipedia, self).__init__(config=config, **kwargs)537
538self._perc_train = 0.7539self._max_num_articles = 1000000540# Number of examples used to build the tokenizer.541self._examples_for_tokenizer = 200000542
543def _info(self):544info = self._builder.info545description = "\n This dataset has been tokenized!"546return tfds.core.DatasetInfo(547builder=self,548description=info.description + description,549features=tfds.features.FeaturesDict({550"title":551tfds.features.Text(),552"text":553tfds.features.Text(554encoder_config=self.builder_config.text_encoder_config),555}),556supervised_keys=("text", "text"),557homepage=info.homepage,558citation=info.citation)559
560def _split_generators(self, dl_manager):561self.info.features["text"].maybe_build_from_corpus(self._vocab_text_gen())562
563return [564tfds.core.SplitGenerator(565name=split, gen_kwargs=dict(split=split))566for split in ["train", "test"]567]568
569def _split_article(self, ex):570for i, split in enumerate(ex["text"].split("\n")):571if len(split.strip()) > 5:572yield i, {"title": ex["title"], "text": split}573
574def _generate_examples(self, split):575hasher = tfds.core.hashing.Hasher("token_wikipedia_salt")576for exi, example in enumerate(577tfds.as_numpy(self._builder.as_dataset(split="train"))):578
579if exi > self._max_num_articles:580return581
582# To make a train test split we first hash the key and convert it to a583# floating point value between 0-1. Depending on this value we either584# yield the example or not depending on the split.585p = hasher.hash_key(exi) % 100000 / 100000.586
587if split == "train" and p < self._perc_train:588for i, sub_example in self._split_article(example):589key = (exi, i)590yield key, sub_example591
592elif split == "test" and p >= self._perc_train:593for i, sub_example in self._split_article(example):594key = (exi, i)595yield key, sub_example596
597def _vocab_text_gen(self):598for i, (_, ex) in enumerate(self._generate_examples("train")):599# Only yield a subset of the data used for tokenization for600# performance reasons.601if self._examples_for_tokenizer > i:602yield ex["text"]603else:604return605
606
607# Arbitrary subset of datasets.
608AMAZON_PRODUCTS = ["Books_v1_02", "Camera_v1_00", "Home_v1_00", "Video_v1_00"]609
610
611class TokenizedAmazonReviews(tfds.core.GeneratorBasedBuilder):612"""Builder which tokenizes the tfds amazon reviews datasets.613
614For compute reasons we only tokenize with 200000 examples.
615
616We make a train and test split by hashing the example index.
617"""
618BUILDER_CONFIGS = _get_builder_configs(AMAZON_PRODUCTS)619
620def __init__(self, config=None, **kwargs):621"""Initialize the resized image dataset builder.622
623Args:
624config: str Config string specified to build dataset with.
625**kwargs: kwargs passed super class.
626"""
627
628# extract the base dataset.629base = "_".join(config.split("_")[0:-1])630self._builder = tfds.builder("amazon_us_reviews/%s" % base)631
632super(TokenizedAmazonReviews, self).__init__(config=config, **kwargs)633
634self._perc_train = 0.7635self._examples_for_tokenizer = 200000636
637def _info(self):638info = self._builder.info639description = "\n This dataset has been tokenized!"640return tfds.core.DatasetInfo(641builder=self,642description=info.description + description,643features=tfds.features.FeaturesDict({644# 1-5 stars are the labels.645"label":646tfds.features.ClassLabel(num_classes=5),647"text":648tfds.features.Text(649encoder_config=self.builder_config.text_encoder_config),650}),651supervised_keys=("text", "label"),652homepage=info.homepage,653citation=info.citation)654
655def _split_generators(self, dl_manager):656self.info.features["text"].maybe_build_from_corpus(self._vocab_text_gen())657
658return [659tfds.core.SplitGenerator(660name=split, gen_kwargs=dict(split=split))661for split in ["train", "test"]662]663
664def _generate_examples(self, split):665hasher = tfds.core.hashing.Hasher("token_wikipedia_salt")666for exi, example in enumerate(667tfds.as_numpy(self._builder.as_dataset(split="train"))):668
669p = hasher.hash_key(exi) % 1000 / 1000.670
671example = {672"text": example["data"]["review_body"],673# subtract one to zero index.674"label": example["data"]["star_rating"] - 1675}676if split == "train" and p < self._perc_train:677yield exi, example678
679elif split == "test" and p > self._perc_train:680yield exi, example681
682def _vocab_text_gen(self):683for i, (_, ex) in enumerate(self._generate_examples("train")):684if self._examples_for_tokenizer > i:685yield ex["text"]686else:687return688
689
690def _single_associative_retrieval(batch_size=128, num_pairs=5, num_tokens=10):691"""See associative_retrieval."""692
693def _onehot_pack(inp, out, loss_mask):694inp_seq, outputs, loss_mask = (tf.one_hot(inp, num_tokens + 2),695tf.one_hot(out, num_tokens + 2), loss_mask)696return {"input": inp_seq, "output": outputs, "loss_mask": loss_mask}697
698def _py_make_example():699"""Iterator that makes single examples in python."""700while True:701keys = np.random.choice(num_tokens, size=num_pairs, replace=False)702values = np.random.choice(num_tokens, size=num_pairs, replace=True)703empty_token_idx = num_tokens704query_token_idx = num_tokens + 1705input_seq = []706output_seq = []707for k, v in zip(keys, values):708input_seq.extend([k, v])709output_seq.extend([empty_token_idx, empty_token_idx])710
711input_seq.append(query_token_idx)712output_seq.append(empty_token_idx)713
714query_key = np.random.randint(0, num_pairs)715input_seq.append(keys[query_key])716output_seq.append(values[query_key])717loss_mask = np.zeros(2 * num_pairs + 2, dtype=np.float32)718loss_mask[-1] = 1.719input_seq = np.asarray(input_seq, dtype=np.int32)720output_seq = np.asarray(output_seq, dtype=np.int32)721yield input_seq, output_seq, loss_mask722
723# per pair, there is a key and a value. Extra 2 account for query indicator724# and query key.725seq_len = 2 * num_pairs + 2726dataset = tf.data.Dataset.from_generator(_py_make_example,727(tf.int32, tf.int32, tf.float32),728([seq_len], [seq_len], [seq_len]))729dataset = dataset.map(_onehot_pack)730return dataset.batch(batch_size, drop_remainder=True)731
732
733def associative_sequence(batch_size=128, num_pairs=5, num_tokens=10):734"""Associative Retrieval datasets.735
736The inputs consist of pairs of key and value sequentially followed by an
737indicator token and then a retrieval token.
738
739Output consists of the value associated with the retrieval key in the final
740step of the sequence, preceded by empty tokens.
741
742The problem can be perfectly solved, as in the 'key' tokens will be unique.
743There can be duplicate values, however, for different keys.
744
745Example (using characters instead of the onehot representations):
746
747input: A1B2C3D4?A
748output: _________1
749loss_mask: 0000000001
750
751The outputs are represented using a one-hot encoding.
752
753The problem is based off of the one used in
754https://arxiv.org/pdf/1610.06258.pdf.
755
756Args:
757batch_size: int
758num_pairs: int, number of pairs to put into memory.
759num_tokens: int, number of possible tokens to choose from.
760
761Returns:
762datasets: Datasets object with each split containing the same data
763generating process.
764"""
765fn = lambda: _single_associative_retrieval(batch_size, num_pairs, num_tokens)766return Datasets(train=fn(), valid_inner=fn(), valid_outer=fn(), test=fn())767
768
769def _single_copy_sequence(batch_size=128,770sequence_length=5,771num_separator=1,772num_tokens=10):773"""See copy_sequence for docs."""774
775def _build_batch(_):776"""Construct a batch.777
778Args:
779_: tf.Tensor Needed to construct a tf.data.Dataset that iteratively calls
780this function. This is a dummy value that never changes.
781
782Returns:
783batch: SequencePrediction, containing a batch of sequences.
784"""
785inp = tf.random_uniform([batch_size, sequence_length],7860,787num_tokens,788dtype=tf.int32)789sep = tf.ones([batch_size, num_separator], dtype=tf.int32) * num_tokens790emit = tf.ones([batch_size, sequence_length], dtype=tf.int32) * (791num_tokens + 1)792inp_seq_pre_onehot = tf.concat([inp, sep, emit], axis=1)793inp_seq = tf.one_hot(inp_seq_pre_onehot, num_tokens + 2)794
795loss_mask = tf.concat([796tf.zeros([batch_size, sequence_length + num_separator]),797tf.ones([batch_size, sequence_length])798],799axis=1)800
801outputs_pre_onehot = tf.concat(802[tf.zeros_like(inp), tf.zeros_like(sep), inp], axis=1)803outputs = tf.one_hot(outputs_pre_onehot, num_tokens + 2)804
805return {"input": inp_seq, "output": outputs, "loss_mask": loss_mask}806
807return tf.data.Dataset.from_tensor_slices([0]).repeat().map(_build_batch)808
809
810def copy_sequence(batch_size=128,811sequence_length=5,812num_separator=1,813num_tokens=10):814"""A simple input copy to output task.815
816Input consists of `seq_len` tokens drawn from a vocab size of `num_tokens`
817followed by `n_sep` separation tokens, followed by 3 empty tokens.
818
819The output consists of `seq_len + n_sep` empty tokens followed by the same
820input tokens from the input.
821
822All token outputs are onehot.
823
824A sample input output pair for seq_len=3, num_tokens=3, n_sep=1
825
826input:: <tokenA><tokenB><tokenC><sep> <empty> <empty> <empty>
827output:: <empty> <empty> <empty> <empty><tokenA><tokenB><tokenC>
828loss_mask:: 0. 0. 0. 0. 1. 1. 1.
829
830Args:
831batch_size: int
832sequence_length: int, length of sequence to copy
833num_separator: int, number of empty tokens separating input from output
834num_tokens: int, number of tokens to build input from
835
836Returns:
837dataset: tf.Data.Dataset
838"""
839
840def fn():841return _single_copy_sequence(batch_size, sequence_length, num_separator,842num_tokens)843
844return Datasets(train=fn(), valid_inner=fn(), valid_outer=fn(), test=fn())845