google-research
189 строк · 6.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Code for serializing raw data into tfrecords."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import collections23import os24import random25import numpy as np26import tensorflow.compat.v1 as tf27
28from bam.data import feature_spec29from bam.helpers import utils30
31
32class PaddingInputExample(object):33"""Fake example so the num input examples is a multiple of the batch size.34
35When running eval/predict on the TPU, we need to pad the number of examples
36to be a multiple of the batch size, because the TPU requires a fixed batch
37size. The alternative is to drop the last batch, which is bad because it means
38the entire output data won't be generated.
39
40We use this class instead of `None` because treating `None` as padding
41batches could cause silent errors.
42"""
43
44
45class Preprocessor(object):46"""Class for loading, preprocessing, and serializing datasets."""47
48def __init__(self, config, tasks):49self._config = config50self._tasks = tasks51self._name_to_task = {task.name: task for task in tasks}52
53self._feature_specs = feature_spec.get_shared_feature_specs(config)54for task in tasks:55self._feature_specs += task.get_feature_specs()56self._name_to_feature_config = {57spec.name: spec.get_parsing_spec()58for spec in self._feature_specs59}60assert len(self._name_to_feature_config) == len(self._feature_specs)61
62def prepare_train(self):63return self._serialize_dataset(self._tasks, True, "train")64
65def prepare_eval(self, task):66return self._serialize_dataset([task], False, "dev")67
68def prepare_predict(self, tasks, split):69return self._serialize_dataset(tasks, False, split)70
71def _serialize_dataset(self, tasks, is_training, split):72"""Writes out tfrecord examples for the specified tasks."""73dataset_name = "_".join(sorted([task.name for task in tasks]))74dataset_name += "_" + split75if self._config.distill:76dataset_name += "_distill"77dataset_prefix = os.path.join(78self._config.preprocessed_data_dir, dataset_name)79tfrecords_path = dataset_prefix + ".tfrecord"80metadata_path = dataset_prefix + ".metadata"81batch_size = (self._config.train_batch_size if is_training else82self._config.eval_batch_size)83
84utils.log("Loading dataset", dataset_name)85n_examples = None86sizes = {}87if (self._config.use_tfrecords_if_existing and88tf.gfile.Exists(metadata_path)):89utils.log("Using already-written tfrecords")90metadata = utils.load_json(metadata_path)91n_examples = metadata["n_examples"]92sizes = metadata["sizes"]93
94if n_examples is None:95utils.log("Existing tfrecords not found so creating")96examples = []97for task in tasks:98task_examples = task.get_examples(split)99sizes[task.name] = len(task_examples)100examples += task_examples101while len(examples) % batch_size != 0:102examples.append(PaddingInputExample())103if is_training:104random.shuffle(examples)105n_examples = len(examples)106assert n_examples % batch_size == 0107utils.mkdir(tfrecords_path.rsplit("/", 1)[0])108self.serialize_examples(examples, is_training, tfrecords_path)109utils.write_json({"n_examples": n_examples,110"sizes": sizes}, metadata_path)111
112input_fn = self._input_fn_builder(tfrecords_path, is_training)113if is_training:114steps = int(n_examples // batch_size * self._config.num_train_epochs)115else:116steps = n_examples // batch_size117
118return input_fn, steps, sizes119
120def serialize_examples(self, examples, is_training, output_file):121"""Convert a set of `InputExample`s to a TFRecord file."""122with tf.python_io.TFRecordWriter(output_file) as writer:123for (ex_index, example) in enumerate(examples):124if ex_index % 50000 == 0:125utils.log("Writing example {:} of {:}".format(126ex_index, len(examples)))127tf_example = self._example_to_tf_example(example, is_training)128writer.write(tf_example.SerializeToString())129
130def _example_to_tf_example(self, example, is_training):131if isinstance(example, PaddingInputExample):132return self._make_tf_example(task_id=self._config.n_tasks)133else:134return self._make_tf_example(135**self._name_to_task[example.task_name].featurize(136example, is_training))137
138def _make_tf_example(self, **kwargs):139"""Construct a tf.train.Example from the provided arguments."""140for k in kwargs:141if k not in self._name_to_feature_config:142raise ValueError("Unknown feature", k)143features = collections.OrderedDict()144for spec in self._feature_specs:145if spec.name in kwargs:146values = kwargs[spec.name]147else:148values = spec.get_default_value()149if (isinstance(values, int) or isinstance(values, bool) or150isinstance(values, float) or isinstance(values, np.float32) or151(isinstance(values, np.ndarray) and values.size == 1)):152values = [values]153if spec.is_int_feature:154feature = tf.train.Feature(int64_list=tf.train.Int64List(155value=list(values)))156else:157feature = tf.train.Feature(float_list=tf.train.FloatList(158value=list(values)))159features[spec.name] = feature160return tf.train.Example(features=tf.train.Features(feature=features))161
162def _input_fn_builder(self, input_file, is_training):163"""Creates an `input_fn` closure to be passed to TPUEstimator."""164
165def input_fn(params):166"""The actual input function."""167d = tf.data.TFRecordDataset(input_file)168if is_training:169d = d.repeat()170d = d.shuffle(buffer_size=100)171return d.apply(172tf.data.experimental.map_and_batch(173self._decode_tfrecord,174batch_size=params["batch_size"],175drop_remainder=True))176
177return input_fn178
179def _decode_tfrecord(self, record):180"""Decodes a record to a TensorFlow example."""181example = tf.parse_single_example(record, self._name_to_feature_config)182# tf.Example only supports tf.int64, but the TPU only supports tf.int32.183# So cast all int64 to int32.184for name, tensor in example.items():185if tensor.dtype == tf.int64:186example[name] = tf.to_int32(tensor)187else:188example[name] = tensor189return example190