1
import hydra.utils as hu
3
from hydra.core.hydra_config import HydraConfig
6
from torch.utils.data import DataLoader
7
from src.data.collators import DataCollatorWithPaddingAndCuda
11
from DPR.dpr.utils.tasks import task_map, test_cluster_map, train_cluster_map,get_prompt_files
12
from DPR.dpr.utils.data_utils import read_data_from_json_files
13
from src.dataset_readers.indexer_dsr import IndexerDatasetReader
14
from transformers import AutoTokenizer
18
logger = logging.getLogger(__name__)
21
def __init__(self, cfg) -> None:
23
self.cuda_device = cfg.cuda_device
24
self.tokenizer=AutoTokenizer.from_pretrained(cfg.model_name,cache_dir=cfg.cache_dir)
25
self.model = hu.instantiate(cfg.model).to(self.cuda_device)
26
self.index = faiss.IndexIDMap(faiss.IndexFlatIP(768))
27
self.co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer,device = self.cuda_device)
28
self.n_docs=cfg.n_docs
29
self.prompt_setup_type=cfg.prompt_setup_type
31
def get_prompt_loader(self):
33
if self.cfg.train_clusters is not None:
34
prompt_pool_path = get_prompt_files(self.cfg.prompt_pool_path, self.cfg.train_clusters)
36
prompt_pool_path = self.cfg.prompt_pool_path
37
logger.info("prompt files: %s", prompt_pool_path)
38
self.prompt_pool = read_data_from_json_files(prompt_pool_path)
39
logger.info("prompt passages num : %d", len(self.prompt_pool))
41
self.corpus=[{'instruction':self.format_prompt(entry)} for entry in self.prompt_pool]
42
prompt_reader = IndexerDatasetReader(self.tokenizer, self.corpus)
43
prompt_loader = DataLoader(prompt_reader, batch_size=self.cfg.batch_size, collate_fn=self.co)
46
def create_index(self):
47
prompt_loader=self.get_prompt_loader()
48
for entry in tqdm.tqdm(prompt_loader):
50
metadata = entry.pop("metadata")
51
res = self.model(**entry)
52
id_list = np.array([m['id'] for m in metadata])
53
self.index.add_with_ids(res.cpu().detach().numpy(), id_list)
55
def format_prompt(self,entry):
56
task=task_map.cls_dic[entry['task_name']]()
57
if self.prompt_setup_type=='q':
58
prompt = task.get_question(entry)
59
elif self.prompt_setup_type=='a':
60
prompt= task.get_answer(entry)
61
elif self.prompt_setup_type=='qa':
62
prompt=task.get_question(entry)+' '+task.get_answer(entry)
67
for i,entry in enumerate(tqdm.tqdm(self.dataloader)):
69
res = self.model(**entry)
70
res = res.cpu().detach().numpy()
71
res_list.extend([{"res":r,"metadata":m} for r,m in zip(res,entry['metadata'])])
75
def search(self,entry):
76
res = np.expand_dims(entry['res'],axis=0)
77
scores, near_ids = self.index.search(res,self.n_docs)
78
return near_ids[0],scores[0]
82
res_list = self.forward()
84
for entry in tqdm.tqdm(res_list):
85
id=entry['metadata']['id']
86
ctx_ids,ctx_scores = self.search(entry)
87
cntx_post[id]=(ctx_ids,ctx_scores)
91
finder = KNNFinder(cfg)
93
task_name=cfg.task_name
94
logger.info("search for %s", task_name)
95
task=task_map.cls_dic[task_name]()
96
dataset=task.get_dataset(cache_dir=cfg.cache_dir)
97
get_question = task.get_question
99
for id, entry in enumerate(dataset):
101
question=get_question(entry)
102
queries.append({'instruction':question})
103
dataset_reader = IndexerDatasetReader(finder.tokenizer, queries)
104
finder.dataloader=DataLoader(dataset_reader, batch_size=cfg.batch_size, collate_fn=finder.co)
105
cntx_post=finder._find()
108
for idx,data in enumerate(dataset):
109
ctx_ids,ctx_scores=cntx_post[idx]
112
"instruction": queries[idx],
116
"prompt_pool_id": str(prompt_pool_id),
117
"passage": finder.corpus[prompt_pool_id],
118
"score": str(ctx_scores[i]),
119
"meta_data":finder.prompt_pool[prompt_pool_id]
121
for i,prompt_pool_id in enumerate(ctx_ids)
125
with open(cfg.out_file, "w") as writer:
126
writer.write(json.dumps(merged_data, indent=4) + "\n")
127
logger.info("Saved results * scores to %s", cfg.out_file)
130
@hydra.main(config_path="configs",config_name="sbert_retriever")
133
os.makedirs(os.path.dirname(cfg.out_file), exist_ok=True)
137
if __name__ == "__main__":