google-research
84 строки · 2.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Computes predictions for a JSON dataset."""
17
18import json19
20from absl import app21from absl import flags22from absl import logging23
24import predict25import redace_config as configs26import redace_flags # pylint: disable=unused-import27import tokenization28import utils29
30FLAGS = flags.FLAGS31
32
33def main(argv):34if len(argv) > 1:35raise app.UsageError('Too many command-line arguments.')36
37predictor = predict.RedAcePredictor(38redace_config=configs.RedAceConfig(),39model_filepath=FLAGS.model_dir,40sequence_length=FLAGS.max_seq_length,41batch_size=FLAGS.predict_batch_size,42)43
44num_predicted = 045results = []46for (47source_batch,48confidence_scores_batch,49_,50utterance_id_batch,51) in utils.batch_generator(52FLAGS.predict_input_file,53FLAGS.predict_batch_size,54):55(56_,57prediction_information,58) = predictor.predict_end_to_end_batch(source_batch,59confidence_scores_batch)60num_predicted += len(source_batch)61logging.log_every_n(logging.INFO, f'{num_predicted} predicted.', 10)62for source, prediction_output, utterance_id, in zip(63source_batch,64prediction_information,65utterance_id_batch,66):67untokenized_words = tokenization.untokenize(68source, prediction_output.input_tokens, prediction_output.tags)69results.append({70'id':71utterance_id,72'asr': [[word, 0 if tag == 'KEEP' else 1]73for word, tag in untokenized_words],74})75
76with open(FLAGS.predict_output_file, 'w') as f:77json.dump(results, f)78
79
80if __name__ == '__main__':81flags.mark_flag_as_required('predict_input_file')82flags.mark_flag_as_required('predict_output_file')83flags.mark_flag_as_required('vocab_file')84app.run(main)85