5
from rank_bm25 import BM25Okapi
7
from DPR.dpr.utils.tasks import task_map, get_prompt_files
8
from DPR.dpr.utils.data_utils import read_data_from_json_files
12
logger = logging.getLogger(__name__)
16
def __init__(self, cfg) -> None:
17
self.prompt_setup_type = cfg.prompt_setup_type
18
assert self.prompt_setup_type in ["q", "qa", "a"]
21
if cfg.train_clusters is not None:
22
prompt_pool_path = get_prompt_files(cfg.prompt_pool_path, cfg.train_clusters)
24
prompt_pool_path = cfg.prompt_pool_path
25
logger.info("prompt files: %s", prompt_pool_path)
26
self.prompt_pool = read_data_from_json_files(prompt_pool_path)
27
logger.info("prompt passages num : %d", len(self.prompt_pool))
29
logger.info("started creating the corpus")
30
self.corpus = [self.tokenize_prompt(prompt) for prompt in self.prompt_pool]
31
self.bm25 = BM25Okapi(self.corpus)
32
logger.info("finished creating the corpus")
34
def tokenize_prompt(self, entry):
35
task = task_map.cls_dic[entry["task_name"]]()
36
if self.prompt_setup_type == "q":
37
prompt = task.get_question(entry)
38
elif self.prompt_setup_type == "a":
39
prompt = task.get_answer(entry)
40
elif self.prompt_setup_type == "qa":
42
task.get_question(entry)
44
+ task.get_answer(entry)
46
return self.tokenize(prompt)
48
def tokenize(self, text):
49
return text.strip().split()
51
def detokenize(self, tokens):
52
return " ".join(tokens)
55
def search(tokenized_query, idx, n_docs):
57
scores = bm25.get_scores(tokenized_query)
58
near_ids = list(np.argsort(scores)[::-1][:n_docs])
59
sorted_scores = list(np.sort(scores)[::-1][:n_docs])
60
return near_ids, sorted_scores, idx
64
tokenized_query, idx, n_docs = args
65
return search(tokenized_query, idx, n_docs)
69
def __init__(self, bm25) -> None:
74
finder = BM25Finder(cfg)
76
def set_global_object(bm25):
80
pool = multiprocessing.Pool(
81
processes=None, initializer=set_global_object, initargs=(finder.bm25,)
83
task_name = cfg.task_name
84
logger.info("search for %s", task_name)
85
task = task_map.cls_dic[task_name]()
86
# get the evaluation data split
87
dataset = task.get_dataset(cache_dir=cfg.cache_dir)
88
get_question = task.get_question
89
tokenized_queries = []
90
for id, entry in enumerate(dataset):
92
question = get_question(entry)
93
tokenized_queries.append(finder.tokenize(question))
95
[tokenized_query, idx, cfg.n_docs]
96
for idx, tokenized_query in enumerate(tokenized_queries)
99
with tqdm.tqdm(total=len(cntx_pre)) as pbar:
100
for res in pool.imap_unordered(_search, cntx_pre):
102
ctx_ids, ctx_scores, idx = res
103
cntx_post[idx] = (ctx_ids, ctx_scores)
105
for idx, data in enumerate(dataset):
106
ctx_ids, ctx_scores = cntx_post[idx]
109
"instruction": finder.detokenize(tokenized_queries[idx]),
113
"prompt_pool_id": str(prompt_pool_id),
114
"passage": finder.detokenize(finder.corpus[prompt_pool_id]),
115
"score": str(ctx_scores[i]),
116
"meta_data": finder.prompt_pool[prompt_pool_id],
118
for i, prompt_pool_id in enumerate(ctx_ids)
122
with open(cfg.out_file, "w") as writer:
123
writer.write(json.dumps(merged_data, indent=4) + "\n")
124
logger.info("Saved results * scores to %s", cfg.out_file)
127
@hydra.main(config_path="configs", config_name="bm25_retriever")
130
os.makedirs(os.path.dirname(cfg.out_file), exist_ok=True)
134
if __name__ == "__main__":