aws-genai-llm-chatbot

Форк
0
150 строк · 4.6 Кб
1
import os
2
import torch
3
import logging
4
import torch.nn.functional as F
5
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer
6

7
logger = logging.getLogger(__name__)
8
logger.setLevel(logging.INFO)
9

10
"""
11
{
12
    "type": "embeddings",
13
    "model": "intfloat/multilingual-e5-large",
14
    "input": "I love Berlin",
15
}
16

17
{
18
    "type": "cross-encoder",
19
    "model": "cross-encoder/ms-marco-MiniLM-L-12-v2",
20
    "input": "I love Berlin",
21
    "passages": ["I love Paris", "I love London"]
22
}
23

24
"""
25

26
embeddings_models = [
27
    "intfloat/multilingual-e5-large",
28
    "sentence-transformers/all-MiniLM-L6-v2",
29
]
30
cross_encoder_models = ["cross-encoder/ms-marco-MiniLM-L-12-v2"]
31

32

33
def process_model_list(model_list):
34
    return list(map(lambda x: x.split("/")[-1], model_list))
35

36

37
def mean_pooling(model_output, attention_mask):
38
    """Mean Pooling - Take attention mask into account for correct averaging"""
39
    # First element of model_output contains all token embeddings
40
    token_embeddings = model_output[0]
41
    input_mask_expanded = (
42
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
43
    )
44
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
45
        input_mask_expanded.sum(1), min=1e-9
46
    )
47

48

49
def model_fn(model_dir):
50
    logger.info("model_fn")
51
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52

53
    config = {}
54
    for model_id in process_model_list(embeddings_models):
55
        embeddings_model_dir = f"{model_dir}/{model_id}"
56
        embeddings_tokenizer = AutoTokenizer.from_pretrained(embeddings_model_dir)
57
        embeddings_model = AutoModel.from_pretrained(embeddings_model_dir)
58
        embeddings_model.eval()
59
        embeddings_model.to(device)
60

61
        model_config = {
62
            "model": embeddings_model,
63
            "tokenizer": embeddings_tokenizer,
64
        }
65

66
        config[model_id] = model_config
67

68
    for model_id in process_model_list(cross_encoder_models):
69
        cross_encoder_model_dir = os.path.join(model_dir, model_id)
70
        cross_encoder_model = AutoModelForSequenceClassification.from_pretrained(
71
            cross_encoder_model_dir
72
        )
73
        cross_encoder_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_model_dir)
74

75
        cross_encoder_model.eval()
76
        cross_encoder_model.to(device)
77

78
        model_config = {
79
            "model": cross_encoder_model,
80
            "tokenizer": cross_encoder_tokenizer,
81
        }
82

83
        config[model_id] = model_config
84

85
    return config
86

87

88
def predict_fn(input_object, config):
89
    logger.info("predict_fn")
90
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91

92
    current_model_id = input_object["model"].split("/")[-1]
93
    current_model_config = config.get(current_model_id)
94
    if not current_model_config:
95
        raise ValueError(f"Model {current_model_id} not found")
96

97
    current_model = current_model_config["model"]
98
    current_tokenizer = current_model_config["tokenizer"]
99

100
    if input_object["type"] == "embeddings":
101
        current_input = input_object["input"]
102
        if current_model_id == "multilingual-e5-large":
103
            if isinstance(current_input, list):
104
                current_input = list(map(lambda val: "query: " + val, current_input))
105
            else:
106
                current_input = "query: " + current_input
107

108
        with torch.inference_mode():
109
            encoded_input = current_tokenizer(
110
                current_input,
111
                padding=True,
112
                truncation=True,
113
                return_tensors="pt",
114
            )
115

116
            encoded_input = encoded_input.to(device)
117
            model_output = current_model(**encoded_input)
118

119
            input_embeddings = mean_pooling(
120
                model_output, encoded_input["attention_mask"]
121
            )
122

123
            input_embeddings = F.normalize(input_embeddings, p=2, dim=1)
124
            response = input_embeddings.cpu().numpy()
125
            ret_value = response.tolist()
126

127
            return ret_value
128
    elif input_object["type"] == "cross-encoder":
129
        current_input = input_object["input"]
130
        passages = input_object["passages"]
131
        data = [[current_input, passage] for passage in passages]
132

133
        with torch.inference_mode():
134
            features = current_tokenizer(
135
                data, padding=True, truncation=True, return_tensors="pt"
136
            )
137

138
            features = features.to(device)
139

140
            scores = current_model(**features).logits.cpu().numpy()
141
            ret_value = list(
142
                map(
143
                    lambda val: val[-1] if isinstance(val, list) else val,
144
                    scores.tolist(),
145
                )
146
            )
147

148
            return ret_value
149

150
    return []
151

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.