google-research

Форк
0
/
preprocessing.py 
189 строк · 6.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Code for serializing raw data into tfrecords."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import collections
23
import os
24
import random
25
import numpy as np
26
import tensorflow.compat.v1 as tf
27

28
from bam.data import feature_spec
29
from bam.helpers import utils
30

31

32
class PaddingInputExample(object):
33
  """Fake example so the num input examples is a multiple of the batch size.
34

35
  When running eval/predict on the TPU, we need to pad the number of examples
36
  to be a multiple of the batch size, because the TPU requires a fixed batch
37
  size. The alternative is to drop the last batch, which is bad because it means
38
  the entire output data won't be generated.
39

40
  We use this class instead of `None` because treating `None` as padding
41
  batches could cause silent errors.
42
  """
43

44

45
class Preprocessor(object):
46
  """Class for loading, preprocessing, and serializing datasets."""
47

48
  def __init__(self, config, tasks):
49
    self._config = config
50
    self._tasks = tasks
51
    self._name_to_task = {task.name: task for task in tasks}
52

53
    self._feature_specs = feature_spec.get_shared_feature_specs(config)
54
    for task in tasks:
55
      self._feature_specs += task.get_feature_specs()
56
    self._name_to_feature_config = {
57
        spec.name: spec.get_parsing_spec()
58
        for spec in self._feature_specs
59
    }
60
    assert len(self._name_to_feature_config) == len(self._feature_specs)
61

62
  def prepare_train(self):
63
    return self._serialize_dataset(self._tasks, True, "train")
64

65
  def prepare_eval(self, task):
66
    return self._serialize_dataset([task], False, "dev")
67

68
  def prepare_predict(self, tasks, split):
69
    return self._serialize_dataset(tasks, False, split)
70

71
  def _serialize_dataset(self, tasks, is_training, split):
72
    """Writes out tfrecord examples for the specified tasks."""
73
    dataset_name = "_".join(sorted([task.name for task in tasks]))
74
    dataset_name += "_" + split
75
    if self._config.distill:
76
      dataset_name += "_distill"
77
    dataset_prefix = os.path.join(
78
        self._config.preprocessed_data_dir, dataset_name)
79
    tfrecords_path = dataset_prefix + ".tfrecord"
80
    metadata_path = dataset_prefix + ".metadata"
81
    batch_size = (self._config.train_batch_size if is_training else
82
                  self._config.eval_batch_size)
83

84
    utils.log("Loading dataset", dataset_name)
85
    n_examples = None
86
    sizes = {}
87
    if (self._config.use_tfrecords_if_existing and
88
        tf.gfile.Exists(metadata_path)):
89
      utils.log("Using already-written tfrecords")
90
      metadata = utils.load_json(metadata_path)
91
      n_examples = metadata["n_examples"]
92
      sizes = metadata["sizes"]
93

94
    if n_examples is None:
95
      utils.log("Existing tfrecords not found so creating")
96
      examples = []
97
      for task in tasks:
98
        task_examples = task.get_examples(split)
99
        sizes[task.name] = len(task_examples)
100
        examples += task_examples
101
      while len(examples) % batch_size != 0:
102
        examples.append(PaddingInputExample())
103
      if is_training:
104
        random.shuffle(examples)
105
      n_examples = len(examples)
106
      assert n_examples % batch_size == 0
107
      utils.mkdir(tfrecords_path.rsplit("/", 1)[0])
108
      self.serialize_examples(examples, is_training, tfrecords_path)
109
      utils.write_json({"n_examples": n_examples,
110
                        "sizes": sizes}, metadata_path)
111

112
    input_fn = self._input_fn_builder(tfrecords_path, is_training)
113
    if is_training:
114
      steps = int(n_examples // batch_size * self._config.num_train_epochs)
115
    else:
116
      steps = n_examples // batch_size
117

118
    return input_fn, steps, sizes
119

120
  def serialize_examples(self, examples, is_training, output_file):
121
    """Convert a set of `InputExample`s to a TFRecord file."""
122
    with tf.python_io.TFRecordWriter(output_file) as writer:
123
      for (ex_index, example) in enumerate(examples):
124
        if ex_index % 50000 == 0:
125
          utils.log("Writing example {:} of {:}".format(
126
              ex_index, len(examples)))
127
        tf_example = self._example_to_tf_example(example, is_training)
128
        writer.write(tf_example.SerializeToString())
129

130
  def _example_to_tf_example(self, example, is_training):
131
    if isinstance(example, PaddingInputExample):
132
      return self._make_tf_example(task_id=self._config.n_tasks)
133
    else:
134
      return self._make_tf_example(
135
          **self._name_to_task[example.task_name].featurize(
136
              example, is_training))
137

138
  def _make_tf_example(self, **kwargs):
139
    """Construct a tf.train.Example from the provided arguments."""
140
    for k in kwargs:
141
      if k not in self._name_to_feature_config:
142
        raise ValueError("Unknown feature", k)
143
    features = collections.OrderedDict()
144
    for spec in self._feature_specs:
145
      if spec.name in kwargs:
146
        values = kwargs[spec.name]
147
      else:
148
        values = spec.get_default_value()
149
      if (isinstance(values, int) or isinstance(values, bool) or
150
          isinstance(values, float) or isinstance(values, np.float32) or
151
          (isinstance(values, np.ndarray) and values.size == 1)):
152
        values = [values]
153
      if spec.is_int_feature:
154
        feature = tf.train.Feature(int64_list=tf.train.Int64List(
155
            value=list(values)))
156
      else:
157
        feature = tf.train.Feature(float_list=tf.train.FloatList(
158
            value=list(values)))
159
      features[spec.name] = feature
160
    return tf.train.Example(features=tf.train.Features(feature=features))
161

162
  def _input_fn_builder(self, input_file, is_training):
163
    """Creates an `input_fn` closure to be passed to TPUEstimator."""
164

165
    def input_fn(params):
166
      """The actual input function."""
167
      d = tf.data.TFRecordDataset(input_file)
168
      if is_training:
169
        d = d.repeat()
170
        d = d.shuffle(buffer_size=100)
171
      return d.apply(
172
          tf.data.experimental.map_and_batch(
173
              self._decode_tfrecord,
174
              batch_size=params["batch_size"],
175
              drop_remainder=True))
176

177
    return input_fn
178

179
  def _decode_tfrecord(self, record):
180
    """Decodes a record to a TensorFlow example."""
181
    example = tf.parse_single_example(record, self._name_to_feature_config)
182
    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
183
    # So cast all int64 to int32.
184
    for name, tensor in example.items():
185
      if tensor.dtype == tf.int64:
186
        example[name] = tf.to_int32(tensor)
187
      else:
188
        example[name] = tensor
189
    return example
190

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.