lmops

Форк
0
/
retrieve_sbert.py 
138 строк · 5.1 Кб
1
import hydra.utils as hu 
2
import hydra
3
from hydra.core.hydra_config import HydraConfig
4
import torch
5
import tqdm
6
from torch.utils.data import DataLoader
7
from src.data.collators import DataCollatorWithPaddingAndCuda
8
import faiss
9
import numpy as np
10
import json
11
from DPR.dpr.utils.tasks import task_map, test_cluster_map, train_cluster_map,get_prompt_files
12
from DPR.dpr.utils.data_utils import read_data_from_json_files
13
from src.dataset_readers.indexer_dsr import IndexerDatasetReader
14
from transformers import AutoTokenizer
15
import os
16
import logging
17

18
logger = logging.getLogger(__name__)
19

20
class KNNFinder:
21
    def __init__(self, cfg) -> None:
22
        self.cfg=cfg
23
        self.cuda_device = cfg.cuda_device
24
        self.tokenizer=AutoTokenizer.from_pretrained(cfg.model_name,cache_dir=cfg.cache_dir)
25
        self.model = hu.instantiate(cfg.model).to(self.cuda_device)
26
        self.index = faiss.IndexIDMap(faiss.IndexFlatIP(768))
27
        self.co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer,device = self.cuda_device)
28
        self.n_docs=cfg.n_docs
29
        self.prompt_setup_type=cfg.prompt_setup_type
30
    
31
    def get_prompt_loader(self):
32
        # prompt_pool
33
        if self.cfg.train_clusters is not None:
34
            prompt_pool_path = get_prompt_files(self.cfg.prompt_pool_path, self.cfg.train_clusters)
35
        else:
36
            prompt_pool_path = self.cfg.prompt_pool_path
37
        logger.info("prompt files: %s", prompt_pool_path)
38
        self.prompt_pool = read_data_from_json_files(prompt_pool_path)
39
        logger.info("prompt passages num : %d", len(self.prompt_pool))
40
        
41
        self.corpus=[{'instruction':self.format_prompt(entry)} for entry in self.prompt_pool]
42
        prompt_reader = IndexerDatasetReader(self.tokenizer, self.corpus)
43
        prompt_loader = DataLoader(prompt_reader, batch_size=self.cfg.batch_size, collate_fn=self.co)
44
        return prompt_loader
45

46
    def create_index(self):
47
        prompt_loader=self.get_prompt_loader()
48
        for entry in tqdm.tqdm(prompt_loader): 
49
            with torch.no_grad():
50
                metadata = entry.pop("metadata")
51
                res = self.model(**entry)
52
            id_list = np.array([m['id'] for m in metadata])
53
            self.index.add_with_ids(res.cpu().detach().numpy(), id_list)
54
    
55
    def format_prompt(self,entry):
56
        task=task_map.cls_dic[entry['task_name']]()
57
        if self.prompt_setup_type=='q':
58
            prompt = task.get_question(entry)
59
        elif self.prompt_setup_type=='a':
60
            prompt= task.get_answer(entry)
61
        elif self.prompt_setup_type=='qa':
62
            prompt=task.get_question(entry)+' '+task.get_answer(entry)
63
        return prompt
64
    
65
    def forward(self):
66
        res_list = []
67
        for i,entry in enumerate(tqdm.tqdm(self.dataloader)):
68
            with torch.no_grad():
69
                res = self.model(**entry)
70
            res = res.cpu().detach().numpy()
71
            res_list.extend([{"res":r,"metadata":m} for r,m in  zip(res,entry['metadata'])])
72

73
        return res_list
74

75
    def search(self,entry):
76
        res = np.expand_dims(entry['res'],axis=0)
77
        scores, near_ids = self.index.search(res,self.n_docs)
78
        return near_ids[0],scores[0] # task dim 0 as we expanded dim before
79
    
80
    
81
    def _find(self):
82
        res_list = self.forward()
83
        cntx_post = {}
84
        for entry in tqdm.tqdm(res_list):
85
            id=entry['metadata']['id']
86
            ctx_ids,ctx_scores = self.search(entry)
87
            cntx_post[id]=(ctx_ids,ctx_scores)
88
        return cntx_post
89
            
90
def find(cfg):
91
    finder = KNNFinder(cfg)
92
    finder.create_index()
93
    task_name=cfg.task_name
94
    logger.info("search for %s", task_name)
95
    task=task_map.cls_dic[task_name]() 
96
    dataset=task.get_dataset(cache_dir=cfg.cache_dir) 
97
    get_question = task.get_question
98
    queries=[]
99
    for id, entry in enumerate(dataset):
100
        entry['id']=id
101
        question=get_question(entry) 
102
        queries.append({'instruction':question})
103
    dataset_reader = IndexerDatasetReader(finder.tokenizer, queries)
104
    finder.dataloader=DataLoader(dataset_reader, batch_size=cfg.batch_size, collate_fn=finder.co)
105
    cntx_post=finder._find()
106
    
107
    merged_data=[]
108
    for idx,data in enumerate(dataset):
109
        ctx_ids,ctx_scores=cntx_post[idx]
110
        merged_data.append( 
111
        {
112
            "instruction": queries[idx],
113
            "meta_data": data,
114
            "ctxs": [
115
                {
116
                    "prompt_pool_id": str(prompt_pool_id), 
117
                    "passage": finder.corpus[prompt_pool_id],
118
                    "score": str(ctx_scores[i]),
119
                    "meta_data":finder.prompt_pool[prompt_pool_id] 
120
                }
121
                for i,prompt_pool_id in enumerate(ctx_ids)
122
            ],
123
        }
124
        )
125
    with open(cfg.out_file, "w") as writer:
126
        writer.write(json.dumps(merged_data, indent=4) + "\n")
127
    logger.info("Saved results * scores  to %s", cfg.out_file)
128

129

130
@hydra.main(config_path="configs",config_name="sbert_retriever")
131
def main(cfg):
132
    logger.info(cfg)
133
    os.makedirs(os.path.dirname(cfg.out_file), exist_ok=True)
134
    find(cfg)
135

136

137
if __name__ == "__main__":
138
    main()

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.