google-research

Форк
0
/
vmsst_encoder.py 
74 строки · 2.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""VMSST encoder."""
17

18
import torch
19
import tqdm
20
import transformers
21

22

23
class VMSSTEncoder:
24
  """VMSST encoder."""
25

26
  def __init__(
27
      self, device="cuda", max_batch_size=32, max_length=512, cache_dir=None
28
  ):
29
    self.max_batch_size = max_batch_size
30
    self.max_length = max_length
31
    if device == "cuda":
32
      self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
    else:
34
      self.device = "cpu"
35

36
    self.tokenizer = transformers.T5Tokenizer.from_pretrained(
37
        "google/mt5-large", cache_dir=cache_dir
38
    )
39
    self.model = transformers.AutoModel.from_pretrained(
40
        "jwieting/vmsst", trust_remote_code=True
41
    )
42

43
    self.model.to(self.device)
44
    self.model.eval()
45

46
  def encode(self, inputs, verbose=False, return_input_ids=False):
47
    """Function to encode VMSST inputs."""
48
    all_embeddings = []
49
    all_input_ids = []
50
    for i in tqdm.tqdm(
51
        range(0, len(inputs), self.max_batch_size),
52
        total=(len(inputs) // self.max_batch_size) + 1,
53
        disable=not verbose,
54
        desc="Encoding inputs:",
55
    ):
56
      tokenized_inputs = self.tokenizer(
57
          inputs[i : i + self.max_batch_size], return_tensors="pt", padding=True
58
      )
59

60
      for k, v in tokenized_inputs.items():
61
        tokenized_inputs[k] = v[:, :self.max_length]
62
      tokenized_inputs = tokenized_inputs.to(self.device)
63

64
      with torch.inference_mode():
65
        batch_embeddings = self.model(**tokenized_inputs)
66
      all_embeddings.append(batch_embeddings)
67

68
      if return_input_ids:
69
        all_input_ids.extend(tokenized_inputs.input_ids.cpu().tolist())
70

71
    return {
72
        "embeddings": torch.cat(all_embeddings, dim=0).detach().cpu().numpy(),
73
        "input_ids": all_input_ids,
74
    }
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.