google-research
116 строк · 3.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Creates token vocabulary using tensor2tensor tokenizer."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import csv
24import operator
25import os
26
27from tensor2tensor.data_generators import tokenizer
28import tensorflow.compat.v1 as tf # tf
29
30_INPUT_DIR = "/tmp"
31_OUTPUT_DIR = "/tmp"
32
33flags = tf.flags
34FLAGS = flags.FLAGS
35gfile = tf.gfile
36
37flags.DEFINE_string(
38"corpus_dir", _INPUT_DIR,
39"Full path to the directory containing the data files for a set of tasks.")
40flags.DEFINE_string(
41"vocab_dir", _OUTPUT_DIR,
42"Full path to the directory for saving the tf record file.")
43flags.DEFINE_string("mode", "write",
44"Flag to indicate read vocab csv or write token csv.")
45
46
47word_count = collections.Counter()
48freq_count = collections.Counter()
49
50
51def create_token_id_files(corpus_dir, output_vocab_dir):
52"""Creates token id csv files.
53
54Args:
55corpus_dir: input corpus directory
56output_vocab_dir: output token vocabulary csv file directory
57"""
58walking_iter = gfile.Walk(corpus_dir)
59for iter_rst in walking_iter:
60valid_filenames = [
61filename for filename in iter_rst[2]
62if ".txt" in filename or "wadata" in filename
63]
64if not valid_filenames:
65continue
66input_file_dir = iter_rst[0]
67for filename in valid_filenames:
68path = os.path.join(input_file_dir, filename)
69with gfile.Open(path, "r") as f:
70for line in f.read().lower().split("\n"):
71tokens = tokenizer.encode(line)
72for token in tokens:
73word_count[token] += 1
74
75sorted_vocab = sorted(word_count.items(), key=operator.itemgetter(1))
76tf.logging.info("%d items in vocb", sum(word_count.values()))
77
78csv_file = gfile.Open(os.path.join(output_vocab_dir, "vocab.csv"), "w+")
79csv_writter = csv.writer(csv_file)
80
81rows = [["<PAD>", 0, 0], ["<EOS>", 0, 1], ["<UKN>", 0, 2], ["<START>", 0, 3]]
82for row in rows:
83csv_writter.writerow(row)
84start_index = len(rows)
85for word_freq in reversed(sorted_vocab):
86row = [word_freq[0], word_freq[1], start_index]
87freq_count[word_freq[1]] += 1
88start_index += 1
89csv_writter.writerow(row)
90tf.logging.info("vocab_size=%d", start_index)
91tf.logging.info("token frequency count")
92tf.logging.info(sorted(freq_count.items(), key=operator.itemgetter(1)))
93csv_file.close()
94
95
96def read_vocab(vocab_path):
97"""Reads vocabulary csv file.
98
99Args:
100vocab_path: full path of the vocabulary csv file
101
102Returns:
103tokens: list of token strings
104freqs: list of token frequencies
105ids: list of token ids
106"""
107csv_file = gfile.Open(vocab_path, "r")
108csv_reader = csv.reader(csv_file)
109tokens, freqs, ids = [], [], []
110
111for row in csv_reader:
112tokens.append(row[0])
113freqs.append(int(row[1]))
114ids.append(int(row[2]))
115tf.logging.info("Totally %d vocabs", len(tokens))
116return tokens, freqs, ids
117