Langchain-Chatchat
79 строк · 3.3 Кб
1'''
2该功能是为了将关键词加入到embedding模型中,以便于在embedding模型中进行关键词的embedding
3该功能的实现是通过修改embedding模型的tokenizer来实现的
4该功能仅仅对EMBEDDING_MODEL参数对应的的模型有效,输出后的模型保存在原本模型
5感谢@CharlesJu1和@charlesyju的贡献提出了想法和最基础的PR
6
7保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
8'''
9import sys10
11sys.path.append("..")12import os13import torch14
15from datetime import datetime16from configs import (17MODEL_PATH,18EMBEDDING_MODEL,19EMBEDDING_KEYWORD_FILE,20)
21
22from safetensors.torch import save_model23from sentence_transformers import SentenceTransformer24from langchain_core._api import deprecated25
26
27@deprecated(28since="0.3.0",29message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",30removal="0.3.0"31)32def get_keyword_embedding(bert_model, tokenizer, key_words):33tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)34input_ids = tokenizer_output['input_ids']35input_ids = input_ids[:, 1:-1]36
37keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)38keyword_embedding = torch.mean(keyword_embedding, 1)39return keyword_embedding40
41
42def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", output_model_path: str = None):43key_words = []44with open(keyword_file, "r") as f:45for line in f:46key_words.append(line.strip())47
48st_model = SentenceTransformer(model_name)49key_words_len = len(key_words)50word_embedding_model = st_model._first_module()51bert_model = word_embedding_model.auto_model52tokenizer = word_embedding_model.tokenizer53key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)54
55embedding_weight = bert_model.embeddings.word_embeddings.weight56embedding_weight_len = len(embedding_weight)57tokenizer.add_tokens(key_words)58bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)59embedding_weight = bert_model.embeddings.word_embeddings.weight60with torch.no_grad():61embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding62
63if output_model_path:64os.makedirs(output_model_path, exist_ok=True)65word_embedding_model.save(output_model_path)66safetensors_file = os.path.join(output_model_path, "model.safetensors")67metadata = {'format': 'pt'}68save_model(bert_model, safetensors_file, metadata)69print("save model to {}".format(output_model_path))70
71
72def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):73keyword_file = os.path.join(path)74model_name = MODEL_PATH["embed_model"][EMBEDDING_MODEL]75model_parent_directory = os.path.dirname(model_name)76current_time = datetime.now().strftime('%Y%m%d_%H%M%S')77output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)78output_model_path = os.path.join(model_parent_directory, output_model_name)79add_keyword_to_model(model_name, keyword_file, output_model_path)80