google-research

Форк
0
/
create_token_vocab.py 
116 строк · 3.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Creates token vocabulary using tensor2tensor tokenizer."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import collections
23
import csv
24
import operator
25
import os
26

27
from tensor2tensor.data_generators import tokenizer
28
import tensorflow.compat.v1 as tf  # tf
29

30
_INPUT_DIR = "/tmp"
31
_OUTPUT_DIR = "/tmp"
32

33
flags = tf.flags
34
FLAGS = flags.FLAGS
35
gfile = tf.gfile
36

37
flags.DEFINE_string(
38
    "corpus_dir", _INPUT_DIR,
39
    "Full path to the directory containing the data files for a set of tasks.")
40
flags.DEFINE_string(
41
    "vocab_dir", _OUTPUT_DIR,
42
    "Full path to the directory for saving the tf record file.")
43
flags.DEFINE_string("mode", "write",
44
                    "Flag to indicate read vocab csv or write token csv.")
45

46

47
word_count = collections.Counter()
48
freq_count = collections.Counter()
49

50

51
def create_token_id_files(corpus_dir, output_vocab_dir):
52
  """Creates token id csv  files.
53

54
  Args:
55
    corpus_dir: input corpus directory
56
    output_vocab_dir: output token vocabulary csv file directory
57
  """
58
  walking_iter = gfile.Walk(corpus_dir)
59
  for iter_rst in walking_iter:
60
    valid_filenames = [
61
        filename for filename in iter_rst[2]
62
        if ".txt" in filename or "wadata" in filename
63
    ]
64
    if not valid_filenames:
65
      continue
66
    input_file_dir = iter_rst[0]
67
    for filename in valid_filenames:
68
      path = os.path.join(input_file_dir, filename)
69
      with gfile.Open(path, "r") as f:
70
        for line in f.read().lower().split("\n"):
71
          tokens = tokenizer.encode(line)
72
          for token in tokens:
73
            word_count[token] += 1
74

75
  sorted_vocab = sorted(word_count.items(), key=operator.itemgetter(1))
76
  tf.logging.info("%d items in vocb", sum(word_count.values()))
77

78
  csv_file = gfile.Open(os.path.join(output_vocab_dir, "vocab.csv"), "w+")
79
  csv_writter = csv.writer(csv_file)
80

81
  rows = [["<PAD>", 0, 0], ["<EOS>", 0, 1], ["<UKN>", 0, 2], ["<START>", 0, 3]]
82
  for row in rows:
83
    csv_writter.writerow(row)
84
  start_index = len(rows)
85
  for word_freq in reversed(sorted_vocab):
86
    row = [word_freq[0], word_freq[1], start_index]
87
    freq_count[word_freq[1]] += 1
88
    start_index += 1
89
    csv_writter.writerow(row)
90
  tf.logging.info("vocab_size=%d", start_index)
91
  tf.logging.info("token frequency count")
92
  tf.logging.info(sorted(freq_count.items(), key=operator.itemgetter(1)))
93
  csv_file.close()
94

95

96
def read_vocab(vocab_path):
97
  """Reads vocabulary csv file.
98

99
  Args:
100
    vocab_path: full path of the vocabulary csv file
101

102
  Returns:
103
    tokens: list of token strings
104
    freqs: list of token frequencies
105
    ids: list of token ids
106
  """
107
  csv_file = gfile.Open(vocab_path, "r")
108
  csv_reader = csv.reader(csv_file)
109
  tokens, freqs, ids = [], [], []
110

111
  for row in csv_reader:
112
    tokens.append(row[0])
113
    freqs.append(int(row[1]))
114
    ids.append(int(row[2]))
115
  tf.logging.info("Totally %d vocabs", len(tokens))
116
  return tokens, freqs, ids
117

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.