Langchain-Chatchat

Форк
0
/
add_embedding_keywords.py 
79 строк · 3.3 Кб
1
'''
2
该功能是为了将关键词加入到embedding模型中,以便于在embedding模型中进行关键词的embedding
3
该功能的实现是通过修改embedding模型的tokenizer来实现的
4
该功能仅仅对EMBEDDING_MODEL参数对应的的模型有效,输出后的模型保存在原本模型
5
感谢@CharlesJu1和@charlesyju的贡献提出了想法和最基础的PR
6

7
保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
8
'''
9
import sys
10

11
sys.path.append("..")
12
import os
13
import torch
14

15
from datetime import datetime
16
from configs import (
17
    MODEL_PATH,
18
    EMBEDDING_MODEL,
19
    EMBEDDING_KEYWORD_FILE,
20
)
21

22
from safetensors.torch import save_model
23
from sentence_transformers import SentenceTransformer
24
from langchain_core._api import deprecated
25

26

27
@deprecated(
28
        since="0.3.0",
29
        message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
30
        removal="0.3.0"
31
    )
32
def get_keyword_embedding(bert_model, tokenizer, key_words):
33
    tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
34
    input_ids = tokenizer_output['input_ids']
35
    input_ids = input_ids[:, 1:-1]
36

37
    keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
38
    keyword_embedding = torch.mean(keyword_embedding, 1)
39
    return keyword_embedding
40

41

42
def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", output_model_path: str = None):
43
    key_words = []
44
    with open(keyword_file, "r") as f:
45
        for line in f:
46
            key_words.append(line.strip())
47

48
    st_model = SentenceTransformer(model_name)
49
    key_words_len = len(key_words)
50
    word_embedding_model = st_model._first_module()
51
    bert_model = word_embedding_model.auto_model
52
    tokenizer = word_embedding_model.tokenizer
53
    key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
54

55
    embedding_weight = bert_model.embeddings.word_embeddings.weight
56
    embedding_weight_len = len(embedding_weight)
57
    tokenizer.add_tokens(key_words)
58
    bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
59
    embedding_weight = bert_model.embeddings.word_embeddings.weight
60
    with torch.no_grad():
61
        embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
62

63
    if output_model_path:
64
        os.makedirs(output_model_path, exist_ok=True)
65
        word_embedding_model.save(output_model_path)
66
        safetensors_file = os.path.join(output_model_path, "model.safetensors")
67
        metadata = {'format': 'pt'}
68
        save_model(bert_model, safetensors_file, metadata)
69
        print("save model to {}".format(output_model_path))
70

71

72
def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
73
    keyword_file = os.path.join(path)
74
    model_name = MODEL_PATH["embed_model"][EMBEDDING_MODEL]
75
    model_parent_directory = os.path.dirname(model_name)
76
    current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
77
    output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
78
    output_model_path = os.path.join(model_parent_directory, output_model_name)
79
    add_keyword_to_model(model_name, keyword_file, output_model_path)
80

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.