paddlenlp

Форк
0
368 строк · 13.6 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import io
17
import json
18
import multiprocessing
19
import os
20
import re
21
import sys
22
import time
23

24
import numpy as np
25
from tqdm import tqdm
26

27
try:
28
    from ppfleetx.data import tokenizers as tfs
29
except ImportError:
30
    __dir__ = os.path.dirname(os.path.abspath(__file__))
31
    sys.path.append(os.path.abspath(os.path.join(__dir__, "../../../../")))
32
    from ppfleetx.data import tokenizers as tfs
33

34
try:
35
    import nltk
36

37
    nltk_available = True
38
except ImportError:
39
    nltk_available = False
40

41
CHINESE_SEG_FUNC = {}
42

43

44
def get_args():
45
    parser = argparse.ArgumentParser()
46
    parser.add_argument("--model_name", type=str, required=True, help="What model to use.")
47
    parser.add_argument(
48
        "--tokenizer_name",
49
        type=str,
50
        required=True,
51
        choices=["ErnieTokenizer", "BertTokenizer", "GPTTokenizer", "GPTChineseTokenizer", "ElectraTokenizer"],
52
        help="What type of tokenizer to use.",
53
    )
54
    group = parser.add_argument_group(title="data input/output")
55
    group.add_argument("--input_path", type=str, required=True, help="Path to input JSON files.")
56
    group.add_argument("--output_prefix", type=str, required=True, help="Output prefix to store output file.")
57
    group.add_argument(
58
        "--data_format",
59
        type=str,
60
        default="text",
61
        choices=["JSON"],
62
        help="Only support json format for now. One document per line.",
63
    )
64
    group.add_argument(
65
        "--json_key",
66
        type=str,
67
        default="text",
68
        help="For JSON format. Space separate listed of keys to extract from json",
69
    )
70
    group.add_argument("--split_sentences", action="store_true", help="Split documents into sentences.")
71

72
    group = parser.add_argument_group(title="chinese words")
73
    group.add_argument(
74
        "--chinese", action="store_true", help="Is corpus need words segmentation step for chinese words."
75
    )
76
    group.add_argument(
77
        "--cn_whole_word_segment",
78
        action="store_true",
79
        help="Is corpus need words segmentation step for chinese words WWM.",
80
    )
81
    group.add_argument(
82
        "--cn_seg_func",
83
        type=str,
84
        default="jieba",
85
        choices=["lac", "seg", "jieba"],
86
        help="Words segment function for chinese words.",
87
    )
88
    group.add_argument("--cn_splited", action="store_true", help="Is chinese corpus is splited in to words.")
89
    group.add_argument("--cn_split_dimer", type=str, default=" ", help="Split dimer between chinese words.")
90

91
    group = parser.add_argument_group(title="common config")
92
    group.add_argument("--append_eos", action="store_true", help="Append an <eos> token to the end of a document.")
93
    group.add_argument("--log_interval", type=int, default=100, help="Interval between progress updates")
94
    group.add_argument("--workers", type=int, default=1, help="Number of worker processes to launch")
95

96
    args = parser.parse_args()
97
    if args.chinese:
98
        global CHINESE_SEG_FUNC
99
        CHINESE_SEG_FUNC["lac"] = lexical_analysis_fn()
100
        CHINESE_SEG_FUNC["seg"] = chinese_segmentation_fn()
101
        CHINESE_SEG_FUNC["jieba"] = jieba_segmentation_fn()
102

103
    return args
104

105

106
def lexical_analysis_fn():
107
    from LAC import LAC
108

109
    lac = LAC(mode="lac")
110

111
    def process(line):
112
        words, _ = lac.run(line)
113
        return words
114

115
    return process
116

117

118
def chinese_segmentation_fn():
119
    from LAC import LAC
120

121
    lac_cws = LAC(mode="seg")
122

123
    def process(line):
124
        words = lac_cws.run(line)
125
        return words
126

127
    return process
128

129

130
def jieba_segmentation_fn():
131
    import jieba
132

133
    def process(line):
134
        words = jieba.cut(line)
135
        return list(words)
136

137
    return process
138

139

140
def get_whole_word_mask_tokens(tokens, words, max_word_length=4):
141
    """
142
    Do whole word mask on Chinese word.
143
    First, we do Chinese word segmentation on the sequence of tokens, which are from the WordPiece tokenization.
144
    Then, we add the '##' mark on chinese characters which are in the middle of Chinese words.
145
    And if the tokens are not chinese characters, we just exploit the results of WordPiece tokenization as words.
146
    Such as,
147
         - text line : 通过利用mercer核,将样本从输入空间映射到高维特征空间,使原来没有显现的特征突现出来,取得了很好的图像分割效果。
148
         - the input tokens (after WordPiece):
149
            ['通', '过', '利', '用', 'me', '##rc', '##er', '核', ',', '将', '样', '本', '从', '输', '入', '空', '间', '映',
150
            '射', '到', '高', '维', '特', '征', '空', '间', ',', '使', '原', '来', '没', '有', '显', '现', '的', '特', '征',
151
            '突', '现', '出', '来', ',', '取', '得', '了', '很', '好', '的', '图', '像', '分', '割', '效', '果', '。']
152
        - the Chinese words (after Chinese word segmentation like jieba)
153
            ['通过', '利用', 'mercer', '核', ',', '将', '样本', '从', '输入', '空间', '映射', '到', '高维', '特征',
154
            '空间', ',', '使', '原来', '没有', '显现', '的', '特征', '突现', '出来', ',', '取得', '了', '很', '好',
155
            '的', '图像', '分割', '效果', '。']
156
        - the output whole word mask tokens:
157
            ['通', '##过', '利', '##用', 'me', '##rc', '##er', '核', ',', '将', '样', '##本', '从', '输', '##入',
158
            '空', '##间', '映', '##射', '到', '高', '##维', '特', '##征', '空', '##间', ',', '使', '原', '##来',
159
            '没', '##有', '显', '##现', '的', '特', '##征', '突', '##现', '出', '##来', ',', '取', '##得', '了',
160
            '很', '好', '的', '图', '##像', '分', '##割', '效', '##果', '。']
161
    Args:
162
        tokens(list(str)): The sequence of tokens, which are from the WordPiece tokenization.
163
        words(list(str)): The sequence of Chinese words.
164
        max_word_length(int, optional):
165
            The maximum chinese character in Chinese words. It avoids too long Chinese word to be masked.
166
            Defaults as 4.
167
    Returns:
168
         new_tokens(list(str)): The new token will be done with whole word masking strategy.
169
    """
170

171
    new_tokens = []
172
    # opt for long document
173
    words_set = set(words)
174
    i = 0
175
    while i < len(tokens):
176
        # non-chinese character, then do word piece
177
        if len(re.findall("[\u4E00-\u9FA5]", tokens[i])) == 0:
178
            new_tokens.append(tokens[i])
179
            i += 1
180
            continue
181

182
        # add "##" mark on the middel tokens of Chinese words
183
        # such as ["通过", "利用"] -> ["通", "##过", "利", "##用"]
184
        has_add = False
185
        for length in range(max_word_length, 0, -1):
186
            if i + length > len(tokens):
187
                continue
188
            if "".join(tokens[i : i + length]) in words_set:
189
                new_tokens.append(tokens[i])
190
                for l in range(1, length):
191
                    new_tokens.append("##" + tokens[i + l])
192
                i += length
193
                has_add = True
194
                break
195

196
        if not has_add:
197
            new_tokens.append(tokens[i])
198
            i += 1
199
    return new_tokens
200

201

202
class IdentitySplitter(object):
203
    def tokenize(self, *text):
204
        return text
205

206

207
class NewlineSplitter:
208
    def tokenize(self, text):
209
        return text.split("\n")
210

211

212
class Converter(object):
213
    def __init__(self, args):
214
        self.args = args
215

216
    def initializer(self):
217
        Converter.tokenizer = getattr(tfs, self.args.tokenizer_name).from_pretrained(self.args.model_name)
218

219
        # Split document to sentence.
220
        if self.args.split_sentences:
221
            if self.args.chinese:
222
                Converter.splitter = NewlineSplitter()
223
            else:
224
                if not nltk_available:
225
                    print("NLTK is not available to split sentences.")
226
                    exit()
227
                splitter = nltk.load("tokenizers/punkt/english.pickle")
228
                Converter.splitter = splitter
229
        else:
230
            Converter.splitter = IdentitySplitter()
231

232
        # Split sentence whole words mask for chinese
233
        if self.args.cn_whole_word_segment:
234
            if self.args.cn_splited:
235
                Converter.segment_func = lambda text: text.split(self.args.cn_split_dimer)
236
            else:
237
                Converter.segment_func = CHINESE_SEG_FUNC[self.args.cn_seg_func]
238
            Converter.whole_word_mask = get_whole_word_mask_tokens
239
        else:
240
            Converter.segment_func = lambda x: x
241
            Converter.whole_word_mask = lambda x, y: x
242

243
        def process(text):
244
            words = Converter.segment_func(text)
245
            tokens = Converter.tokenizer.tokenize("".join(words))
246
            tokens = Converter.whole_word_mask(tokens, words)
247
            tokens = Converter.tokenizer.convert_tokens_to_ids(tokens)
248
            return tokens
249

250
        Converter.process = process
251

252
    def encode(self, json_line):
253
        text = json.loads(json_line)[self.args.json_key]
254
        doc_ids = []
255
        for sentence in Converter.splitter.tokenize(text):
256
            sentence_ids = Converter.process(sentence.strip())
257
            if len(sentence_ids) > 0:
258
                doc_ids.append(sentence_ids)
259

260
        if len(doc_ids) > 0 and self.args.append_eos:
261
            doc_ids[-1].append(Converter.tokenizer.eos_token_id)
262

263
        return doc_ids, len(text.encode("utf-8"))
264

265

266
def main():
267
    args = get_args()
268

269
    file_paths = []
270
    if os.path.isfile(args.input_path):
271
        file_paths.append(args.input_path)
272
    else:
273
        for root, _, fs in os.walk(args.input_path):
274
            for f in fs:
275
                file_paths.append(os.path.join(root, f))
276
    if len(file_paths) == 0:
277
        print("No input file found!")
278
        exit(-1)
279

280
    convert = Converter(args)
281

282
    # Try tokenizer is availiable
283
    sample_tokenizer = getattr(tfs, args.tokenizer_name).from_pretrained(args.model_name)
284
    if sample_tokenizer.vocab_size < 2**16 - 1:
285
        save_dtype = np.uint16
286
    else:
287
        save_dtype = np.int32
288

289
    pool = multiprocessing.Pool(args.workers, initializer=convert.initializer)
290

291
    # We use BytesIO to store the ids.
292
    token_ids_stream = io.BytesIO()
293
    sentlens_stream = io.BytesIO()
294
    # # Cumsum on tokens num
295
    # sent_cumsum_stream = io.BytesIO()
296
    # sent_cumsum_stream.write((0).to_bytes(8, byteorder='little', signed=True))
297
    # Cunsum on document on every sentence num, type=np.int64
298
    doc_cumsum_stream = io.BytesIO()
299
    doc_cumsum_stream.write((0).to_bytes(8, byteorder="little", signed=True))
300

301
    sent_count = 0
302
    # token_count = 0
303

304
    file_paths.sort()
305

306
    step = 0
307
    total_bytes_processed = 0
308
    startup_start = time.time()
309
    for file_path in tqdm(file_paths):
310
        if file_path.endswith(".zst"):
311
            import zstandard
312

313
            cctx = zstandard.ZstdDecompressor()
314
            fh = open(file_path, "rb")
315
            text = io.BufferedReader(cctx.stream_reader(fh))
316
        elif file_path.endswith(".jsonl"):
317
            text = open(file_path, "r", encoding="utf-8")
318
        else:
319
            print("Unexpected data format, skiped %s" % file_path)
320
            continue
321

322
        encoded_docs = pool.imap(convert.encode, text, 256)
323
        print("Processing %s" % file_path)
324
        for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
325
            step += 1
326
            total_bytes_processed += bytes_processed
327
            if len(doc) == 0:
328
                continue
329

330
            for sentence in doc:
331
                sentence_len = len(sentence)
332
                if sentence_len == 0:
333
                    continue
334
                sentlens_stream.write(sentence_len.to_bytes(4, byteorder="little", signed=True))
335
                # token_count += sentence_len
336
                # sent_cumsum_stream.write(
337
                #     token_count.to_bytes(
338
                #         8, byteorder='little', signed=True))
339
                sent_count += 1
340
                token_ids_stream.write(np.array(sentence, dtype=save_dtype).tobytes(order="C"))
341

342
            doc_cumsum_stream.write(sent_count.to_bytes(8, byteorder="little", signed=True))
343

344
            if step % args.log_interval == 0:
345
                current = time.time()
346
                elapsed = current - startup_start
347
                mbs = total_bytes_processed / elapsed / 1024 / 1024
348
                print(f"Processed {step} documents", f"({step/elapsed:.2f} docs/s, {mbs:.4f} MB/s).", file=sys.stderr)
349

350
    pool.close()
351
    print("Saving tokens to files...")
352
    all_doc_ids = np.frombuffer(token_ids_stream.getbuffer(), dtype=save_dtype)
353
    lens = np.frombuffer(sentlens_stream.getbuffer(), dtype=np.int32)
354
    # sents = np.frombuffer(sent_cumsum_stream.getbuffer(), dtype=np.int64)
355
    docs = np.frombuffer(doc_cumsum_stream.getbuffer(), dtype=np.int64)
356
    np.save(args.output_prefix + "_ids.npy", all_doc_ids)
357
    # np.savez(args.output_prefix + "_idx.npz", lens=lens, sents=sents, docs=docs)
358
    np.savez(args.output_prefix + "_idx.npz", lens=lens, docs=docs)
359

360
    print("Total sentences num: %d" % len(lens))
361
    print("Total documents num: %d" % (len(docs) - 1))
362
    print("Total tokens num: %d" % len(all_doc_ids))
363
    print("Average tokens per sentence: %.2f" % (len(all_doc_ids) / len(lens)))
364
    print("Average tokens per document: %.2f" % (len(all_doc_ids) / (len(docs) - 1)))
365

366

367
if __name__ == "__main__":
368
    main()
369

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.