fastrag
58 строк · 2.0 Кб
1import logging2import os3import pathlib4
5import tqdm6from beir import LoggingHandler, util7from beir.datasets.data_loader import GenericDataLoader8from beir.retrieval.evaluation import EvaluateRetrieval9from haystack import Pipeline10from haystack.document_stores import ElasticsearchDocumentStore11from haystack.nodes import BM25Retriever, SentenceTransformersRanker12
13logging.getLogger().setLevel(logging.INFO)14logging.basicConfig(15format="%(asctime)s - %(message)s",16datefmt="%Y-%m-%d %H:%M:%S",17level=logging.INFO,18handlers=[LoggingHandler()],19)
20
21dataset = "msmarco"22url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)23out_dir = os.path.join(pathlib.Path(".").absolute(), "benchmarks", "datasets")24data_path = util.download_and_unzip(url, out_dir)25
26logging.info("Loading dataset...")27corpus, queries, qrels = GenericDataLoader(data_path).load(split="dev")28
29k_values = [10, 20, 50, 100]30beir_retriever = EvaluateRetrieval(k_values=k_values)31
32document_store = ElasticsearchDocumentStore(host="localhost", index="msmarco_index", port=80)33
34retriever = BM25Retriever(document_store=document_store, top_k=100)35
36reranker = SentenceTransformersRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2")37
38p = Pipeline()39p.add_node(component=retriever, name="Retriever", inputs=["Query"])40p.add_node(component=reranker, name="Reranker", inputs=["Retriever"])41
42
43def retrieve_queries(queries: dict, k: int):44ans = {}45for q, query in tqdm.tqdm(queries.items(), "queries"):46ans[q] = {47doc.id: doc.score48for doc in p.run(query, params={"Reranker": {"top_k": k}})["documents"]49}50return ans51
52
53logging.info("Querying documents...")54results = retrieve_queries(queries, max(k_values))55
56#### Evaluate the retrieval using NDCG@k, MAP@K ...
57logging.info("Retriever evaluation for k in: {}".format(beir_retriever.k_values))58ndcg, _map, recall, precision = beir_retriever.evaluate(qrels, results, beir_retriever.k_values)59