fastrag
63 строки · 1.9 Кб
1import logging
2import os
3import pathlib
4
5import tqdm
6from beir import LoggingHandler, util
7from beir.datasets.data_loader import GenericDataLoader
8from beir.retrieval.evaluation import EvaluateRetrieval
9from haystack import Pipeline
10
11from fastrag.retrievers.colbert import ColBERTRetriever
12from fastrag.stores import PLAIDDocumentStore
13
14logging.getLogger().setLevel(logging.INFO)
15logging.basicConfig(
16format="%(asctime)s - %(message)s",
17datefmt="%Y-%m-%d %H:%M:%S",
18level=logging.INFO,
19handlers=[LoggingHandler()],
20)
21
22dataset = "msmarco"
23url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
24out_dir = os.path.join(pathlib.Path(".").absolute(), "benchmarks", "datasets")
25data_path = util.download_and_unzip(url, out_dir)
26
27logging.info("Loading dataset...")
28corpus, queries, qrels = GenericDataLoader(data_path).load(split="dev")
29
30k_values = [10, 20, 50, 100]
31
32beir_retriever = EvaluateRetrieval(k_values=k_values)
33
34logging.info("Loading PLAID index...")
35document_store = PLAIDDocumentStore(
36index_path="/path/to/index",
37checkpoint_path="/path/to/checkpoint",
38collection_path="/path/to/collection.tsv",
39)
40
41retriever = ColBERTRetriever(document_store=document_store)
42
43p = Pipeline()
44p.add_node(component=retriever, name="Retriever", inputs=["Query"])
45
46
47def retrieve_queries(queries: dict, k: int):
48ans = {}
49for q, query in tqdm.tqdm(queries.items(), "queries"):
50ans[q] = {
51doc.id: doc.score
52for doc in p.run(query, params={"Retriever": {"top_k": k}})["documents"]
53}
54return ans
55
56
57logging.info("Querying documents...")
58results = retrieve_queries(queries, max(k_values))
59
60
61#### Evaluate the retrieval using NDCG@k, MAP@K ...
62logging.info("Retriever evaluation for k in: {}".format(beir_retriever.k_values))
63ndcg, _map, recall, precision = beir_retriever.evaluate(qrels, results, beir_retriever.k_values)
64