google-research

Форк
0
/
input_pipeline.py 
816 строк · 28.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Input pipeline for a WMT dataset.
17

18
This file was branched from flax/examples/wmt/input_pipeline.py.
19
"""
20

21
import collections
22
import csv
23
import os
24
import random
25
from typing import Dict, List, Optional, Union
26

27
from absl import logging
28
import numpy as np
29
import tensorflow as tf
30
import tensorflow.compat.v2 as tf
31

32
from data_selection.wmt import dataset_utils
33
from data_selection.wmt import tokenizer
34

35
AUTOTUNE = tf.data.AUTOTUNE
36
Features = Dict[str, tf.Tensor]
37

38

39
# -----------------------------------------------------------------------------
40
# Raw TFDS dataset.
41
# -----------------------------------------------------------------------------
42
def raw_wmt_datasets(dataset_name='wmt17_translate/de-en',
43
                     eval_dataset_name=None,
44
                     reverse_translation=False,
45
                     shard_idx=0,
46
                     shard_count=1,
47
                     data_dir=None,
48
                     paracrawl_size=0,
49
                     shuffle_train_files=True,
50
                     pseudo_path=None,
51
                     newscommentary_size=None,
52
                     newscomment_sample_ratio=1.0):
53
  """Load raw WMT datasets and normalize feature keys.
54

55
  Args:
56
    dataset_name: str: TFDS WMT dataset name.
57
    eval_dataset_name: Optional[str]: separate dataset name for evaluation.
58
      e.g. for specifying the standard academic WMT14 test set.
59
    reverse_translation: bool: whether to reverse the translation direction.
60
      e.g. for 'de-en' this translates from english to german.
61
    shard_idx: int: for multihost training, index of this host.
62
    shard_count: int: for mulithost training, number of total hosts.
63
    data_dir: str: location of TFDS data directory.
64
    paracrawl_size: if paracrawl is used, we will sample this many examples.
65
    shuffle_train_files: whether to shuffle the input data files
66
    pseudo_path: path to pseudo references
67
    newscommentary_size: Size of news commentary ft set
68
    newscomment_sample_ratio: how much to downsample newscommentary data
69

70
  Returns:
71
    training tf.dataset, evaluation tf.dataset, and training features_info
72
    source and target language features are mapped to 'inputs' and 'targets'
73
    keys.
74
  """
75
  wmt_dataset_builder = dataset_utils.WmtDatasetBuilder(shard_idx,
76
                                                        shard_count,
77
                                                        data_dir,
78
                                                        shuffle_train_files,
79
                                                        pseudo_path)
80
  train_data, eval_data = wmt_dataset_builder.build_train_and_eval_datasets(
81
      dataset_name, eval_dataset_name, paracrawl_size, newscommentary_size,
82
      newscomment_sample_ratio)
83

84
  builder = wmt_dataset_builder.retrieve_builder()
85

86
  if builder is not None:
87
    features_info = builder.info
88

89
    # standardize on 'inputs' and 'targets' features.
90
    input_lang = features_info.supervised_keys[0]
91
    target_lang = features_info.supervised_keys[1]
92
    if reverse_translation:
93
      input_lang, target_lang = target_lang, input_lang
94
    def to_features_dict(x):
95
      return {'inputs': x[input_lang], 'targets': x[target_lang]}
96
    if 'pseudo' not in dataset_name:  # Perhaps remove this code path.
97
      train_data = train_data.map(to_features_dict, num_parallel_calls=AUTOTUNE)
98
    eval_data = eval_data.map(to_features_dict, num_parallel_calls=AUTOTUNE)
99
  else:
100
    features_info = None
101

102
  return train_data, eval_data, features_info
103

104

105
def pack_dataset(dataset,
106
                 key2length,
107
                 keys = None):
108
  """Creates a 'packed' version of a dataset on-the-fly.
109

110
  Adapted from the mesh-tf implementation.
111

112
  This is meant to replace the irritation of having to create a separate
113
  "packed" version of a dataset to train efficiently on TPU.
114
  Each example in the output dataset represents several examples in the
115
  input dataset.
116
  For each key in the input dataset, two additional keys are created:
117
  <key>_segmentation: an int32 tensor identifying the parts
118
     representing the original example.
119
  <key>_position: an int32 tensor identifying the position within the original
120
     example.
121
  Example:
122
  Two input examples get combined to form an output example.
123
  The input examples are:
124
  {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
125
  {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
126
  The output example is:
127
  {
128
                 "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
129
    "inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
130
        "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
131
                "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
132
   "targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
133
       "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
134
  }
135
  0 represents padding in both the inputs and the outputs.
136
  Sequences in the incoming examples are truncated to length "length", and the
137
  sequences in the output examples all have fixed (padded) length "length".
138

139
  Args:
140
    dataset: a tf.data.Dataset
141
    key2length: an integer, or a dict from feature-key to integer
142
    keys: a list of strings (e.g. ["inputs", "targets"])
143

144
  Returns:
145
    a tf.data.Dataset
146
  """
147
  shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec)
148
  if keys is None:
149
    keys = list(shapes.keys())
150
  for k in keys:
151
    if k not in shapes:
152
      raise ValueError('Key %s not found in dataset.  Available keys are %s' %
153
                       (k, shapes.keys()))
154
    if not shapes[k].is_compatible_with(tf.TensorShape([None])):
155
      raise ValueError('Tensors to be packed must be one-dimensional.')
156
  # make sure that the length dictionary contains all keys as well as the
157
  # keys suffixed by "_segmentation" and "_position"
158
  if isinstance(key2length, int):
159
    key2length = {k: key2length for k in keys}
160
  for k in keys:
161
    for suffix in ['_segmentation', '_position']:
162
      key2length[k + suffix] = key2length[k]
163

164
  # trim to length
165
  dataset = dataset.map(
166
      lambda x: {k: x[k][:key2length[k]] for k in keys},
167
      num_parallel_calls=AUTOTUNE)
168
  # Setting batch_size=length ensures that the concatenated sequences (if they
169
  # have length >=1) are sufficient to fill at least one packed example.
170
  batch_size = max(key2length.values())
171
  dataset = dataset.padded_batch(
172
      batch_size, padded_shapes={k: [-1] for k in keys})
173
  dataset = _pack_with_tf_ops(dataset, keys, key2length)
174

175
  # Set the Tensor shapes correctly since they get lost in the process.
176
  def my_fn(x):
177
    return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()}
178

179
  return dataset.map(my_fn, num_parallel_calls=AUTOTUNE)
180

181

182
def _pack_with_tf_ops(dataset, keys,
183
                      key2length):
184
  """Helper-function for packing a dataset which has already been batched.
185

186
  Helper for pack_dataset()  Uses tf.while_loop.
187

188
  Args:
189
    dataset: a dataset containing padded batches of examples.
190
    keys: a list of strings
191
    key2length: an dict from feature-key to integer
192

193
  Returns:
194
    a dataset.
195
  """
196
  empty_example = {}
197
  for k in keys:
198
    empty_example[k] = tf.zeros([0], dtype=tf.int32)
199
    empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32)
200
  keys_etc = empty_example.keys()
201

202
  def write_packed_example(partial, outputs):
203
    new_partial = empty_example.copy()
204
    new_outputs = {}
205
    for k in keys_etc:
206
      new_outputs[k] = outputs[k].write(
207
          outputs[k].size(),
208
          tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]))
209
    return new_partial, new_outputs
210

211
  def map_fn(x):
212
    """Internal function to flat_map over.
213

214
    Consumes a batch of input examples and produces a variable number of output
215
    examples.
216
    Args:
217
      x: a single example
218

219
    Returns:
220
      a tf.data.Dataset
221
    """
222
    partial = empty_example.copy()
223
    i = tf.zeros([], dtype=tf.int32)
224
    dynamic_batch_size = tf.shape(x[keys[0]])[0]
225
    outputs = {}
226
    for k in keys:
227
      outputs[k] = tf.TensorArray(
228
          tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
229
      outputs[k + '_position'] = tf.TensorArray(
230
          tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
231

232
    def body_fn(i, partial, outputs):
233
      """Body function for while_loop.
234

235
      Args:
236
        i: integer scalar
237
        partial: dictionary of Tensor (partially-constructed example)
238
        outputs: dictionary of TensorArray
239

240
      Returns:
241
        A triple containing the new values of the inputs.
242
      """
243
      can_append = True
244
      one_example = {}
245
      for k in keys:
246
        val = tf.cast(x[k][i], tf.int32)
247
        val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
248
        one_example[k] = val
249
      for k in keys:
250
        can_append = tf.logical_and(
251
            can_append,
252
            tf.less_equal(
253
                tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]))
254

255
      def false_fn():
256
        return write_packed_example(partial, outputs)
257

258
      def true_fn():
259
        return partial, outputs
260

261
      partial, outputs = tf.cond(can_append, true_fn, false_fn)
262
      new_partial = {}
263
      for k in keys:
264
        new_seq = one_example[k][:key2length[k]]
265
        new_seq_len = tf.size(new_seq)
266
        new_partial[k] = tf.concat([partial[k], new_seq], 0)
267
        new_partial[k + '_position'] = tf.concat(
268
            [partial[k + '_position'],
269
             tf.range(new_seq_len)], 0)
270
      partial = new_partial
271
      return i + 1, partial, outputs
272

273
    # For loop over all examples in the batch.
274
    i, partial, outputs = tf.while_loop(
275
        cond=lambda *_: True,
276
        body=body_fn,
277
        loop_vars=(i, partial, outputs),
278
        shape_invariants=(
279
            tf.TensorShape([]),
280
            {k: tf.TensorShape([None]) for k in keys_etc},
281
            {k: tf.TensorShape(None) for k in keys_etc},
282
        ),
283
        maximum_iterations=dynamic_batch_size)
284
    _, outputs = write_packed_example(partial, outputs)
285
    packed = {k: outputs[k].stack() for k in keys_etc}
286
    for k in keys:
287
      packed[k + '_segmentation'] = (
288
          tf.cumsum(
289
              tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) *
290
          tf.cast(tf.not_equal(packed[k], 0), tf.int32))
291
    return packed
292

293
  dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE)
294
  return dataset.unbatch()
295

296

297
# -----------------------------------------------------------------------------
298
# Main dataset prep routines.
299
# -----------------------------------------------------------------------------
300
def preprocess_wmt_data(dataset,
301
                        shuffle,
302
                        num_epochs = 1,
303
                        pack_examples = True,
304
                        shuffle_buffer_size = 1024000,
305
                        max_length = 512,
306
                        batch_size = 256,
307
                        drop_remainder = True,
308
                        prefetch_size = AUTOTUNE,
309
                        is_scores_path=None,
310
                        num_to_keep=0,
311
                        truncate=False,
312
                        sample_size=-1):
313
  """Shuffle and batch/pack the given dataset."""
314

315
  def length_filter(max_len):
316

317
    def filter_fn(x):
318
      source, target = x['inputs'], x['targets']
319
      l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
320
      return tf.less(l, max_len + 1)
321

322
    return filter_fn
323

324
  if truncate:
325
    dataset = dataset.map(
326
        lambda x: {k: v[:max_length] for k, v in x.items()},
327
        num_parallel_calls=AUTOTUNE)
328
  elif max_length > 0:
329
    dataset = dataset.filter(length_filter(max_length))
330

331
  if is_scores_path is not None:
332
    logging.info('Doing data selection!')
333
    logging.info('Num to keep = %d', num_to_keep)
334
    dataset = data_selection(dataset, is_scores_path, num_to_keep)
335

336
  if sample_size > 0:
337
    logging.info('Downsampling: %d', sample_size)
338
    shuff_buff = 200000  # num_to_keep if num_to_keep > 0 else 200000
339
    dataset = dataset.shuffle(shuff_buff).take(sample_size)
340

341
  if shuffle:
342
    dataset = dataset.shuffle(shuffle_buffer_size)
343
  dataset = dataset.repeat(num_epochs)
344

345
  if pack_examples:
346
    dataset = pack_dataset(dataset, max_length)
347
    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
348
  else:  # simple (static-shape) padded batching
349
    dataset = dataset.padded_batch(
350
        batch_size,
351
        padded_shapes={
352
            'inputs': max_length,
353
            'targets': max_length
354
        },
355
        padding_values={
356
            'inputs': 0,
357
            'targets': 0
358
        },
359
        drop_remainder=drop_remainder)
360

361
  if prefetch_size:
362
    dataset = dataset.prefetch(prefetch_size)
363

364
  return dataset
365

366

367
def data_selection(train_data, is_scores_path, num_to_keep=-1):
368
  """Select data based on intelligent selection scores."""
369
  if num_to_keep < 0:
370
    return train_data
371

372
  scores = []
373
  with tf.io.gfile.GFile(is_scores_path, 'r') as f:
374
    reader = csv.reader(f)
375
    for val in reader:
376
      scores.extend(val)
377
  scores = [float(s) for s in scores]
378

379
  lengths = []
380
  with tf.io.gfile.GFile(is_scores_path.replace('.csv', '_length.csv'),
381
                         'r') as f:
382
    reader = csv.reader(f)
383
    for val in reader:
384
      lengths.extend(val)
385
  lengths = [int(s) for s in lengths]
386

387
  if num_to_keep >= len(scores):
388
    return train_data
389

390
  threshold = np.sort(scores)[num_to_keep]
391

392
  tf_is_scores = tf.data.Dataset.from_tensor_slices(scores)
393
  tf_lengths = tf.data.Dataset.from_tensor_slices(lengths)
394

395
  scored_data = tf.data.Dataset.zip((tf_is_scores, tf_lengths, train_data))
396
  def filter_fn(score, _, __):  #  # pylint: disable=invalid-name
397
    return tf.math.less_equal(score, threshold)
398

399
  def remove_enum(_, length, el):
400
    targ_size = tf.math.count_nonzero(el['targets'], dtype=tf.dtypes.int32)
401
    assert_op = tf.debugging.assert_equal(
402
        length, targ_size, message='Lengths not alligned')
403
    with tf.control_dependencies([assert_op]):
404
      return el
405

406
  train_data = scored_data.filter(filter_fn).map(remove_enum)
407
  train_data = train_data.cache()
408
  return train_data
409

410

411
def get_wmt_datasets(dataset_name='wmt17_translate/de-en',
412
                     eval_dataset_name=None,
413
                     reverse_translation=True,
414
                     shard_idx=0,
415
                     shard_count=1,
416
                     data_dir=None,
417
                     vocab_path=None,
418
                     target_vocab_size=2**15,  # 32000
419
                     max_corpus_chars=10**7,
420
                     batch_size=256,
421
                     pack_examples=True,
422
                     max_length=256,
423
                     max_eval_length=256,
424
                     paracrawl_size=0,
425
                     is_scores_path=None,
426
                     num_to_keep=-1,
427
                     pseudo_path=None,
428
                     shuffle_repeat_train=True,
429
                     repeat_count=-1,
430
                     newscommentary_size=None,
431
                     split_tokenizer=False,
432
                     sample_size=-1,
433
                     newscomment_sample_ratio=1.0):
434
  """Load and return dataset of batched examples for use during training."""
435
  if vocab_path is None:
436
    vocab_path = os.path.expanduser('~/wmt_sentencepiece_model')
437

438
  train_data, eval_data, _ = raw_wmt_datasets(
439
      dataset_name=dataset_name,
440
      eval_dataset_name=eval_dataset_name,
441
      reverse_translation=reverse_translation,
442
      shard_idx=shard_idx,
443
      shard_count=shard_count,
444
      data_dir=data_dir,
445
      paracrawl_size=paracrawl_size,
446
      shuffle_train_files=(is_scores_path is None) and shuffle_repeat_train,
447
      pseudo_path=pseudo_path,
448
      newscommentary_size=newscommentary_size,
449
      newscomment_sample_ratio=newscomment_sample_ratio)
450
  # If is_score_path is None, there is no data selection so we can shuffle.
451
  # If it is not None, then we cannot shuffle the input files.
452

453
  # Tokenize data.
454
  if split_tokenizer:
455
    sp_tokenizer_input = tokenizer.load_or_train_tokenizer(
456
        train_data,
457
        vocab_path=vocab_path + '_input',
458
        vocab_size=target_vocab_size,
459
        max_corpus_chars=max_corpus_chars,
460
        data_keys=('inputs',))
461
    sp_tokenizer_target = tokenizer.load_or_train_tokenizer(
462
        train_data,
463
        vocab_path=vocab_path + '_target',
464
        vocab_size=target_vocab_size,
465
        max_corpus_chars=max_corpus_chars,
466
        data_keys=('targets',))
467
    train_data = train_data.map(
468
        tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
469
                                   sp_tokenizer_target=sp_tokenizer_target),
470
        num_parallel_calls=AUTOTUNE)
471
    eval_data = eval_data.map(
472
        tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
473
                                   sp_tokenizer_target=sp_tokenizer_target),
474
        num_parallel_calls=AUTOTUNE)
475
    sp_tokenizer = sp_tokenizer_target
476
  else:
477
    sp_tokenizer = tokenizer.load_or_train_tokenizer(
478
        train_data,
479
        vocab_path=vocab_path,
480
        vocab_size=target_vocab_size,
481
        max_corpus_chars=max_corpus_chars)
482

483
    # Currently the pseudorefs are stored in pickle files and are pre-tokenized
484
    # so we would not tokenize them here. Instead we should write the
485
    # pseudo references to a tfrecord in the future.
486
    if 'pseudo' not in dataset_name:
487
      train_data = train_data.map(
488
          tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
489
    eval_data = eval_data.map(
490
        tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
491

492
  train_ds = preprocess_wmt_data(
493
      train_data,
494
      shuffle=shuffle_repeat_train,
495
      num_epochs=repeat_count,
496
      pack_examples=pack_examples,
497
      batch_size=batch_size,
498
      max_length=max_length,
499
      is_scores_path=is_scores_path,
500
      num_to_keep=num_to_keep,
501
      sample_size=sample_size)
502

503
  eval_ds = preprocess_wmt_data(
504
      eval_data,
505
      shuffle=False,
506
      pack_examples=False,
507
      batch_size=batch_size,
508
      max_length=max_eval_length)
509

510
  predict_ds = preprocess_wmt_data(
511
      eval_data,
512
      shuffle=False,
513
      pack_examples=False,
514
      batch_size=batch_size,
515
      max_length=max_eval_length,
516
      drop_remainder=False)
517

518
  return train_ds, eval_ds, predict_ds, sp_tokenizer
519

520

521
def get_wmt_is_datasets(n_devices,
522
                        dataset_name='wmt17_translate/de-en',
523
                        reverse_translation=True,
524
                        shard_idx=0,
525
                        shard_count=1,
526
                        data_dir=None,
527
                        vocab_path=None,
528
                        target_vocab_size=2**15,  # 32000
529
                        max_corpus_chars=10**7,
530
                        batch_size=256,
531
                        max_length=256,
532
                        paracrawl_size=0,
533
                        split_tokenizer=False,
534
                        use_eval_data=False,
535
                        truncate=False):
536
  """Load and return dataset of batched examples for use during training."""
537
  if batch_size % n_devices:
538
    raise ValueError("Batch size %d isn't divided evenly by n_devices %d" %
539
                     (batch_size, n_devices))
540
  if vocab_path is None:
541
    vocab_path = os.path.expanduser('~/wmt_sentencepiece_model')
542

543
  train_data, eval_data, _ = raw_wmt_datasets(
544
      dataset_name=dataset_name,
545
      eval_dataset_name=None,
546
      reverse_translation=reverse_translation,
547
      shard_idx=shard_idx,
548
      shard_count=shard_count,
549
      data_dir=data_dir,
550
      paracrawl_size=paracrawl_size,
551
      shuffle_train_files=False)
552

553
  if use_eval_data:
554
    # Unfortunate use of names but easiest for refactor w/o errors.
555
    train_data = eval_data
556

557
  # Tokenize data.
558
  if split_tokenizer:
559
    sp_tokenizer_input = tokenizer.load_or_train_tokenizer(
560
        train_data,
561
        vocab_path=vocab_path + '_input',
562
        vocab_size=target_vocab_size,
563
        max_corpus_chars=max_corpus_chars,
564
        data_keys=('inputs',))
565
    sp_tokenizer_target = tokenizer.load_or_train_tokenizer(
566
        train_data,
567
        vocab_path=vocab_path + '_target',
568
        vocab_size=target_vocab_size,
569
        max_corpus_chars=max_corpus_chars,
570
        data_keys=('targets',))
571
    train_data = train_data.map(
572
        tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
573
                                   sp_tokenizer_target=sp_tokenizer_target),
574
        num_parallel_calls=AUTOTUNE)
575
    sp_tokenizer = sp_tokenizer_target
576
  else:
577
    sp_tokenizer = tokenizer.load_or_train_tokenizer(
578
        train_data,
579
        vocab_path=vocab_path,
580
        vocab_size=target_vocab_size,
581
        max_corpus_chars=max_corpus_chars)
582

583
    # Encode strings with sentencepiece tokenizer.
584
    train_data = train_data.map(
585
        tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
586

587
  train_batches = preprocess_wmt_data(
588
      train_data,
589
      shuffle=False,
590
      num_epochs=1,
591
      pack_examples=False,
592
      batch_size=batch_size,
593
      max_length=max_length,
594
      drop_remainder=False,
595
      truncate=truncate)
596
  # Note: we drop remainder which will truncate the training data but the
597
  # effect is 0.017% of the dataset so shouldn't effect model
598

599
  if split_tokenizer:
600
    return train_batches, (sp_tokenizer_input, sp_tokenizer_target)
601
  return train_batches, (sp_tokenizer, sp_tokenizer)
602

603

604
def get_dynamic_datasets(dataset_name='wmt17_translate/de-en',
605
                         eval_dataset_name=None,
606
                         reverse_translation=True,
607
                         shard_idx=0,
608
                         shard_count=1,
609
                         data_dir=None,
610
                         vocab_path=None,
611
                         target_vocab_size=2**15,  # 32000
612
                         max_corpus_chars=10**7,
613
                         batch_size=256,
614
                         max_length=256,
615
                         max_eval_length=256,
616
                         paracrawl_size=0,
617
                         is_scores_path=None,
618
                         num_buckets=100,
619
                         split_tokenizer=False):
620
  """Load and return dataset of batched examples for use during training."""
621
  if vocab_path is None:
622
    vocab_path = os.path.expanduser('~/wmt_sentencepiece_model')
623

624
  train_data, eval_data, _ = raw_wmt_datasets(
625
      dataset_name=dataset_name,
626
      eval_dataset_name=eval_dataset_name,
627
      reverse_translation=reverse_translation,
628
      shard_idx=shard_idx,
629
      shard_count=shard_count,
630
      data_dir=data_dir,
631
      paracrawl_size=paracrawl_size,
632
      shuffle_train_files=False)
633

634
  if split_tokenizer:
635
    sp_tokenizer_input = tokenizer.load_or_train_tokenizer(
636
        train_data,
637
        vocab_path=vocab_path + '_input',
638
        vocab_size=target_vocab_size,
639
        max_corpus_chars=max_corpus_chars,
640
        data_keys=('inputs',))
641
    sp_tokenizer_target = tokenizer.load_or_train_tokenizer(
642
        train_data,
643
        vocab_path=vocab_path + '_target',
644
        vocab_size=target_vocab_size,
645
        max_corpus_chars=max_corpus_chars,
646
        data_keys=('targets',))
647
    train_data = train_data.map(
648
        tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
649
                                   sp_tokenizer_target=sp_tokenizer_target),
650
        num_parallel_calls=AUTOTUNE)
651
    eval_data = eval_data.map(
652
        tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
653
                                   sp_tokenizer_target=sp_tokenizer_target),
654
        num_parallel_calls=AUTOTUNE)
655
    sp_tokenizer = sp_tokenizer_target
656
  else:
657
    sp_tokenizer = tokenizer.load_or_train_tokenizer(
658
        train_data,
659
        vocab_path=vocab_path,
660
        vocab_size=target_vocab_size,
661
        max_corpus_chars=max_corpus_chars)
662
    train_data = train_data.map(
663
        tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
664
    eval_data = eval_data.map(
665
        tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
666

667
  train_data_manager = build_dynamic_data(
668
      train_data,
669
      batch_size=batch_size,
670
      max_length=max_length,
671
      is_scores_path=is_scores_path,
672
      num_buckets=num_buckets)
673

674
  eval_batches = preprocess_wmt_data(
675
      eval_data,
676
      shuffle=False,
677
      pack_examples=False,
678
      batch_size=batch_size,
679
      max_length=max_eval_length)
680

681
  predict_batches = preprocess_wmt_data(
682
      eval_data,
683
      shuffle=False,
684
      pack_examples=False,
685
      batch_size=batch_size,
686
      max_length=max_eval_length,
687
      drop_remainder=False)
688

689
  return train_data_manager, eval_batches, predict_batches, sp_tokenizer
690

691

692
def build_dynamic_data(dataset,
693
                       shuffle_buffer_size=1000,
694
                       max_length=512,
695
                       batch_size=256,
696
                       is_scores_path=None,
697
                       num_buckets=100):
698
  """Shuffle and batch/pack the given dataset."""
699
  def length_filter(max_len):
700
    def filter_fn(x):
701
      source, target = x['inputs'], x['targets']
702
      l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
703
      return tf.less(l, max_len + 1)
704
    return filter_fn
705

706
  if max_length > 0:
707
    dataset = dataset.filter(length_filter(max_length))
708

709
  assert is_scores_path is not None
710
  # Break into buckets
711
  buckets = create_buckets(dataset, is_scores_path, num_buckets)
712
  # Create DatasetBucketManager
713
  bucket_manager = DatasetBucketManager(buckets, shuffle_buffer_size,
714
                                        max_length, batch_size)
715
  return bucket_manager
716

717

718
def create_buckets(dataset, is_scores_path, num_buckets):
719
  """Split dataset into buckets."""
720

721
  scores = []
722
  with tf.io.gfile.GFile(is_scores_path, 'r') as f:
723
    reader = csv.reader(f)
724
    for val in reader:
725
      scores.extend(val)
726
  scores = [float(s) for s in scores]
727

728
  lengths = []
729
  with tf.io.gfile.GFile(is_scores_path.replace('.csv', '_length.csv'),
730
                         'r') as f:
731
    reader = csv.reader(f)
732
    for val in reader:
733
      lengths.extend(val)
734
  lengths = [int(s) for s in lengths]
735

736
  # compute bucket thresholds
737
  sorted_scores = np.sort(scores)
738
  logging.info('len scores %d', len(scores))
739
  bucket_size = int(len(scores) / num_buckets)
740
  ends = sorted_scores[bucket_size-1::bucket_size]
741

742
  # Iterate through dataset and write to memory
743
  bin_assignments = np.digitize(scores, ends)
744
  tf_is_bins = tf.data.Dataset.from_tensor_slices(bin_assignments)
745
  tf_lengths = tf.data.Dataset.from_tensor_slices(lengths)
746
  scored_data = tf.data.Dataset.zip((tf_is_bins, tf_lengths, dataset))
747
  bucket_examples = collections.defaultdict(list)
748
  iter_index = 0
749
  for ex_bin, ex_len, data in iter(scored_data):
750
    assert ex_len.numpy() == np.count_nonzero(
751
        data['targets'].numpy()), (ex_len, data, iter_index)
752
    iter_index += 1
753
    bucket_examples[ex_bin.numpy()].append(data)
754

755
  bucket_datasets = []
756
  index_memory = [0]* num_buckets
757
  for i in range(num_buckets):
758
    logging.info('Bin %d num el: %d', i, len(bucket_examples[i]))
759
    def gen_creator(bin_i):
760
      def gen():
761
        for ex_i in range(index_memory[bin_i], len(bucket_examples[bin_i])):
762
          index_memory[bin_i] = ex_i + 1
763
          yield bucket_examples[bin_i][ex_i]
764
          if ex_i == len(bucket_examples[bin_i]) - 1:
765
            logging.info('SHUFFLING BIN!! %d', bin_i)
766
            index_memory[bin_i] = 0
767
            random.shuffle(bucket_examples[bin_i])
768
      return gen
769

770
    gen_ds = tf.data.Dataset.from_generator(
771
        gen_creator(i), output_types={
772
            'inputs': tf.int32,
773
            'targets': tf.int32
774
        })
775
    gen_ds = gen_ds.repeat()
776
    bucket_datasets.append(gen_ds)
777

778
  # Sanity check that we are not creating the same dataset on each loop
779
  assert bucket_datasets[0] != bucket_datasets[1]
780
  return bucket_datasets
781

782

783
class DatasetBucketManager():
784
  """For dynamic data selection, sample or draw from buckets."""
785

786
  def __init__(self, datasets, shuffle_buffer_size=1000,
787
               max_length=256, batch_size=256):
788
    self.shuffle_buffer_size = shuffle_buffer_size
789
    self.max_length = max_length
790
    self.batch_size = batch_size
791
    self.unproccessed_buckets = datasets
792
    self.buckets = self._proc_buckets(self.unproccessed_buckets)
793

794
  def _proc_buckets(self, buckets):
795
    return list(map(iter, map(self._proc_dataset, buckets)))
796

797
  def _proc_dataset(self, dataset):
798
    dataset = dataset.repeat()
799
    dataset = dataset.shuffle(self.shuffle_buffer_size)
800
    dataset = pack_dataset(dataset, self.max_length)
801
    dataset = dataset.batch(self.batch_size, drop_remainder=True)
802
    dataset = dataset.prefetch(AUTOTUNE)
803
    return dataset
804

805
  def sampled_dataset(self, distribution):
806
    """Return a dataset that samples from the buckets."""
807
    sampled_ds = tf.data.experimental.sample_from_datasets(
808
        self.unproccessed_buckets,
809
        weights=distribution)
810
    # Optionally you can add a seed for better reproducibility
811
    # seed=dataset_utils.RANDOM_SAMPLE_SEED)
812
    # You shouldn't cache this dataset because it might not properly resample
813
    return self._proc_dataset(sampled_ds)
814

815
  def get_bucket(self, index):
816
    return self.buckets[index]
817

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.