google-research

Форк
0
844 строки · 28.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Helper module for dealing with datasets loaded from TFDS."""
17

18
import copy
19
import enum
20
from typing import List, Dict, Optional, Text, Any, Tuple, Callable
21
import attr
22
import cv2
23
import numpy as np
24
import tensorflow.compat.v1 as tf
25
import tensorflow_datasets as tfds
26

27

28
def tfds_load_dataset(dataset_name, *args, **kwargs):
29
  """Helper function used to bridge internal google, and the external world."""
30
  data_dir = kwargs.pop("data_dir", None)
31
  return tfds.load(
32
      dataset_name, *args, data_dir=data_dir, download=True, **kwargs)
33

34

35
class Split(enum.Enum):
36
  """Enum representing different splits of data for cross validation.
37

38
  Two validation sets are needed for meta-learning optimizers.
39
  """
40

41
  TRAIN = "TRAIN"
42
  VALID_INNER = "VALID_INNER"
43
  VALID_OUTER = "VALID_OUTER"
44
  TEST = "TEST"
45

46

47
def split_dataset(
48
    dataset,
49
    num_per_split,
50
    num_splits = 3,
51
):
52
  """Helper to split a dataset for cross validaton.
53

54
  The first num_splits-1 datasets contain num_per_split examples.
55

56
  The last dataset contains the remaining number of examples.
57
  This often used to split a training set into more validation sets:
58
  e.g. train_old --> [valid_inner, valid_outer, train]
59

60
  Args:
61
    dataset: name of tfds dataset
62
    num_per_split: number of examples to have for each of the split off dataset.
63
    num_splits: number of splits to create.
64

65
  Returns:
66
    A list of the requested datasets.
67
  """
68
  new_datasets = []
69

70
  # make the first n_split-1 splits containing num_per_split examples
71
  for i in range(num_splits - 1):
72
    new_datasets.append(dataset.skip(num_per_split * i).take(num_per_split))
73
  # The remainder of the dataset
74
  new_datasets.append(dataset.skip(num_per_split * (num_splits - 1)))
75

76
  return new_datasets
77

78

79
def _add_onehot_label_to_dict(d,
80
                              num_label):
81
  """Returns a new dictionary with a label_onehot key."""
82
  d = copy.copy(d)
83
  d["label_onehot"] = tf.one_hot(d["label"], num_label)
84
  return d
85

86

87
def _process_image_in_dict(d):
88
  """Returns a new dict with a uint8 image converted to 0-1 scaled image."""
89
  d = copy.copy(d)
90
  image = d["image"]
91
  if image.dtype != tf.uint8:
92
    raise ValueError("Only supports uint8 images")
93
  d["image"] = tf.cast(image, tf.float32) / 255.
94
  return d
95

96

97
@attr.s
98
class Datasets(object):
99
  train = attr.ib(Any)
100
  valid_inner = attr.ib(Any)
101
  valid_outer = attr.ib(Any)
102
  test = attr.ib(Any)
103

104

105
def get_image_datasets(
106
    dataset_name,
107
    batch_size,
108
    num_per_valid = 3000,
109
    num_train = None,
110
    cache_dataset = True,
111
    shuffle_buffer = None,
112
    data_dir = None,
113
    augmentation_fn = None,
114
):
115
  """Get an image `Datasets` instance that is ready to train with.
116

117
  This includes caching for speed, repeating, shuffling, preprocessing, and
118
  batching for each of the 4 splits.
119

120
  Args:
121
    dataset_name: Name of tfds dataset.
122
    batch_size: Batch size to use.
123
    num_per_valid: Number of validation images.
124
    num_train: Number of training examples to use. If None, use all.
125
    cache_dataset: Optionally cache the dataset for speed.
126
    shuffle_buffer: Size of shuffle buffer. If none, use the full train set
127
      size.
128
    data_dir: Location of tfds data_dir.
129
    augmentation_fn: Function to apply before batching for augmentation.
130

131
  Returns:
132
    `Datasets` ready to train with.
133
  """
134
  # TODO(lmetz) pin all versions of datasets so they are consistent in time.
135

136
  splits, info = tfds_load_dataset(
137
      dataset_name, with_info=True, data_dir=data_dir)
138
  num_classes = info.features["label"].num_classes
139

140
  # Some datasets have different splits defined. For meta-learning we need 4
141
  # splits. The following takes the splits that are defined, and tries to use
142
  # them when possible. For missing splits, examples are taken off of the train
143
  # dataset.
144

145
  if set(splits.keys()) == set(["train", "validation", "test"]):
146
    train = splits["train"]
147
    test = splits["test"]
148
    valid_outer = splits["validation"]
149

150
    # pylint: disable=unbalanced-tuple-unpacking
151
    valid_inner, train = split_dataset(
152
        train, num_per_split=num_per_valid, num_splits=2)
153
    num_test = info.splits["test"].num_examples
154
    total_num_train = info.splits["train"].num_examples
155
    num_valid = info.splits["validation"].num_examples
156

157
  elif (set(splits.keys()) == set(["train", "test"]) or
158
        set(splits.keys()) == set(["train", "validation"])):
159

160
    train = splits["train"]
161
    # pylint: disable=unbalanced-tuple-unpacking
162
    valid_inner, valid_outer, train = split_dataset(
163
        train, num_per_split=num_per_valid, num_splits=3)
164

165
    if "test" in info.splits:
166
      heldout_split = info.splits["test"]
167
    else:
168
      heldout_split = info.splits["validation"]
169
    num_test = heldout_split.num_examples
170

171
    test = splits["test"] if "test" in splits else splits["validation"]
172
    total_num_train = info.splits["train"].num_examples - num_per_valid * 2
173
    num_valid = num_per_valid
174

175
  elif set(splits.keys()) == set(["train"]):
176
    train = splits["train"]
177
    # pylint: disable=unbalanced-tuple-unpacking
178
    valid_inner, valid_outer, test, train = split_dataset(
179
        train, num_per_split=num_per_valid, num_splits=4)
180

181
    total_num_train = info.splits["train"].num_examples - num_per_valid * 3
182
    num_test = num_per_valid
183
    num_valid = num_per_valid
184
  else:
185
    raise ValueError("Unsure how to manage the following splits: %s" %
186
                     str(list(splits.keys())))
187

188
  if num_train:
189
    train = train.take(num_train)
190
  else:
191
    num_train = total_num_train
192

193
  datasets = Datasets(
194
      train=train, valid_inner=valid_inner, valid_outer=valid_outer, test=test)
195

196
  if cache_dataset:
197
    datasets = tf.nest.map_structure(lambda ds: ds.cache(), datasets)
198

199
  datasets = tf.nest.map_structure(lambda ds: ds.repeat(), datasets)
200

201
  train_shuffle = shuffle_buffer if shuffle_buffer else num_train
202
  valid_shuffle = shuffle_buffer if shuffle_buffer else num_valid
203
  test_shuffle = shuffle_buffer if shuffle_buffer else num_test
204

205
  datasets = Datasets(
206
      train=datasets.train.shuffle(train_shuffle),
207
      valid_inner=datasets.valid_inner.shuffle(valid_shuffle),
208
      valid_outer=datasets.valid_outer.shuffle(valid_shuffle),
209
      test=datasets.test.shuffle(test_shuffle))
210

211
  def pre_process(example):
212
    example = _add_onehot_label_to_dict(example, num_classes)
213
    return _process_image_in_dict(example)
214

215
  datasets = tf.nest.map_structure(lambda ds: ds.map(pre_process), datasets)
216

217
  if augmentation_fn:
218
    datasets = tf.nest.map_structure(lambda ds: ds.map(augmentation_fn),
219
                                     datasets)
220

221
  return tf.nest.map_structure(
222
      lambda ds: ds.batch(batch_size, drop_remainder=True), datasets)
223

224

225
def _random_slice(example,
226
                  length):
227
  """Extract a random slice or pad to make all sequences a fixed length.
228

229
  For example -- if one passes in [1,2,3,4] with length=2, this would return
230
  one of the following: [1,2], [2,3], [3,4].
231

232
  If the input is [1, 2] with length=4, this would return [1, 2, 0, 0].
233

234
  Args:
235
    example: Dictionary containing a single example with the "text" key. This
236
      "text" key should be a vector with an integer type.
237
    length: Length of the slice.
238

239
  Returns:
240
    An example containing only a fixed slice of text.
241
  """
242
  input_length = tf.shape(example["text"])[0]
243
  max_idx = input_length - length
244
  # pylint: disable=g-long-lambda
245
  start_idx = tf.cond(
246
      tf.greater(max_idx, 0), lambda: tf.random_uniform(
247
          [], tf.to_int32(0), tf.cast(max_idx, tf.int32), dtype=tf.int32),
248
      lambda: 0)
249
  # pylint: enable=g-long-lambda
250

251
  to_pad = tf.maximum(length - input_length, 0)
252
  pad_input = tf.pad(example["text"], [[0, to_pad]])
253
  # copy to prevent a mutation of inputs.
254
  example = copy.copy(example)
255
  example["text"] = pad_input[start_idx:start_idx + length]
256
  example["text"].set_shape([length])
257

258
  pad_mask = tf.pad(tf.ones([input_length]), [[0, to_pad]])
259
  example["mask"] = pad_mask[start_idx:start_idx + length]
260
  example["mask"].set_shape([length])
261

262
  return example
263

264

265
def random_slice_text_data(
266
    dataset_name,
267
    batch_size,
268
    num_train = None,
269
    patch_length = 128,
270
    num_per_valid = 3000,
271
    cache_dataset = False,
272
    shuffle_buffer = None,
273
):
274
  """Gets a text dataset ready to train on.
275

276
  This splits the dataset into 4 cross validation splits, takes a random slice
277
  to make all entries the same length, and batches the examples.
278

279
  Args:
280
    dataset_name: tensorflow_dataset's dataset name.
281
    batch_size: batch size.
282
    num_train: number of training examples. If None use all examples.
283
    patch_length: length of patch to extract.
284
    num_per_valid: number of images for each validation set.
285
    cache_dataset: Cache the dataset or not.
286
    shuffle_buffer: Shuffle buffer size. If None, use dataset size.
287

288
  Returns:
289
    Datasets object containing tf.Dataset.
290
  """
291

292
  train, info = tfds_load_dataset(
293
      dataset_name, split="train", with_info=True, shuffle_files=True)
294
  total_num_train = info.splits["train"].num_examples
295
  num_test = info.splits["test"].num_examples
296

297
  # pylint: disable=unbalanced-tuple-unpacking
298
  valid_inner, valid_outer, train = split_dataset(
299
      train, num_per_split=num_per_valid)
300
  # pylint: enable=unbalanced-tuple-unpacking
301
  if num_train:
302
    train = train.take(num_train)
303

304
  test = tfds_load_dataset(dataset_name, split="test", shuffle_files=True)
305

306
  datasets = Datasets(
307
      train=train, valid_inner=valid_inner, valid_outer=valid_outer, test=test)
308

309
  if cache_dataset:
310
    datasets = tf.nest.map_structure(lambda ds: ds.cache(), datasets)
311

312
  datasets = tf.nest.map_structure(lambda ds: ds.repeat(), datasets)
313

314
  train_shuffle = shuffle_buffer if shuffle_buffer else total_num_train - num_per_valid * 2
315
  valid_shuffle = shuffle_buffer if shuffle_buffer else num_per_valid
316
  test_shuffle = shuffle_buffer if shuffle_buffer else num_test
317

318
  datasets = Datasets(
319
      train=datasets.train.shuffle(train_shuffle),
320
      valid_inner=datasets.valid_inner.shuffle(valid_shuffle),
321
      valid_outer=datasets.valid_outer.shuffle(valid_shuffle),
322
      test=datasets.test.shuffle(test_shuffle))
323

324
  def pre_process(example):
325
    """Preprocess example by adding onehot label, and taking a random slice."""
326
    if "label" in info.features:
327
      num_classes = info.features["label"].num_classes
328
      example = _add_onehot_label_to_dict(example, num_classes)
329
    return _random_slice(example, patch_length)
330

331
  datasets = tf.nest.map_structure(lambda ds: ds.map(pre_process), datasets)
332
  return tf.nest.map_structure(
333
      lambda ds: ds.batch(batch_size, drop_remainder=True), datasets)
334

335

336
class ResizedDataset(tfds.core.GeneratorBasedBuilder):
337
  """Base class for a resized image tensorflow dataset."""
338

339
  def __init__(self, parent_builder,
340
               size, *args, **kwargs):
341
    """Initialize the resized image dataset builder.
342

343
    Args:
344
      parent_builder: The builder to build the resized image dataset from.
345
      size: size to resize each example to.
346
      *args: args passed super class.
347
      **kwargs: kwargs passed super class.
348
    """
349

350
    parent_builder.download_and_prepare()
351
    self._builder = parent_builder
352
    self._size = size
353
    super(ResizedDataset, self).__init__(*args, **kwargs)
354

355
  def _info(self):
356
    info = self._builder.info
357
    description = "\n This dataset has been resized to %dx%d!" % (self._size[0],
358
                                                                  self._size[1])
359

360
    new_feature_dict = {k: v for k, v in info.features.items()}
361
    new_feature_dict["image"] = tfds.features.Image(
362
        shape=list(self._size) + [3])
363

364
    return tfds.core.DatasetInfo(
365
        builder=self,
366
        description=info.description + description,
367
        homepage=info.homepage,
368
        features=tfds.features.FeaturesDict(new_feature_dict),
369
        supervised_keys=info.supervised_keys,
370
        citation=info.citation)
371

372
  def _split_generators(self, dl_manager):
373
    return [
374
        tfds.core.SplitGenerator(
375
            name=split, gen_kwargs=dict(split=split))
376
        for split in self._builder.info.splits.keys()
377
    ]
378

379
  def _generate_examples(self, split):
380
    for exi, ex in enumerate(
381
        tfds.as_numpy(self._builder.as_dataset(split=split))):
382
      ex = self._process_example(ex)
383
      yield exi, ex
384

385
  def _process_example(self, example):
386
    # As of now, this simply converts the image to the passed in size.
387
    # TODO(lmetz) It might also make sense to resize then crop out the center.
388
    example["image"] = cv2.resize(
389
        example["image"], dsize=self._size, interpolation=cv2.INTER_CUBIC)
390
    return example
391

392

393
class Food101_32x32(ResizedDataset):  # pylint: disable=invalid-name
394
  """The Food101 dataset resized to be 32x32."""
395

396
  VERSION = "1.0.0"
397

398
  def __init__(self, *args, **kwargs):
399
    parent_builder = tfds.builder("food101", version="1.0.0")
400
    super(Food101_32x32, self).__init__(
401
        *args, parent_builder=parent_builder, size=(32, 32), **kwargs)
402

403

404
class Food101_64x64(ResizedDataset):  # pylint: disable=invalid-name
405
  """The Food101 dataset resized to be 64x64."""
406

407
  VERSION = "1.0.0"
408

409
  def __init__(self, *args, **kwargs):
410
    parent_builder = tfds.builder("food101", version="1.0.0")
411
    super(Food101_64x64, self).__init__(
412
        *args, parent_builder=parent_builder, size=(64, 64), **kwargs)
413

414

415
class Coil100_32x32(ResizedDataset):  # pylint: disable=invalid-name
416
  """The coil100 dataset resized to be 32x32."""
417

418
  VERSION = "1.0.0"
419

420
  def __init__(self, *args, **kwargs):
421
    parent_builder = tfds.builder("coil100", version="1.0.0")
422
    super(Coil100_32x32, self).__init__(
423
        *args, parent_builder=parent_builder, size=(32, 32), **kwargs)
424

425

426
class ColorectalHistology_32x32(ResizedDataset):  # pylint: disable=invalid-name
427
  """The colorectal_histology dataset resized to be 32x32."""
428

429
  VERSION = "1.0.0"
430

431
  def __init__(self, *args, **kwargs):
432
    parent_builder = tfds.builder("colorectal_histology", version="2.*.*")
433
    super(ColorectalHistology_32x32, self).__init__(
434
        *args, parent_builder=parent_builder, size=(32, 32), **kwargs)
435

436

437
class DeepWeeds_32x32(ResizedDataset):  # pylint: disable=invalid-name
438
  """The deep_weeds dataset resized to be 32x32."""
439

440
  VERSION = "1.0.0"
441

442
  def __init__(self, *args, **kwargs):
443
    parent_builder = tfds.builder("deep_weeds", version="1.0.0")
444
    super(DeepWeeds_32x32, self).__init__(
445
        *args, parent_builder=parent_builder, size=(32, 32), **kwargs)
446

447

448
class Sun397_32x32(ResizedDataset):  # pylint: disable=invalid-name
449
  """The sun397/tfds dataset resized to be 32x32."""
450

451
  VERSION = "1.0.0"
452

453
  def __init__(self, *args, **kwargs):
454
    parent_builder = tfds.builder("sun397/tfds", version="4.0.0")
455
    super(Sun397_32x32, self).__init__(
456
        *args, parent_builder=parent_builder, size=(32, 32), **kwargs)
457

458

459
class TokenizedConfig(tfds.core.BuilderConfig):
460
  """BuilderConfig for tokenized text datasets."""
461

462
  def __init__(self, version=None, text_encoder_config=None, **kwargs):
463
    """BuilderConfig for tokenized text datasets.
464

465
    Args:
466
      version (string): version as string.
467
      text_encoder_config: `tfds.deprecated.text.TextEncoderConfig`, configuration
468
        for the `tfds.deprecated.text.TextEncoder` used for the `"text"` feature.
469
      **kwargs: keyword arguments forwarded to super.
470
    """
471
    super(TokenizedConfig, self).__init__(
472
        version=tfds.core.Version(version), **kwargs)
473
    self.text_encoder_config = (
474
        text_encoder_config or tfds.deprecated.text.TextEncoderConfig())
475

476

477
# This is an arbitrarily chosen subset of languages.
478
WIKIPEDIA_PREFIX = [
479
    "20190301.zh", "20190301.ru", "20190301.ja", "20190301.hsb", "20190301.en"
480
]
481

482

483
def _get_builder_configs(base_configs):
484
  """Get the builder configs for tokenized datasets."""
485
  configs = []
486
  for prefix in base_configs:
487
    configs.append(
488
        TokenizedConfig(
489
            name="%s_bytes" % prefix,
490
            version="0.0.1",
491
            description=("Uses byte-level text encoding with "
492
                         "`tfds.deprecated.text.ByteTextEncoder`"),
493
            text_encoder_config=tfds.deprecated.text.TextEncoderConfig(
494
                encoder=tfds.deprecated.text.ByteTextEncoder()),
495
        ))
496
    configs.append(
497
        TokenizedConfig(
498
            name="%s_subwords8k" % prefix,
499
            version="0.0.1",
500
            description=("Uses `tfds.deprecated.text.SubwordTextEncoder` with 8k "
501
                         "vocab size"),
502
            text_encoder_config=tfds.deprecated.text.TextEncoderConfig(
503
                encoder_cls=tfds.deprecated.text.SubwordTextEncoder,
504
                vocab_size=8192),
505
        ))
506
  return configs
507

508

509
class TokenizedWikipedia(tfds.core.GeneratorBasedBuilder):
510
  """Builder which tokenizes the tfds wikipedia datasets.
511

512
  This dataset returns 1 paragraph (split via new line) per example
513
  extracted from the articles. We additionally filter examples to have more than
514
  5 bytes. Encoding is either bytes, or subwords. The vocab is constructed out
515
  of the first 200k examples. While this is likely not perfect this should be
516
  sufficient for meta-learning optimizers.
517

518
  Additionally, we make a train and test split by hashing the article seed.
519

520
  Finally, for computational reasons we only use 1 millon articles. For the size
521
  of the models we are training here this should be plenty.
522
  """
523
  BUILDER_CONFIGS = _get_builder_configs(WIKIPEDIA_PREFIX)
524

525
  def __init__(self, config=None, **kwargs):
526
    """Initialize the resized image dataset builder.
527

528
    Args:
529
      config: str Config string specified to build dataset with.
530
      **kwargs: kwargs passed super class.
531
    """
532

533
    # extract the base dataset.
534
    base, _ = config.split("_")
535
    self._builder = tfds.builder("wikipedia/%s" % base)
536
    super(TokenizedWikipedia, self).__init__(config=config, **kwargs)
537

538
    self._perc_train = 0.7
539
    self._max_num_articles = 1000000
540
    # Number of examples used to build the tokenizer.
541
    self._examples_for_tokenizer = 200000
542

543
  def _info(self):
544
    info = self._builder.info
545
    description = "\n This dataset has been tokenized!"
546
    return tfds.core.DatasetInfo(
547
        builder=self,
548
        description=info.description + description,
549
        features=tfds.features.FeaturesDict({
550
            "title":
551
                tfds.features.Text(),
552
            "text":
553
                tfds.features.Text(
554
                    encoder_config=self.builder_config.text_encoder_config),
555
        }),
556
        supervised_keys=("text", "text"),
557
        homepage=info.homepage,
558
        citation=info.citation)
559

560
  def _split_generators(self, dl_manager):
561
    self.info.features["text"].maybe_build_from_corpus(self._vocab_text_gen())
562

563
    return [
564
        tfds.core.SplitGenerator(
565
            name=split, gen_kwargs=dict(split=split))
566
        for split in ["train", "test"]
567
    ]
568

569
  def _split_article(self, ex):
570
    for i, split in enumerate(ex["text"].split("\n")):
571
      if len(split.strip()) > 5:
572
        yield i, {"title": ex["title"], "text": split}
573

574
  def _generate_examples(self, split):
575
    hasher = tfds.core.hashing.Hasher("token_wikipedia_salt")
576
    for exi, example in enumerate(
577
        tfds.as_numpy(self._builder.as_dataset(split="train"))):
578

579
      if exi > self._max_num_articles:
580
        return
581

582
      # To make a train test split we first hash the key and convert it to a
583
      # floating point value between 0-1. Depending on this value we either
584
      # yield the example or not depending on the split.
585
      p = hasher.hash_key(exi) % 100000 / 100000.
586

587
      if split == "train" and p < self._perc_train:
588
        for i, sub_example in self._split_article(example):
589
          key = (exi, i)
590
          yield key, sub_example
591

592
      elif split == "test" and p >= self._perc_train:
593
        for i, sub_example in self._split_article(example):
594
          key = (exi, i)
595
          yield key, sub_example
596

597
  def _vocab_text_gen(self):
598
    for i, (_, ex) in enumerate(self._generate_examples("train")):
599
      # Only yield a subset of the data used for tokenization for
600
      # performance reasons.
601
      if self._examples_for_tokenizer > i:
602
        yield ex["text"]
603
      else:
604
        return
605

606

607
# Arbitrary subset of datasets.
608
AMAZON_PRODUCTS = ["Books_v1_02", "Camera_v1_00", "Home_v1_00", "Video_v1_00"]
609

610

611
class TokenizedAmazonReviews(tfds.core.GeneratorBasedBuilder):
612
  """Builder which tokenizes the tfds amazon reviews datasets.
613

614
  For compute reasons we only tokenize with 200000 examples.
615

616
  We make a train and test split by hashing the example index.
617
  """
618
  BUILDER_CONFIGS = _get_builder_configs(AMAZON_PRODUCTS)
619

620
  def __init__(self, config=None, **kwargs):
621
    """Initialize the resized image dataset builder.
622

623
    Args:
624
      config: str Config string specified to build dataset with.
625
      **kwargs: kwargs passed super class.
626
    """
627

628
    # extract the base dataset.
629
    base = "_".join(config.split("_")[0:-1])
630
    self._builder = tfds.builder("amazon_us_reviews/%s" % base)
631

632
    super(TokenizedAmazonReviews, self).__init__(config=config, **kwargs)
633

634
    self._perc_train = 0.7
635
    self._examples_for_tokenizer = 200000
636

637
  def _info(self):
638
    info = self._builder.info
639
    description = "\n This dataset has been tokenized!"
640
    return tfds.core.DatasetInfo(
641
        builder=self,
642
        description=info.description + description,
643
        features=tfds.features.FeaturesDict({
644
            # 1-5 stars are the labels.
645
            "label":
646
                tfds.features.ClassLabel(num_classes=5),
647
            "text":
648
                tfds.features.Text(
649
                    encoder_config=self.builder_config.text_encoder_config),
650
        }),
651
        supervised_keys=("text", "label"),
652
        homepage=info.homepage,
653
        citation=info.citation)
654

655
  def _split_generators(self, dl_manager):
656
    self.info.features["text"].maybe_build_from_corpus(self._vocab_text_gen())
657

658
    return [
659
        tfds.core.SplitGenerator(
660
            name=split, gen_kwargs=dict(split=split))
661
        for split in ["train", "test"]
662
    ]
663

664
  def _generate_examples(self, split):
665
    hasher = tfds.core.hashing.Hasher("token_wikipedia_salt")
666
    for exi, example in enumerate(
667
        tfds.as_numpy(self._builder.as_dataset(split="train"))):
668

669
      p = hasher.hash_key(exi) % 1000 / 1000.
670

671
      example = {
672
          "text": example["data"]["review_body"],
673
          # subtract one to zero index.
674
          "label": example["data"]["star_rating"] - 1
675
      }
676
      if split == "train" and p < self._perc_train:
677
        yield exi, example
678

679
      elif split == "test" and p > self._perc_train:
680
        yield exi, example
681

682
  def _vocab_text_gen(self):
683
    for i, (_, ex) in enumerate(self._generate_examples("train")):
684
      if self._examples_for_tokenizer > i:
685
        yield ex["text"]
686
      else:
687
        return
688

689

690
def _single_associative_retrieval(batch_size=128, num_pairs=5, num_tokens=10):
691
  """See associative_retrieval."""
692

693
  def _onehot_pack(inp, out, loss_mask):
694
    inp_seq, outputs, loss_mask = (tf.one_hot(inp, num_tokens + 2),
695
                                   tf.one_hot(out, num_tokens + 2), loss_mask)
696
    return {"input": inp_seq, "output": outputs, "loss_mask": loss_mask}
697

698
  def _py_make_example():
699
    """Iterator that makes single examples in python."""
700
    while True:
701
      keys = np.random.choice(num_tokens, size=num_pairs, replace=False)
702
      values = np.random.choice(num_tokens, size=num_pairs, replace=True)
703
      empty_token_idx = num_tokens
704
      query_token_idx = num_tokens + 1
705
      input_seq = []
706
      output_seq = []
707
      for k, v in zip(keys, values):
708
        input_seq.extend([k, v])
709
        output_seq.extend([empty_token_idx, empty_token_idx])
710

711
      input_seq.append(query_token_idx)
712
      output_seq.append(empty_token_idx)
713

714
      query_key = np.random.randint(0, num_pairs)
715
      input_seq.append(keys[query_key])
716
      output_seq.append(values[query_key])
717
      loss_mask = np.zeros(2 * num_pairs + 2, dtype=np.float32)
718
      loss_mask[-1] = 1.
719
      input_seq = np.asarray(input_seq, dtype=np.int32)
720
      output_seq = np.asarray(output_seq, dtype=np.int32)
721
      yield input_seq, output_seq, loss_mask
722

723
  # per pair, there is a key and a value. Extra 2 account for query indicator
724
  # and query key.
725
  seq_len = 2 * num_pairs + 2
726
  dataset = tf.data.Dataset.from_generator(_py_make_example,
727
                                           (tf.int32, tf.int32, tf.float32),
728
                                           ([seq_len], [seq_len], [seq_len]))
729
  dataset = dataset.map(_onehot_pack)
730
  return dataset.batch(batch_size, drop_remainder=True)
731

732

733
def associative_sequence(batch_size=128, num_pairs=5, num_tokens=10):
734
  """Associative Retrieval datasets.
735

736
  The inputs consist of pairs of key and value sequentially followed by an
737
  indicator token and then a retrieval token.
738

739
  Output consists of the value associated with the retrieval key in the final
740
  step of the sequence, preceded by empty tokens.
741

742
  The problem can be perfectly solved, as in the 'key' tokens will be unique.
743
  There can be duplicate values, however, for different keys.
744

745
  Example (using characters instead of the onehot representations):
746

747
  input:     A1B2C3D4?A
748
  output:    _________1
749
  loss_mask: 0000000001
750

751
  The outputs are represented using a one-hot encoding.
752

753
  The problem is based off of the one used in
754
  https://arxiv.org/pdf/1610.06258.pdf.
755

756
  Args:
757
    batch_size: int
758
    num_pairs: int, number of pairs to put into memory.
759
    num_tokens: int, number of possible tokens to choose from.
760

761
  Returns:
762
    datasets: Datasets object with each split containing the same data
763
      generating process.
764
  """
765
  fn = lambda: _single_associative_retrieval(batch_size, num_pairs, num_tokens)
766
  return Datasets(train=fn(), valid_inner=fn(), valid_outer=fn(), test=fn())
767

768

769
def _single_copy_sequence(batch_size=128,
770
                          sequence_length=5,
771
                          num_separator=1,
772
                          num_tokens=10):
773
  """See copy_sequence for docs."""
774

775
  def _build_batch(_):
776
    """Construct a batch.
777

778
    Args:
779
      _: tf.Tensor Needed to construct a tf.data.Dataset that iteratively calls
780
        this function. This is a dummy value that never changes.
781

782
    Returns:
783
      batch: SequencePrediction, containing a batch of sequences.
784
    """
785
    inp = tf.random_uniform([batch_size, sequence_length],
786
                            0,
787
                            num_tokens,
788
                            dtype=tf.int32)
789
    sep = tf.ones([batch_size, num_separator], dtype=tf.int32) * num_tokens
790
    emit = tf.ones([batch_size, sequence_length], dtype=tf.int32) * (
791
        num_tokens + 1)
792
    inp_seq_pre_onehot = tf.concat([inp, sep, emit], axis=1)
793
    inp_seq = tf.one_hot(inp_seq_pre_onehot, num_tokens + 2)
794

795
    loss_mask = tf.concat([
796
        tf.zeros([batch_size, sequence_length + num_separator]),
797
        tf.ones([batch_size, sequence_length])
798
    ],
799
                          axis=1)
800

801
    outputs_pre_onehot = tf.concat(
802
        [tf.zeros_like(inp), tf.zeros_like(sep), inp], axis=1)
803
    outputs = tf.one_hot(outputs_pre_onehot, num_tokens + 2)
804

805
    return {"input": inp_seq, "output": outputs, "loss_mask": loss_mask}
806

807
  return tf.data.Dataset.from_tensor_slices([0]).repeat().map(_build_batch)
808

809

810
def copy_sequence(batch_size=128,
811
                  sequence_length=5,
812
                  num_separator=1,
813
                  num_tokens=10):
814
  """A simple input copy to output task.
815

816
  Input consists of `seq_len` tokens drawn from a vocab size of `num_tokens`
817
  followed by `n_sep` separation tokens, followed by 3 empty tokens.
818

819
  The output consists of `seq_len + n_sep` empty tokens followed by the same
820
  input tokens from the input.
821

822
  All token outputs are onehot.
823

824
  A sample input output pair for seq_len=3, num_tokens=3, n_sep=1
825

826
  input::        <tokenA><tokenB><tokenC><sep>  <empty> <empty> <empty>
827
  output::       <empty> <empty> <empty> <empty><tokenA><tokenB><tokenC>
828
  loss_mask::  0.       0.     0.      0.     1.      1.      1.
829

830
  Args:
831
    batch_size: int
832
    sequence_length: int, length of sequence to copy
833
    num_separator: int, number of empty tokens separating input from output
834
    num_tokens: int, number of tokens to build input from
835

836
  Returns:
837
    dataset: tf.Data.Dataset
838
  """
839

840
  def fn():
841
    return _single_copy_sequence(batch_size, sequence_length, num_separator,
842
                                 num_tokens)
843

844
  return Datasets(train=fn(), valid_inner=fn(), valid_outer=fn(), test=fn())
845

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.