google-research
74 строки · 2.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""VMSST encoder."""
17
18import torch19import tqdm20import transformers21
22
23class VMSSTEncoder:24"""VMSST encoder."""25
26def __init__(27self, device="cuda", max_batch_size=32, max_length=512, cache_dir=None28):29self.max_batch_size = max_batch_size30self.max_length = max_length31if device == "cuda":32self.device = "cuda" if torch.cuda.is_available() else "cpu"33else:34self.device = "cpu"35
36self.tokenizer = transformers.T5Tokenizer.from_pretrained(37"google/mt5-large", cache_dir=cache_dir38)39self.model = transformers.AutoModel.from_pretrained(40"jwieting/vmsst", trust_remote_code=True41)42
43self.model.to(self.device)44self.model.eval()45
46def encode(self, inputs, verbose=False, return_input_ids=False):47"""Function to encode VMSST inputs."""48all_embeddings = []49all_input_ids = []50for i in tqdm.tqdm(51range(0, len(inputs), self.max_batch_size),52total=(len(inputs) // self.max_batch_size) + 1,53disable=not verbose,54desc="Encoding inputs:",55):56tokenized_inputs = self.tokenizer(57inputs[i : i + self.max_batch_size], return_tensors="pt", padding=True58)59
60for k, v in tokenized_inputs.items():61tokenized_inputs[k] = v[:, :self.max_length]62tokenized_inputs = tokenized_inputs.to(self.device)63
64with torch.inference_mode():65batch_embeddings = self.model(**tokenized_inputs)66all_embeddings.append(batch_embeddings)67
68if return_input_ids:69all_input_ids.extend(tokenized_inputs.input_ids.cpu().tolist())70
71return {72"embeddings": torch.cat(all_embeddings, dim=0).detach().cpu().numpy(),73"input_ids": all_input_ids,74}75