google-research
119 строк · 4.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""API to compute BERT tokenization and embeddings."""
17
18import os19from typing import Dict, List, Tuple20
21import numpy as np22import tensorflow as tf23import tensorflow_hub as hub24import tensorflow_text as text # pylint: disable=unused-import25
26DEFAULT_TOKENIZER_TFHUB_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'27DEFAULT_ENCODER_TFHUB_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3'28
29
30class Tokenizer(object):31"""Tokenizes words to token_ids."""32
33def __init__(self, tfhub_handle = DEFAULT_TOKENIZER_TFHUB_HANDLE):34preprocessor = hub.load(tfhub_handle)35model_path = hub.resolve(tfhub_handle)36vocab_file_path = os.path.join(model_path, 'assets/vocab.txt')37with tf.io.gfile.GFile(vocab_file_path, 'r') as vocab_file:38self.vocab = [token.strip() for token in vocab_file.readlines()]39
40text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)41tokenize_layer = hub.KerasLayer(preprocessor.tokenize)42outputs = tokenize_layer(text_input)43
44self.model = tf.keras.Model(45inputs=text_input, outputs=outputs, name='tokenizer')46
47def tokenize(self, utterance):48"""Returns tokens and token_ids in utterance (used for BERT embeddings)."""49model_input = tf.constant([utterance], dtype=tf.string)50token_ids = self.model(model_input).numpy()51token_ids = list(np.concatenate(token_ids[0]).flat)52tokens = [self.vocab[tid] for tid in token_ids]53return tokens, token_ids54
55
56class BertClient(object):57"""Computes BERT embeddings for input tokens."""58
59def __init__(self, tf_hub_source = DEFAULT_ENCODER_TFHUB_HANDLE):60self.batch_size = 161self.max_seq_length = 12862input_shape = (self.max_seq_length)63encoder_inputs = {64'input_type_ids': tf.keras.layers.Input(input_shape, dtype=tf.int32),65'input_word_ids': tf.keras.layers.Input(input_shape, dtype=tf.int32),66'input_mask': tf.keras.layers.Input(input_shape, dtype=tf.int32)67}68encoder = hub.KerasLayer(tf_hub_source, trainable=True)69encoder_outputs = encoder(encoder_inputs)70sequence_output = encoder_outputs['sequence_output']71
72self.model = tf.keras.Model(73inputs=encoder_inputs, outputs=sequence_output, name='bert_model')74
75def predict_batch(self, msg_ids, token_batch,76mask_batch):77"""Run BERT on one batch of input data."""78res = {}79model_input = {80'input_type_ids':81tf.zeros([self.batch_size, self.max_seq_length], dtype=tf.int32),82'input_word_ids':83tf.constant(token_batch, dtype=tf.int32),84'input_mask':85tf.constant(mask_batch, dtype=tf.int32)86}87output = self.model(model_input).numpy()88for i, msg_id in enumerate(msg_ids):89res[msg_id] = output[i, :, :].tolist()90return res91
92def lookup(self, messages):93"""Look up BERT embeddings for the tokens in messages."""94result = {}95token_batch = []96mask_batch = []97msg_ids = []98for msg_id, msg in messages.items():99token_batch.append(msg + [0] * (self.max_seq_length - len(msg)))100mask_batch.append([1] * len(msg) + [0] * (self.max_seq_length - len(msg)))101msg_ids.append(msg_id)102
103if len(msg_ids) == self.batch_size:104res = self.predict_batch(msg_ids, token_batch, mask_batch)105result.update(res)106token_batch = []107mask_batch = []108msg_ids = []109
110if msg_ids and self.batch_size > len(msg_ids):111to_add = self.batch_size - len(msg_ids)112msg_ids.extend([-1] * to_add)113token_batch.extend([0] * to_add)114mask_batch.extend([0] * to_add)115
116res = self.predict_batch(msg_ids, token_batch, mask_batch)117result.update(res)118
119return result120