lmops

Форк
0
/
retrieve_bm25.py 
135 строк · 4.3 Кб
1
import hydra
2
import tqdm
3
import numpy as np
4
import json
5
from rank_bm25 import BM25Okapi
6
import multiprocessing
7
from DPR.dpr.utils.tasks import task_map, get_prompt_files
8
from DPR.dpr.utils.data_utils import read_data_from_json_files
9
import os
10
import logging
11

12
logger = logging.getLogger(__name__)
13

14

15
class BM25Finder:
16
    def __init__(self, cfg) -> None:
17
        self.prompt_setup_type = cfg.prompt_setup_type
18
        assert self.prompt_setup_type in ["q", "qa", "a"]
19

20
        # prompt_pool
21
        if cfg.train_clusters is not None:
22
            prompt_pool_path = get_prompt_files(cfg.prompt_pool_path, cfg.train_clusters)
23
        else:
24
            prompt_pool_path = cfg.prompt_pool_path
25
        logger.info("prompt files: %s", prompt_pool_path)
26
        self.prompt_pool = read_data_from_json_files(prompt_pool_path)
27
        logger.info("prompt passages num : %d", len(self.prompt_pool))
28

29
        logger.info("started creating the corpus")
30
        self.corpus = [self.tokenize_prompt(prompt) for prompt in self.prompt_pool]
31
        self.bm25 = BM25Okapi(self.corpus)
32
        logger.info("finished creating the corpus")
33

34
    def tokenize_prompt(self, entry):
35
        task = task_map.cls_dic[entry["task_name"]]()
36
        if self.prompt_setup_type == "q":
37
            prompt = task.get_question(entry)
38
        elif self.prompt_setup_type == "a":
39
            prompt = task.get_answer(entry)
40
        elif self.prompt_setup_type == "qa":
41
            prompt = (
42
                task.get_question(entry)
43
                + " "
44
                + task.get_answer(entry)
45
            )
46
        return self.tokenize(prompt)
47

48
    def tokenize(self, text):
49
        return text.strip().split()
50

51
    def detokenize(self, tokens):
52
        return " ".join(tokens)
53

54

55
def search(tokenized_query, idx, n_docs):
56
    bm25 = bm25_global
57
    scores = bm25.get_scores(tokenized_query)
58
    near_ids = list(np.argsort(scores)[::-1][:n_docs])
59
    sorted_scores = list(np.sort(scores)[::-1][:n_docs])
60
    return near_ids, sorted_scores, idx
61

62

63
def _search(args):
64
    tokenized_query, idx, n_docs = args
65
    return search(tokenized_query, idx, n_docs)
66

67

68
class GlobalState:
69
    def __init__(self, bm25) -> None:
70
        self.bm25 = bm25
71

72

73
def find(cfg):
74
    finder = BM25Finder(cfg)
75

76
    def set_global_object(bm25):
77
        global bm25_global
78
        bm25_global = bm25
79

80
    pool = multiprocessing.Pool(
81
        processes=None, initializer=set_global_object, initargs=(finder.bm25,)
82
    )
83
    task_name = cfg.task_name
84
    logger.info("search for %s", task_name)
85
    task = task_map.cls_dic[task_name]()
86
    # get the evaluation data split
87
    dataset = task.get_dataset(cache_dir=cfg.cache_dir)
88
    get_question = task.get_question
89
    tokenized_queries = []
90
    for id, entry in enumerate(dataset):
91
        entry["id"] = id
92
        question = get_question(entry)
93
        tokenized_queries.append(finder.tokenize(question))
94
    cntx_pre = [
95
        [tokenized_query, idx, cfg.n_docs]
96
        for idx, tokenized_query in enumerate(tokenized_queries)
97
    ]
98
    cntx_post = {}
99
    with tqdm.tqdm(total=len(cntx_pre)) as pbar:
100
        for res in pool.imap_unordered(_search, cntx_pre):
101
            pbar.update()
102
            ctx_ids, ctx_scores, idx = res
103
            cntx_post[idx] = (ctx_ids, ctx_scores)
104
    merged_data = []
105
    for idx, data in enumerate(dataset):
106
        ctx_ids, ctx_scores = cntx_post[idx]
107
        merged_data.append(
108
            {
109
                "instruction": finder.detokenize(tokenized_queries[idx]),
110
                "meta_data": data,
111
                "ctxs": [
112
                    {
113
                        "prompt_pool_id": str(prompt_pool_id),
114
                        "passage": finder.detokenize(finder.corpus[prompt_pool_id]),
115
                        "score": str(ctx_scores[i]),
116
                        "meta_data": finder.prompt_pool[prompt_pool_id],
117
                    }
118
                    for i, prompt_pool_id in enumerate(ctx_ids)
119
                ],
120
            }
121
        )
122
    with open(cfg.out_file, "w") as writer:
123
        writer.write(json.dumps(merged_data, indent=4) + "\n")
124
    logger.info("Saved results * scores to %s", cfg.out_file)
125

126

127
@hydra.main(config_path="configs", config_name="bm25_retriever")
128
def main(cfg):
129
    print(cfg)
130
    os.makedirs(os.path.dirname(cfg.out_file), exist_ok=True)
131
    find(cfg)
132

133

134
if __name__ == "__main__":
135
    main()
136

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.