google-research

Форк
0
/
preprocess_main.py 
128 строк · 4.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Convert a JSON dataset into the TFRecord format.
17

18
The resulting TFRecord file will be used when training a RED-ACE model.
19
"""
20

21
import random
22

23
from absl import app
24
from absl import flags
25
from absl import logging
26
import example_builder
27
import redace_flags  # pylint: disable=unused-import
28
import tensorflow as tf
29
import tokenization
30
import utils
31

32

33
FLAGS = flags.FLAGS
34

35
flags.DEFINE_string(
36
    'input_file', None,
37
    'Path to the input file containing examples to be converted to tf.Examples.'
38
)
39

40
flags.DEFINE_bool(
41
    'store_debug_features', False,
42
    'Debugging information, i.e. source tokens, are stored in the tf.Examples.')
43

44
flags.DEFINE_bool(
45
    'write_tfrecord_to_file', True,
46
    'If False no tfrecord is written to file, instead only joint length'
47
    'information is written to a file.')
48

49
flags.DEFINE_integer('max_input_lines', None, 'Number of samples.')
50

51

52
def _write_example_count(count, example_path):
53
  """Saves the number of converted examples to a file.
54

55
  This count is used when determining the number of training steps.
56

57
  Args:
58
    count: The number of converted examples.
59
    example_path: Path to the file where the examples are saved.
60

61
  Returns:
62
    The path to which the example count is saved
63
      (example_path + '.num_examples.txt').
64
  """
65
  count_fname = example_path + '.num_examples.txt'
66
  with tf.io.gfile.GFile(count_fname, 'w') as count_writer:
67
    count_writer.write(str(count))
68
  return count_fname
69

70

71
def _write_length(length, example_path):
72
  """Saves the 99 percentile joint insertion length to a file.
73

74
  This count is used when determining the number of decoding steps.
75

76
  Args:
77
    length: The 99 percentile length.
78
    example_path: Path to the file where the length is saved.
79

80
  Returns:
81
    The path to which the length is saved
82
      (example_path + '.length.txt').
83
  """
84
  count_fname = example_path + '.length.txt'
85
  with tf.io.gfile.GFile(count_fname, 'w') as count_writer:
86
    count_writer.write(str(length))
87
  return count_fname
88

89

90
def main(argv):
91
  if len(argv) > 1:
92
    raise app.UsageError('Too many command-line arguments.')
93
  if FLAGS.output_file.count('@') != 0:
94
    raise app.UsageError('Output-file sharding is not supported.')
95

96
  builder = example_builder.RedAceExampleBuilder(
97
      tokenization.FullTokenizer(FLAGS.vocab_file), FLAGS.max_seq_length)
98

99
  # Number of examples successfully converted to a tagging TF example.
100
  num_converted = 0
101
  random.seed(42)
102

103
  output_file = FLAGS.output_file
104
  with tf.io.TFRecordWriter(output_file) as writer:
105
    for i, (source, target, confidence_scores,
106
            utterance_id) in enumerate(utils.read_input(FLAGS.input_file)):
107
      logging.log_every_n(
108
          logging.INFO,
109
          f'{i} examples processed, {num_converted} converted to tf.Example.',
110
          10000,
111
      )
112
      example = builder.build_redace_example(source, confidence_scores, target)
113
      if example is not None:
114
        example.debug_features['utterance_id'] = utterance_id
115
        writer.write(example.to_tf_example().SerializeToString())
116
        num_converted += 1
117

118
  logging.info('Done. %d tagging examples converted to tf.Example.',
119
               num_converted)
120
  count_fname = _write_example_count(num_converted, FLAGS.output_file)
121
  logging.info('\n'.join(['Wrote:', FLAGS.output_file, count_fname]))
122

123

124
if __name__ == '__main__':
125
  flags.mark_flag_as_required('input_file')
126
  flags.mark_flag_as_required('output_file')
127
  flags.mark_flag_as_required('vocab_file')
128
  app.run(main)
129

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.