fastrag

Форк
0
/
msmarco-bm25-sbert.py 
58 строк · 2.0 Кб
1
import logging
2
import os
3
import pathlib
4

5
import tqdm
6
from beir import LoggingHandler, util
7
from beir.datasets.data_loader import GenericDataLoader
8
from beir.retrieval.evaluation import EvaluateRetrieval
9
from haystack import Pipeline
10
from haystack.document_stores import ElasticsearchDocumentStore
11
from haystack.nodes import BM25Retriever, SentenceTransformersRanker
12

13
logging.getLogger().setLevel(logging.INFO)
14
logging.basicConfig(
15
    format="%(asctime)s - %(message)s",
16
    datefmt="%Y-%m-%d %H:%M:%S",
17
    level=logging.INFO,
18
    handlers=[LoggingHandler()],
19
)
20

21
dataset = "msmarco"
22
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
23
out_dir = os.path.join(pathlib.Path(".").absolute(), "benchmarks", "datasets")
24
data_path = util.download_and_unzip(url, out_dir)
25

26
logging.info("Loading dataset...")
27
corpus, queries, qrels = GenericDataLoader(data_path).load(split="dev")
28

29
k_values = [10, 20, 50, 100]
30
beir_retriever = EvaluateRetrieval(k_values=k_values)
31

32
document_store = ElasticsearchDocumentStore(host="localhost", index="msmarco_index", port=80)
33

34
retriever = BM25Retriever(document_store=document_store, top_k=100)
35

36
reranker = SentenceTransformersRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2")
37

38
p = Pipeline()
39
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
40
p.add_node(component=reranker, name="Reranker", inputs=["Retriever"])
41

42

43
def retrieve_queries(queries: dict, k: int):
44
    ans = {}
45
    for q, query in tqdm.tqdm(queries.items(), "queries"):
46
        ans[q] = {
47
            doc.id: doc.score
48
            for doc in p.run(query, params={"Reranker": {"top_k": k}})["documents"]
49
        }
50
    return ans
51

52

53
logging.info("Querying documents...")
54
results = retrieve_queries(queries, max(k_values))
55

56
#### Evaluate the retrieval using NDCG@k, MAP@K ...
57
logging.info("Retriever evaluation for k in: {}".format(beir_retriever.k_values))
58
ndcg, _map, recall, precision = beir_retriever.evaluate(qrels, results, beir_retriever.k_values)
59

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.