google-research
428 строк · 12.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tokenization classes. Copied from bert."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import re
24import unicodedata
25
26from absl import flags
27import six
28import tensorflow.compat.v1 as tf
29
30FLAGS = flags.FLAGS
31
32flags.DEFINE_bool(
33"preserve_unused_tokens", False,
34"If True, Wordpiece tokenization will not be applied to words in the vocab."
35)
36
37_UNUSED_TOKEN_RE = re.compile("^\\[unused\\d+\\]$")
38
39
40def preserve_token(token, vocab):
41"""Returns True if the token should forgo tokenization and be preserved."""
42if not FLAGS.preserve_unused_tokens:
43return False
44if token not in vocab:
45return False
46return bool(_UNUSED_TOKEN_RE.search(token))
47
48
49def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
50"""Checks whether the casing config is consistent with the checkpoint name."""
51
52# The casing has to be passed in by the user and there is no explicit check
53# as to whether it matches the checkpoint. The casing information probably
54# should have been stored in the bert_config.json file, but it's not, so
55# we have to heuristically detect it to validate.
56
57if not init_checkpoint:
58return
59
60m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
61if m is None:
62return
63
64model_name = m.group(1)
65
66lower_models = [
67"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
68"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
69]
70
71cased_models = [
72"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
73"multi_cased_L-12_H-768_A-12"
74]
75
76is_bad_config = False
77if model_name in lower_models and not do_lower_case:
78is_bad_config = True
79actual_flag = "False"
80case_name = "lowercased"
81opposite_flag = "True"
82
83if model_name in cased_models and do_lower_case:
84is_bad_config = True
85actual_flag = "True"
86case_name = "cased"
87opposite_flag = "False"
88
89if is_bad_config:
90raise ValueError(
91"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
92"However, `%s` seems to be a %s model, so you "
93"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
94"how the model was pre-training. If this error is wrong, please "
95"just comment out this check." % (actual_flag, init_checkpoint,
96model_name, case_name, opposite_flag))
97
98
99def convert_to_unicode(text):
100"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
101if six.PY3:
102if isinstance(text, str):
103return text
104elif isinstance(text, bytes):
105return text.decode("utf-8", "ignore")
106else:
107raise ValueError("Unsupported string type: %s" % (type(text)))
108elif six.PY2:
109if isinstance(text, str):
110return text.decode("utf-8", "ignore")
111elif isinstance(text, unicode):
112return text
113else:
114raise ValueError("Unsupported string type: %s" % (type(text)))
115else:
116raise ValueError("Not running on Python2 or Python 3?")
117
118
119def printable_text(text):
120"""Returns text encoded in a way suitable for print or `tf.logging`."""
121
122# These functions want `str` for both Python2 and Python3, but in one case
123# it's a Unicode string and in the other it's a byte string.
124if six.PY3:
125if isinstance(text, str):
126return text
127elif isinstance(text, bytes):
128return text.decode("utf-8", "ignore")
129else:
130raise ValueError("Unsupported string type: %s" % (type(text)))
131elif six.PY2:
132if isinstance(text, str):
133return text
134elif isinstance(text, unicode):
135return text.encode("utf-8")
136else:
137raise ValueError("Unsupported string type: %s" % (type(text)))
138else:
139raise ValueError("Not running on Python2 or Python 3?")
140
141
142def load_vocab(vocab_file):
143"""Loads a vocabulary file into a dictionary."""
144vocab = collections.OrderedDict()
145with tf.io.gfile.GFile(vocab_file, "r") as reader:
146while True:
147token = convert_to_unicode(reader.readline())
148if not token:
149break
150token = token.strip()
151if token not in vocab:
152vocab[token] = len(vocab)
153return vocab
154
155
156def convert_by_vocab(vocab, items):
157"""Converts a sequence of [tokens|ids] using the vocab."""
158output = []
159for item in items:
160output.append(vocab[item])
161return output
162
163
164def convert_tokens_to_ids(vocab, tokens):
165return convert_by_vocab(vocab, tokens)
166
167
168def convert_ids_to_tokens(inv_vocab, ids):
169return convert_by_vocab(inv_vocab, ids)
170
171
172def whitespace_tokenize(text):
173"""Runs basic whitespace cleaning and splitting on a piece of text."""
174text = text.strip()
175if not text:
176return []
177tokens = text.split()
178return tokens
179
180
181class FullTokenizer(object):
182"""Runs end-to-end tokenziation."""
183
184def __init__(self, vocab_file, do_lower_case=True):
185self.vocab = load_vocab(vocab_file)
186self.inv_vocab = {v: k for k, v in self.vocab.items()}
187self.basic_tokenizer = BasicTokenizer(
188do_lower_case=do_lower_case, vocab=self.vocab)
189self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
190
191def tokenize(self, text):
192split_tokens = []
193for token in self.basic_tokenizer.tokenize(text):
194if preserve_token(token, self.vocab):
195split_tokens.append(token)
196continue
197for sub_token in self.wordpiece_tokenizer.tokenize(token):
198split_tokens.append(sub_token)
199
200return split_tokens
201
202def convert_tokens_to_ids(self, tokens):
203return convert_by_vocab(self.vocab, tokens)
204
205def convert_ids_to_tokens(self, ids):
206return convert_by_vocab(self.inv_vocab, ids)
207
208
209class BasicTokenizer(object):
210"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
211
212def __init__(self, do_lower_case=True, vocab=tuple()):
213"""Constructs a BasicTokenizer.
214
215Args:
216do_lower_case: Whether to lower case the input.
217vocab: A container of tokens to not mutate during tokenization.
218"""
219self.do_lower_case = do_lower_case
220self.vocab = vocab
221
222def tokenize(self, text):
223"""Tokenizes a piece of text."""
224text = convert_to_unicode(text)
225text = self._clean_text(text)
226
227# This was added on November 1st, 2018 for the multilingual and Chinese
228# models. This is also applied to the English models now, but it doesn't
229# matter since the English models were not trained on any Chinese data
230# and generally don't have any Chinese data in them (there are Chinese
231# characters in the vocabulary because Wikipedia does have some Chinese
232# words in the English Wikipedia.).
233text = self._tokenize_chinese_chars(text)
234
235orig_tokens = whitespace_tokenize(text)
236split_tokens = []
237for token in orig_tokens:
238if preserve_token(token, self.vocab):
239split_tokens.append(token)
240continue
241if self.do_lower_case:
242token = token.lower()
243token = self._run_strip_accents(token)
244split_tokens.extend(self._run_split_on_punc(token))
245
246output_tokens = whitespace_tokenize(" ".join(split_tokens))
247return output_tokens
248
249def _run_strip_accents(self, text):
250"""Strips accents from a piece of text."""
251text = unicodedata.normalize("NFD", text)
252output = []
253for char in text:
254cat = unicodedata.category(char)
255if cat == "Mn":
256continue
257output.append(char)
258return "".join(output)
259
260def _run_split_on_punc(self, text):
261"""Splits punctuation on a piece of text."""
262chars = list(text)
263i = 0
264start_new_word = True
265output = []
266while i < len(chars):
267char = chars[i]
268if _is_punctuation(char):
269output.append([char])
270start_new_word = True
271else:
272if start_new_word:
273output.append([])
274start_new_word = False
275output[-1].append(char)
276i += 1
277
278return ["".join(x) for x in output]
279
280def _tokenize_chinese_chars(self, text):
281"""Adds whitespace around any CJK character."""
282output = []
283for char in text:
284cp = ord(char)
285if self._is_chinese_char(cp):
286output.append(" ")
287output.append(char)
288output.append(" ")
289else:
290output.append(char)
291return "".join(output)
292
293def _is_chinese_char(self, cp):
294"""Checks whether CP is the codepoint of a CJK character."""
295# This defines a "chinese character" as anything in the CJK Unicode block:
296# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
297#
298# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
299# despite its name. The modern Korean Hangul alphabet is a different block,
300# as is Japanese Hiragana and Katakana. Those alphabets are used to write
301# space-separated words, so they are not treated specially and handled
302# like the all of the other languages.
303if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
304(cp >= 0x3400 and cp <= 0x4DBF) or #
305(cp >= 0x20000 and cp <= 0x2A6DF) or #
306(cp >= 0x2A700 and cp <= 0x2B73F) or #
307(cp >= 0x2B740 and cp <= 0x2B81F) or #
308(cp >= 0x2B820 and cp <= 0x2CEAF) or
309(cp >= 0xF900 and cp <= 0xFAFF) or #
310(cp >= 0x2F800 and cp <= 0x2FA1F)): #
311return True
312
313return False
314
315def _clean_text(self, text):
316"""Performs invalid character removal and whitespace cleanup on text."""
317output = []
318for char in text:
319cp = ord(char)
320if cp == 0 or cp == 0xfffd or _is_control(char):
321continue
322if _is_whitespace(char):
323output.append(" ")
324else:
325output.append(char)
326return "".join(output)
327
328
329class WordpieceTokenizer(object):
330"""Runs WordPiece tokenziation."""
331
332def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
333self.vocab = vocab
334self.unk_token = unk_token
335self.max_input_chars_per_word = max_input_chars_per_word
336
337def tokenize(self, text):
338"""Tokenizes a piece of text into its word pieces.
339
340This uses a greedy longest-match-first algorithm to perform tokenization
341using the given vocabulary.
342
343For example:
344input = "unaffable"
345output = ["un", "##aff", "##able"]
346
347Args:
348text: A single token or whitespace separated tokens. This should have
349already been passed through `BasicTokenizer.
350
351Returns:
352A list of wordpiece tokens.
353"""
354
355text = convert_to_unicode(text)
356
357output_tokens = []
358for token in whitespace_tokenize(text):
359chars = list(token)
360if len(chars) > self.max_input_chars_per_word:
361output_tokens.append(self.unk_token)
362continue
363
364is_bad = False
365start = 0
366sub_tokens = []
367while start < len(chars):
368end = len(chars)
369cur_substr = None
370while start < end:
371substr = "".join(chars[start:end])
372if start > 0:
373substr = "##" + substr
374if substr in self.vocab:
375cur_substr = substr
376break
377end -= 1
378if cur_substr is None:
379is_bad = True
380break
381sub_tokens.append(cur_substr)
382start = end
383
384if is_bad:
385output_tokens.append(self.unk_token)
386else:
387output_tokens.extend(sub_tokens)
388return output_tokens
389
390
391def _is_whitespace(char):
392"""Checks whether `chars` is a whitespace character."""
393# \t, \n, and \r are technically control characters but we treat them
394# as whitespace since they are generally considered as such.
395if char == " " or char == "\t" or char == "\n" or char == "\r":
396return True
397cat = unicodedata.category(char)
398if cat == "Zs":
399return True
400return False
401
402
403def _is_control(char):
404"""Checks whether `chars` is a control character."""
405# These are technically control characters but we count them as whitespace
406# characters.
407if char == "\t" or char == "\n" or char == "\r":
408return False
409cat = unicodedata.category(char)
410if cat in ("Cc", "Cf"):
411return True
412return False
413
414
415def _is_punctuation(char):
416"""Checks whether `chars` is a punctuation character."""
417cp = ord(char)
418# We treat all non-letter/number ASCII as punctuation.
419# Characters such as "^", "$", and "`" are not in the Unicode
420# Punctuation class but we treat them as punctuation anyways, for
421# consistency.
422if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
423(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
424return True
425cat = unicodedata.category(char)
426if cat.startswith("P"):
427return True
428return False
429