google-research
128 строк · 4.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Convert a JSON dataset into the TFRecord format.
17
18The resulting TFRecord file will be used when training a RED-ACE model.
19"""
20
21import random22
23from absl import app24from absl import flags25from absl import logging26import example_builder27import redace_flags # pylint: disable=unused-import28import tensorflow as tf29import tokenization30import utils31
32
33FLAGS = flags.FLAGS34
35flags.DEFINE_string(36'input_file', None,37'Path to the input file containing examples to be converted to tf.Examples.'38)
39
40flags.DEFINE_bool(41'store_debug_features', False,42'Debugging information, i.e. source tokens, are stored in the tf.Examples.')43
44flags.DEFINE_bool(45'write_tfrecord_to_file', True,46'If False no tfrecord is written to file, instead only joint length'47'information is written to a file.')48
49flags.DEFINE_integer('max_input_lines', None, 'Number of samples.')50
51
52def _write_example_count(count, example_path):53"""Saves the number of converted examples to a file.54
55This count is used when determining the number of training steps.
56
57Args:
58count: The number of converted examples.
59example_path: Path to the file where the examples are saved.
60
61Returns:
62The path to which the example count is saved
63(example_path + '.num_examples.txt').
64"""
65count_fname = example_path + '.num_examples.txt'66with tf.io.gfile.GFile(count_fname, 'w') as count_writer:67count_writer.write(str(count))68return count_fname69
70
71def _write_length(length, example_path):72"""Saves the 99 percentile joint insertion length to a file.73
74This count is used when determining the number of decoding steps.
75
76Args:
77length: The 99 percentile length.
78example_path: Path to the file where the length is saved.
79
80Returns:
81The path to which the length is saved
82(example_path + '.length.txt').
83"""
84count_fname = example_path + '.length.txt'85with tf.io.gfile.GFile(count_fname, 'w') as count_writer:86count_writer.write(str(length))87return count_fname88
89
90def main(argv):91if len(argv) > 1:92raise app.UsageError('Too many command-line arguments.')93if FLAGS.output_file.count('@') != 0:94raise app.UsageError('Output-file sharding is not supported.')95
96builder = example_builder.RedAceExampleBuilder(97tokenization.FullTokenizer(FLAGS.vocab_file), FLAGS.max_seq_length)98
99# Number of examples successfully converted to a tagging TF example.100num_converted = 0101random.seed(42)102
103output_file = FLAGS.output_file104with tf.io.TFRecordWriter(output_file) as writer:105for i, (source, target, confidence_scores,106utterance_id) in enumerate(utils.read_input(FLAGS.input_file)):107logging.log_every_n(108logging.INFO,109f'{i} examples processed, {num_converted} converted to tf.Example.',11010000,111)112example = builder.build_redace_example(source, confidence_scores, target)113if example is not None:114example.debug_features['utterance_id'] = utterance_id115writer.write(example.to_tf_example().SerializeToString())116num_converted += 1117
118logging.info('Done. %d tagging examples converted to tf.Example.',119num_converted)120count_fname = _write_example_count(num_converted, FLAGS.output_file)121logging.info('\n'.join(['Wrote:', FLAGS.output_file, count_fname]))122
123
124if __name__ == '__main__':125flags.mark_flag_as_required('input_file')126flags.mark_flag_as_required('output_file')127flags.mark_flag_as_required('vocab_file')128app.run(main)129