google-research

Форк
0
/
predict_main.py 
84 строки · 2.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Computes predictions for a JSON dataset."""
17

18
import json
19

20
from absl import app
21
from absl import flags
22
from absl import logging
23

24
import predict
25
import redace_config as configs
26
import redace_flags  # pylint: disable=unused-import
27
import tokenization
28
import utils
29

30
FLAGS = flags.FLAGS
31

32

33
def main(argv):
34
  if len(argv) > 1:
35
    raise app.UsageError('Too many command-line arguments.')
36

37
  predictor = predict.RedAcePredictor(
38
      redace_config=configs.RedAceConfig(),
39
      model_filepath=FLAGS.model_dir,
40
      sequence_length=FLAGS.max_seq_length,
41
      batch_size=FLAGS.predict_batch_size,
42
  )
43

44
  num_predicted = 0
45
  results = []
46
  for (
47
      source_batch,
48
      confidence_scores_batch,
49
      _,
50
      utterance_id_batch,
51
  ) in utils.batch_generator(
52
      FLAGS.predict_input_file,
53
      FLAGS.predict_batch_size,
54
  ):
55
    (
56
        _,
57
        prediction_information,
58
    ) = predictor.predict_end_to_end_batch(source_batch,
59
                                           confidence_scores_batch)
60
    num_predicted += len(source_batch)
61
    logging.log_every_n(logging.INFO, f'{num_predicted} predicted.', 10)
62
    for source, prediction_output, utterance_id, in zip(
63
        source_batch,
64
        prediction_information,
65
        utterance_id_batch,
66
    ):
67
      untokenized_words = tokenization.untokenize(
68
          source, prediction_output.input_tokens, prediction_output.tags)
69
      results.append({
70
          'id':
71
              utterance_id,
72
          'asr': [[word, 0 if tag == 'KEEP' else 1]
73
                  for word, tag in untokenized_words],
74
      })
75

76
  with open(FLAGS.predict_output_file, 'w') as f:
77
    json.dump(results, f)
78

79

80
if __name__ == '__main__':
81
  flags.mark_flag_as_required('predict_input_file')
82
  flags.mark_flag_as_required('predict_output_file')
83
  flags.mark_flag_as_required('vocab_file')
84
  app.run(main)
85

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.