google-research

Форк
0
/
bert_client.py 
119 строк · 4.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""API to compute BERT tokenization and embeddings."""
17

18
import os
19
from typing import Dict, List, Tuple
20

21
import numpy as np
22
import tensorflow as tf
23
import tensorflow_hub as hub
24
import tensorflow_text as text  # pylint: disable=unused-import
25

26
DEFAULT_TOKENIZER_TFHUB_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
27
DEFAULT_ENCODER_TFHUB_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3'
28

29

30
class Tokenizer(object):
31
  """Tokenizes words to token_ids."""
32

33
  def __init__(self, tfhub_handle = DEFAULT_TOKENIZER_TFHUB_HANDLE):
34
    preprocessor = hub.load(tfhub_handle)
35
    model_path = hub.resolve(tfhub_handle)
36
    vocab_file_path = os.path.join(model_path, 'assets/vocab.txt')
37
    with tf.io.gfile.GFile(vocab_file_path, 'r') as vocab_file:
38
      self.vocab = [token.strip() for token in vocab_file.readlines()]
39

40
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
41
    tokenize_layer = hub.KerasLayer(preprocessor.tokenize)
42
    outputs = tokenize_layer(text_input)
43

44
    self.model = tf.keras.Model(
45
        inputs=text_input, outputs=outputs, name='tokenizer')
46

47
  def tokenize(self, utterance):
48
    """Returns tokens and token_ids in utterance (used for BERT embeddings)."""
49
    model_input = tf.constant([utterance], dtype=tf.string)
50
    token_ids = self.model(model_input).numpy()
51
    token_ids = list(np.concatenate(token_ids[0]).flat)
52
    tokens = [self.vocab[tid] for tid in token_ids]
53
    return tokens, token_ids
54

55

56
class BertClient(object):
57
  """Computes BERT embeddings for input tokens."""
58

59
  def __init__(self, tf_hub_source = DEFAULT_ENCODER_TFHUB_HANDLE):
60
    self.batch_size = 1
61
    self.max_seq_length = 128
62
    input_shape = (self.max_seq_length)
63
    encoder_inputs = {
64
        'input_type_ids': tf.keras.layers.Input(input_shape, dtype=tf.int32),
65
        'input_word_ids': tf.keras.layers.Input(input_shape, dtype=tf.int32),
66
        'input_mask': tf.keras.layers.Input(input_shape, dtype=tf.int32)
67
    }
68
    encoder = hub.KerasLayer(tf_hub_source, trainable=True)
69
    encoder_outputs = encoder(encoder_inputs)
70
    sequence_output = encoder_outputs['sequence_output']
71

72
    self.model = tf.keras.Model(
73
        inputs=encoder_inputs, outputs=sequence_output, name='bert_model')
74

75
  def predict_batch(self, msg_ids, token_batch,
76
                    mask_batch):
77
    """Run BERT on one batch of input data."""
78
    res = {}
79
    model_input = {
80
        'input_type_ids':
81
            tf.zeros([self.batch_size, self.max_seq_length], dtype=tf.int32),
82
        'input_word_ids':
83
            tf.constant(token_batch, dtype=tf.int32),
84
        'input_mask':
85
            tf.constant(mask_batch, dtype=tf.int32)
86
    }
87
    output = self.model(model_input).numpy()
88
    for i, msg_id in enumerate(msg_ids):
89
      res[msg_id] = output[i, :, :].tolist()
90
    return res
91

92
  def lookup(self, messages):
93
    """Look up BERT embeddings for the tokens in messages."""
94
    result = {}
95
    token_batch = []
96
    mask_batch = []
97
    msg_ids = []
98
    for msg_id, msg in messages.items():
99
      token_batch.append(msg + [0] * (self.max_seq_length - len(msg)))
100
      mask_batch.append([1] * len(msg) + [0] * (self.max_seq_length - len(msg)))
101
      msg_ids.append(msg_id)
102

103
      if len(msg_ids) == self.batch_size:
104
        res = self.predict_batch(msg_ids, token_batch, mask_batch)
105
        result.update(res)
106
        token_batch = []
107
        mask_batch = []
108
        msg_ids = []
109

110
    if msg_ids and self.batch_size > len(msg_ids):
111
      to_add = self.batch_size - len(msg_ids)
112
      msg_ids.extend([-1] * to_add)
113
      token_batch.extend([0] * to_add)
114
      mask_batch.extend([0] * to_add)
115

116
      res = self.predict_batch(msg_ids, token_batch, mask_batch)
117
      result.update(res)
118

119
    return result
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.