google-research
816 строк · 28.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Input pipeline for a WMT dataset.
17
18This file was branched from flax/examples/wmt/input_pipeline.py.
19"""
20
21import collections
22import csv
23import os
24import random
25from typing import Dict, List, Optional, Union
26
27from absl import logging
28import numpy as np
29import tensorflow as tf
30import tensorflow.compat.v2 as tf
31
32from data_selection.wmt import dataset_utils
33from data_selection.wmt import tokenizer
34
35AUTOTUNE = tf.data.AUTOTUNE
36Features = Dict[str, tf.Tensor]
37
38
39# -----------------------------------------------------------------------------
40# Raw TFDS dataset.
41# -----------------------------------------------------------------------------
42def raw_wmt_datasets(dataset_name='wmt17_translate/de-en',
43eval_dataset_name=None,
44reverse_translation=False,
45shard_idx=0,
46shard_count=1,
47data_dir=None,
48paracrawl_size=0,
49shuffle_train_files=True,
50pseudo_path=None,
51newscommentary_size=None,
52newscomment_sample_ratio=1.0):
53"""Load raw WMT datasets and normalize feature keys.
54
55Args:
56dataset_name: str: TFDS WMT dataset name.
57eval_dataset_name: Optional[str]: separate dataset name for evaluation.
58e.g. for specifying the standard academic WMT14 test set.
59reverse_translation: bool: whether to reverse the translation direction.
60e.g. for 'de-en' this translates from english to german.
61shard_idx: int: for multihost training, index of this host.
62shard_count: int: for mulithost training, number of total hosts.
63data_dir: str: location of TFDS data directory.
64paracrawl_size: if paracrawl is used, we will sample this many examples.
65shuffle_train_files: whether to shuffle the input data files
66pseudo_path: path to pseudo references
67newscommentary_size: Size of news commentary ft set
68newscomment_sample_ratio: how much to downsample newscommentary data
69
70Returns:
71training tf.dataset, evaluation tf.dataset, and training features_info
72source and target language features are mapped to 'inputs' and 'targets'
73keys.
74"""
75wmt_dataset_builder = dataset_utils.WmtDatasetBuilder(shard_idx,
76shard_count,
77data_dir,
78shuffle_train_files,
79pseudo_path)
80train_data, eval_data = wmt_dataset_builder.build_train_and_eval_datasets(
81dataset_name, eval_dataset_name, paracrawl_size, newscommentary_size,
82newscomment_sample_ratio)
83
84builder = wmt_dataset_builder.retrieve_builder()
85
86if builder is not None:
87features_info = builder.info
88
89# standardize on 'inputs' and 'targets' features.
90input_lang = features_info.supervised_keys[0]
91target_lang = features_info.supervised_keys[1]
92if reverse_translation:
93input_lang, target_lang = target_lang, input_lang
94def to_features_dict(x):
95return {'inputs': x[input_lang], 'targets': x[target_lang]}
96if 'pseudo' not in dataset_name: # Perhaps remove this code path.
97train_data = train_data.map(to_features_dict, num_parallel_calls=AUTOTUNE)
98eval_data = eval_data.map(to_features_dict, num_parallel_calls=AUTOTUNE)
99else:
100features_info = None
101
102return train_data, eval_data, features_info
103
104
105def pack_dataset(dataset,
106key2length,
107keys = None):
108"""Creates a 'packed' version of a dataset on-the-fly.
109
110Adapted from the mesh-tf implementation.
111
112This is meant to replace the irritation of having to create a separate
113"packed" version of a dataset to train efficiently on TPU.
114Each example in the output dataset represents several examples in the
115input dataset.
116For each key in the input dataset, two additional keys are created:
117<key>_segmentation: an int32 tensor identifying the parts
118representing the original example.
119<key>_position: an int32 tensor identifying the position within the original
120example.
121Example:
122Two input examples get combined to form an output example.
123The input examples are:
124{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
125{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
126The output example is:
127{
128"inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
129"inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
130"inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
131"targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
132"targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
133"targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
134}
1350 represents padding in both the inputs and the outputs.
136Sequences in the incoming examples are truncated to length "length", and the
137sequences in the output examples all have fixed (padded) length "length".
138
139Args:
140dataset: a tf.data.Dataset
141key2length: an integer, or a dict from feature-key to integer
142keys: a list of strings (e.g. ["inputs", "targets"])
143
144Returns:
145a tf.data.Dataset
146"""
147shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec)
148if keys is None:
149keys = list(shapes.keys())
150for k in keys:
151if k not in shapes:
152raise ValueError('Key %s not found in dataset. Available keys are %s' %
153(k, shapes.keys()))
154if not shapes[k].is_compatible_with(tf.TensorShape([None])):
155raise ValueError('Tensors to be packed must be one-dimensional.')
156# make sure that the length dictionary contains all keys as well as the
157# keys suffixed by "_segmentation" and "_position"
158if isinstance(key2length, int):
159key2length = {k: key2length for k in keys}
160for k in keys:
161for suffix in ['_segmentation', '_position']:
162key2length[k + suffix] = key2length[k]
163
164# trim to length
165dataset = dataset.map(
166lambda x: {k: x[k][:key2length[k]] for k in keys},
167num_parallel_calls=AUTOTUNE)
168# Setting batch_size=length ensures that the concatenated sequences (if they
169# have length >=1) are sufficient to fill at least one packed example.
170batch_size = max(key2length.values())
171dataset = dataset.padded_batch(
172batch_size, padded_shapes={k: [-1] for k in keys})
173dataset = _pack_with_tf_ops(dataset, keys, key2length)
174
175# Set the Tensor shapes correctly since they get lost in the process.
176def my_fn(x):
177return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()}
178
179return dataset.map(my_fn, num_parallel_calls=AUTOTUNE)
180
181
182def _pack_with_tf_ops(dataset, keys,
183key2length):
184"""Helper-function for packing a dataset which has already been batched.
185
186Helper for pack_dataset() Uses tf.while_loop.
187
188Args:
189dataset: a dataset containing padded batches of examples.
190keys: a list of strings
191key2length: an dict from feature-key to integer
192
193Returns:
194a dataset.
195"""
196empty_example = {}
197for k in keys:
198empty_example[k] = tf.zeros([0], dtype=tf.int32)
199empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32)
200keys_etc = empty_example.keys()
201
202def write_packed_example(partial, outputs):
203new_partial = empty_example.copy()
204new_outputs = {}
205for k in keys_etc:
206new_outputs[k] = outputs[k].write(
207outputs[k].size(),
208tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]))
209return new_partial, new_outputs
210
211def map_fn(x):
212"""Internal function to flat_map over.
213
214Consumes a batch of input examples and produces a variable number of output
215examples.
216Args:
217x: a single example
218
219Returns:
220a tf.data.Dataset
221"""
222partial = empty_example.copy()
223i = tf.zeros([], dtype=tf.int32)
224dynamic_batch_size = tf.shape(x[keys[0]])[0]
225outputs = {}
226for k in keys:
227outputs[k] = tf.TensorArray(
228tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
229outputs[k + '_position'] = tf.TensorArray(
230tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
231
232def body_fn(i, partial, outputs):
233"""Body function for while_loop.
234
235Args:
236i: integer scalar
237partial: dictionary of Tensor (partially-constructed example)
238outputs: dictionary of TensorArray
239
240Returns:
241A triple containing the new values of the inputs.
242"""
243can_append = True
244one_example = {}
245for k in keys:
246val = tf.cast(x[k][i], tf.int32)
247val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
248one_example[k] = val
249for k in keys:
250can_append = tf.logical_and(
251can_append,
252tf.less_equal(
253tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]))
254
255def false_fn():
256return write_packed_example(partial, outputs)
257
258def true_fn():
259return partial, outputs
260
261partial, outputs = tf.cond(can_append, true_fn, false_fn)
262new_partial = {}
263for k in keys:
264new_seq = one_example[k][:key2length[k]]
265new_seq_len = tf.size(new_seq)
266new_partial[k] = tf.concat([partial[k], new_seq], 0)
267new_partial[k + '_position'] = tf.concat(
268[partial[k + '_position'],
269tf.range(new_seq_len)], 0)
270partial = new_partial
271return i + 1, partial, outputs
272
273# For loop over all examples in the batch.
274i, partial, outputs = tf.while_loop(
275cond=lambda *_: True,
276body=body_fn,
277loop_vars=(i, partial, outputs),
278shape_invariants=(
279tf.TensorShape([]),
280{k: tf.TensorShape([None]) for k in keys_etc},
281{k: tf.TensorShape(None) for k in keys_etc},
282),
283maximum_iterations=dynamic_batch_size)
284_, outputs = write_packed_example(partial, outputs)
285packed = {k: outputs[k].stack() for k in keys_etc}
286for k in keys:
287packed[k + '_segmentation'] = (
288tf.cumsum(
289tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) *
290tf.cast(tf.not_equal(packed[k], 0), tf.int32))
291return packed
292
293dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE)
294return dataset.unbatch()
295
296
297# -----------------------------------------------------------------------------
298# Main dataset prep routines.
299# -----------------------------------------------------------------------------
300def preprocess_wmt_data(dataset,
301shuffle,
302num_epochs = 1,
303pack_examples = True,
304shuffle_buffer_size = 1024000,
305max_length = 512,
306batch_size = 256,
307drop_remainder = True,
308prefetch_size = AUTOTUNE,
309is_scores_path=None,
310num_to_keep=0,
311truncate=False,
312sample_size=-1):
313"""Shuffle and batch/pack the given dataset."""
314
315def length_filter(max_len):
316
317def filter_fn(x):
318source, target = x['inputs'], x['targets']
319l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
320return tf.less(l, max_len + 1)
321
322return filter_fn
323
324if truncate:
325dataset = dataset.map(
326lambda x: {k: v[:max_length] for k, v in x.items()},
327num_parallel_calls=AUTOTUNE)
328elif max_length > 0:
329dataset = dataset.filter(length_filter(max_length))
330
331if is_scores_path is not None:
332logging.info('Doing data selection!')
333logging.info('Num to keep = %d', num_to_keep)
334dataset = data_selection(dataset, is_scores_path, num_to_keep)
335
336if sample_size > 0:
337logging.info('Downsampling: %d', sample_size)
338shuff_buff = 200000 # num_to_keep if num_to_keep > 0 else 200000
339dataset = dataset.shuffle(shuff_buff).take(sample_size)
340
341if shuffle:
342dataset = dataset.shuffle(shuffle_buffer_size)
343dataset = dataset.repeat(num_epochs)
344
345if pack_examples:
346dataset = pack_dataset(dataset, max_length)
347dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
348else: # simple (static-shape) padded batching
349dataset = dataset.padded_batch(
350batch_size,
351padded_shapes={
352'inputs': max_length,
353'targets': max_length
354},
355padding_values={
356'inputs': 0,
357'targets': 0
358},
359drop_remainder=drop_remainder)
360
361if prefetch_size:
362dataset = dataset.prefetch(prefetch_size)
363
364return dataset
365
366
367def data_selection(train_data, is_scores_path, num_to_keep=-1):
368"""Select data based on intelligent selection scores."""
369if num_to_keep < 0:
370return train_data
371
372scores = []
373with tf.io.gfile.GFile(is_scores_path, 'r') as f:
374reader = csv.reader(f)
375for val in reader:
376scores.extend(val)
377scores = [float(s) for s in scores]
378
379lengths = []
380with tf.io.gfile.GFile(is_scores_path.replace('.csv', '_length.csv'),
381'r') as f:
382reader = csv.reader(f)
383for val in reader:
384lengths.extend(val)
385lengths = [int(s) for s in lengths]
386
387if num_to_keep >= len(scores):
388return train_data
389
390threshold = np.sort(scores)[num_to_keep]
391
392tf_is_scores = tf.data.Dataset.from_tensor_slices(scores)
393tf_lengths = tf.data.Dataset.from_tensor_slices(lengths)
394
395scored_data = tf.data.Dataset.zip((tf_is_scores, tf_lengths, train_data))
396def filter_fn(score, _, __): # # pylint: disable=invalid-name
397return tf.math.less_equal(score, threshold)
398
399def remove_enum(_, length, el):
400targ_size = tf.math.count_nonzero(el['targets'], dtype=tf.dtypes.int32)
401assert_op = tf.debugging.assert_equal(
402length, targ_size, message='Lengths not alligned')
403with tf.control_dependencies([assert_op]):
404return el
405
406train_data = scored_data.filter(filter_fn).map(remove_enum)
407train_data = train_data.cache()
408return train_data
409
410
411def get_wmt_datasets(dataset_name='wmt17_translate/de-en',
412eval_dataset_name=None,
413reverse_translation=True,
414shard_idx=0,
415shard_count=1,
416data_dir=None,
417vocab_path=None,
418target_vocab_size=2**15, # 32000
419max_corpus_chars=10**7,
420batch_size=256,
421pack_examples=True,
422max_length=256,
423max_eval_length=256,
424paracrawl_size=0,
425is_scores_path=None,
426num_to_keep=-1,
427pseudo_path=None,
428shuffle_repeat_train=True,
429repeat_count=-1,
430newscommentary_size=None,
431split_tokenizer=False,
432sample_size=-1,
433newscomment_sample_ratio=1.0):
434"""Load and return dataset of batched examples for use during training."""
435if vocab_path is None:
436vocab_path = os.path.expanduser('~/wmt_sentencepiece_model')
437
438train_data, eval_data, _ = raw_wmt_datasets(
439dataset_name=dataset_name,
440eval_dataset_name=eval_dataset_name,
441reverse_translation=reverse_translation,
442shard_idx=shard_idx,
443shard_count=shard_count,
444data_dir=data_dir,
445paracrawl_size=paracrawl_size,
446shuffle_train_files=(is_scores_path is None) and shuffle_repeat_train,
447pseudo_path=pseudo_path,
448newscommentary_size=newscommentary_size,
449newscomment_sample_ratio=newscomment_sample_ratio)
450# If is_score_path is None, there is no data selection so we can shuffle.
451# If it is not None, then we cannot shuffle the input files.
452
453# Tokenize data.
454if split_tokenizer:
455sp_tokenizer_input = tokenizer.load_or_train_tokenizer(
456train_data,
457vocab_path=vocab_path + '_input',
458vocab_size=target_vocab_size,
459max_corpus_chars=max_corpus_chars,
460data_keys=('inputs',))
461sp_tokenizer_target = tokenizer.load_or_train_tokenizer(
462train_data,
463vocab_path=vocab_path + '_target',
464vocab_size=target_vocab_size,
465max_corpus_chars=max_corpus_chars,
466data_keys=('targets',))
467train_data = train_data.map(
468tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
469sp_tokenizer_target=sp_tokenizer_target),
470num_parallel_calls=AUTOTUNE)
471eval_data = eval_data.map(
472tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
473sp_tokenizer_target=sp_tokenizer_target),
474num_parallel_calls=AUTOTUNE)
475sp_tokenizer = sp_tokenizer_target
476else:
477sp_tokenizer = tokenizer.load_or_train_tokenizer(
478train_data,
479vocab_path=vocab_path,
480vocab_size=target_vocab_size,
481max_corpus_chars=max_corpus_chars)
482
483# Currently the pseudorefs are stored in pickle files and are pre-tokenized
484# so we would not tokenize them here. Instead we should write the
485# pseudo references to a tfrecord in the future.
486if 'pseudo' not in dataset_name:
487train_data = train_data.map(
488tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
489eval_data = eval_data.map(
490tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
491
492train_ds = preprocess_wmt_data(
493train_data,
494shuffle=shuffle_repeat_train,
495num_epochs=repeat_count,
496pack_examples=pack_examples,
497batch_size=batch_size,
498max_length=max_length,
499is_scores_path=is_scores_path,
500num_to_keep=num_to_keep,
501sample_size=sample_size)
502
503eval_ds = preprocess_wmt_data(
504eval_data,
505shuffle=False,
506pack_examples=False,
507batch_size=batch_size,
508max_length=max_eval_length)
509
510predict_ds = preprocess_wmt_data(
511eval_data,
512shuffle=False,
513pack_examples=False,
514batch_size=batch_size,
515max_length=max_eval_length,
516drop_remainder=False)
517
518return train_ds, eval_ds, predict_ds, sp_tokenizer
519
520
521def get_wmt_is_datasets(n_devices,
522dataset_name='wmt17_translate/de-en',
523reverse_translation=True,
524shard_idx=0,
525shard_count=1,
526data_dir=None,
527vocab_path=None,
528target_vocab_size=2**15, # 32000
529max_corpus_chars=10**7,
530batch_size=256,
531max_length=256,
532paracrawl_size=0,
533split_tokenizer=False,
534use_eval_data=False,
535truncate=False):
536"""Load and return dataset of batched examples for use during training."""
537if batch_size % n_devices:
538raise ValueError("Batch size %d isn't divided evenly by n_devices %d" %
539(batch_size, n_devices))
540if vocab_path is None:
541vocab_path = os.path.expanduser('~/wmt_sentencepiece_model')
542
543train_data, eval_data, _ = raw_wmt_datasets(
544dataset_name=dataset_name,
545eval_dataset_name=None,
546reverse_translation=reverse_translation,
547shard_idx=shard_idx,
548shard_count=shard_count,
549data_dir=data_dir,
550paracrawl_size=paracrawl_size,
551shuffle_train_files=False)
552
553if use_eval_data:
554# Unfortunate use of names but easiest for refactor w/o errors.
555train_data = eval_data
556
557# Tokenize data.
558if split_tokenizer:
559sp_tokenizer_input = tokenizer.load_or_train_tokenizer(
560train_data,
561vocab_path=vocab_path + '_input',
562vocab_size=target_vocab_size,
563max_corpus_chars=max_corpus_chars,
564data_keys=('inputs',))
565sp_tokenizer_target = tokenizer.load_or_train_tokenizer(
566train_data,
567vocab_path=vocab_path + '_target',
568vocab_size=target_vocab_size,
569max_corpus_chars=max_corpus_chars,
570data_keys=('targets',))
571train_data = train_data.map(
572tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
573sp_tokenizer_target=sp_tokenizer_target),
574num_parallel_calls=AUTOTUNE)
575sp_tokenizer = sp_tokenizer_target
576else:
577sp_tokenizer = tokenizer.load_or_train_tokenizer(
578train_data,
579vocab_path=vocab_path,
580vocab_size=target_vocab_size,
581max_corpus_chars=max_corpus_chars)
582
583# Encode strings with sentencepiece tokenizer.
584train_data = train_data.map(
585tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
586
587train_batches = preprocess_wmt_data(
588train_data,
589shuffle=False,
590num_epochs=1,
591pack_examples=False,
592batch_size=batch_size,
593max_length=max_length,
594drop_remainder=False,
595truncate=truncate)
596# Note: we drop remainder which will truncate the training data but the
597# effect is 0.017% of the dataset so shouldn't effect model
598
599if split_tokenizer:
600return train_batches, (sp_tokenizer_input, sp_tokenizer_target)
601return train_batches, (sp_tokenizer, sp_tokenizer)
602
603
604def get_dynamic_datasets(dataset_name='wmt17_translate/de-en',
605eval_dataset_name=None,
606reverse_translation=True,
607shard_idx=0,
608shard_count=1,
609data_dir=None,
610vocab_path=None,
611target_vocab_size=2**15, # 32000
612max_corpus_chars=10**7,
613batch_size=256,
614max_length=256,
615max_eval_length=256,
616paracrawl_size=0,
617is_scores_path=None,
618num_buckets=100,
619split_tokenizer=False):
620"""Load and return dataset of batched examples for use during training."""
621if vocab_path is None:
622vocab_path = os.path.expanduser('~/wmt_sentencepiece_model')
623
624train_data, eval_data, _ = raw_wmt_datasets(
625dataset_name=dataset_name,
626eval_dataset_name=eval_dataset_name,
627reverse_translation=reverse_translation,
628shard_idx=shard_idx,
629shard_count=shard_count,
630data_dir=data_dir,
631paracrawl_size=paracrawl_size,
632shuffle_train_files=False)
633
634if split_tokenizer:
635sp_tokenizer_input = tokenizer.load_or_train_tokenizer(
636train_data,
637vocab_path=vocab_path + '_input',
638vocab_size=target_vocab_size,
639max_corpus_chars=max_corpus_chars,
640data_keys=('inputs',))
641sp_tokenizer_target = tokenizer.load_or_train_tokenizer(
642train_data,
643vocab_path=vocab_path + '_target',
644vocab_size=target_vocab_size,
645max_corpus_chars=max_corpus_chars,
646data_keys=('targets',))
647train_data = train_data.map(
648tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
649sp_tokenizer_target=sp_tokenizer_target),
650num_parallel_calls=AUTOTUNE)
651eval_data = eval_data.map(
652tokenizer.DoubleTokenizeOp(sp_tokenizer_input=sp_tokenizer_input,
653sp_tokenizer_target=sp_tokenizer_target),
654num_parallel_calls=AUTOTUNE)
655sp_tokenizer = sp_tokenizer_target
656else:
657sp_tokenizer = tokenizer.load_or_train_tokenizer(
658train_data,
659vocab_path=vocab_path,
660vocab_size=target_vocab_size,
661max_corpus_chars=max_corpus_chars)
662train_data = train_data.map(
663tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
664eval_data = eval_data.map(
665tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
666
667train_data_manager = build_dynamic_data(
668train_data,
669batch_size=batch_size,
670max_length=max_length,
671is_scores_path=is_scores_path,
672num_buckets=num_buckets)
673
674eval_batches = preprocess_wmt_data(
675eval_data,
676shuffle=False,
677pack_examples=False,
678batch_size=batch_size,
679max_length=max_eval_length)
680
681predict_batches = preprocess_wmt_data(
682eval_data,
683shuffle=False,
684pack_examples=False,
685batch_size=batch_size,
686max_length=max_eval_length,
687drop_remainder=False)
688
689return train_data_manager, eval_batches, predict_batches, sp_tokenizer
690
691
692def build_dynamic_data(dataset,
693shuffle_buffer_size=1000,
694max_length=512,
695batch_size=256,
696is_scores_path=None,
697num_buckets=100):
698"""Shuffle and batch/pack the given dataset."""
699def length_filter(max_len):
700def filter_fn(x):
701source, target = x['inputs'], x['targets']
702l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
703return tf.less(l, max_len + 1)
704return filter_fn
705
706if max_length > 0:
707dataset = dataset.filter(length_filter(max_length))
708
709assert is_scores_path is not None
710# Break into buckets
711buckets = create_buckets(dataset, is_scores_path, num_buckets)
712# Create DatasetBucketManager
713bucket_manager = DatasetBucketManager(buckets, shuffle_buffer_size,
714max_length, batch_size)
715return bucket_manager
716
717
718def create_buckets(dataset, is_scores_path, num_buckets):
719"""Split dataset into buckets."""
720
721scores = []
722with tf.io.gfile.GFile(is_scores_path, 'r') as f:
723reader = csv.reader(f)
724for val in reader:
725scores.extend(val)
726scores = [float(s) for s in scores]
727
728lengths = []
729with tf.io.gfile.GFile(is_scores_path.replace('.csv', '_length.csv'),
730'r') as f:
731reader = csv.reader(f)
732for val in reader:
733lengths.extend(val)
734lengths = [int(s) for s in lengths]
735
736# compute bucket thresholds
737sorted_scores = np.sort(scores)
738logging.info('len scores %d', len(scores))
739bucket_size = int(len(scores) / num_buckets)
740ends = sorted_scores[bucket_size-1::bucket_size]
741
742# Iterate through dataset and write to memory
743bin_assignments = np.digitize(scores, ends)
744tf_is_bins = tf.data.Dataset.from_tensor_slices(bin_assignments)
745tf_lengths = tf.data.Dataset.from_tensor_slices(lengths)
746scored_data = tf.data.Dataset.zip((tf_is_bins, tf_lengths, dataset))
747bucket_examples = collections.defaultdict(list)
748iter_index = 0
749for ex_bin, ex_len, data in iter(scored_data):
750assert ex_len.numpy() == np.count_nonzero(
751data['targets'].numpy()), (ex_len, data, iter_index)
752iter_index += 1
753bucket_examples[ex_bin.numpy()].append(data)
754
755bucket_datasets = []
756index_memory = [0]* num_buckets
757for i in range(num_buckets):
758logging.info('Bin %d num el: %d', i, len(bucket_examples[i]))
759def gen_creator(bin_i):
760def gen():
761for ex_i in range(index_memory[bin_i], len(bucket_examples[bin_i])):
762index_memory[bin_i] = ex_i + 1
763yield bucket_examples[bin_i][ex_i]
764if ex_i == len(bucket_examples[bin_i]) - 1:
765logging.info('SHUFFLING BIN!! %d', bin_i)
766index_memory[bin_i] = 0
767random.shuffle(bucket_examples[bin_i])
768return gen
769
770gen_ds = tf.data.Dataset.from_generator(
771gen_creator(i), output_types={
772'inputs': tf.int32,
773'targets': tf.int32
774})
775gen_ds = gen_ds.repeat()
776bucket_datasets.append(gen_ds)
777
778# Sanity check that we are not creating the same dataset on each loop
779assert bucket_datasets[0] != bucket_datasets[1]
780return bucket_datasets
781
782
783class DatasetBucketManager():
784"""For dynamic data selection, sample or draw from buckets."""
785
786def __init__(self, datasets, shuffle_buffer_size=1000,
787max_length=256, batch_size=256):
788self.shuffle_buffer_size = shuffle_buffer_size
789self.max_length = max_length
790self.batch_size = batch_size
791self.unproccessed_buckets = datasets
792self.buckets = self._proc_buckets(self.unproccessed_buckets)
793
794def _proc_buckets(self, buckets):
795return list(map(iter, map(self._proc_dataset, buckets)))
796
797def _proc_dataset(self, dataset):
798dataset = dataset.repeat()
799dataset = dataset.shuffle(self.shuffle_buffer_size)
800dataset = pack_dataset(dataset, self.max_length)
801dataset = dataset.batch(self.batch_size, drop_remainder=True)
802dataset = dataset.prefetch(AUTOTUNE)
803return dataset
804
805def sampled_dataset(self, distribution):
806"""Return a dataset that samples from the buckets."""
807sampled_ds = tf.data.experimental.sample_from_datasets(
808self.unproccessed_buckets,
809weights=distribution)
810# Optionally you can add a seed for better reproducibility
811# seed=dataset_utils.RANDOM_SAMPLE_SEED)
812# You shouldn't cache this dataset because it might not properly resample
813return self._proc_dataset(sampled_ds)
814
815def get_bucket(self, index):
816return self.buckets[index]
817