fastrag
64 строки · 2.0 Кб
1import logging2import os3import pathlib4
5import tqdm6from beir import LoggingHandler, util7from beir.datasets.data_loader import GenericDataLoader8from beir.retrieval.evaluation import EvaluateRetrieval9from haystack import Pipeline10
11from fastrag.retrievers.colbert import ColBERTRetriever12from fastrag.stores import PLAIDDocumentStore13
14logging.getLogger().setLevel(logging.INFO)15logging.basicConfig(16format="%(asctime)s - %(message)s",17datefmt="%Y-%m-%d %H:%M:%S",18level=logging.INFO,19handlers=[LoggingHandler()],20)
21
22dataset = "nq"23url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)24out_dir = os.path.join(pathlib.Path(".").absolute(), "benchmarks", "datasets")25data_path = util.download_and_unzip(url, out_dir)26
27logging.info("Loading dataset...")28corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")29
30k_values = [10, 20, 50, 100]31
32beir_retriever = EvaluateRetrieval(k_values=k_values)33
34logging.info("Loading PLAID index...")35document_store = PLAIDDocumentStore(36index_path="/path/to/index",37checkpoint_path="/path/to/checkpoint",38collection_path="/path/to/collection.tsv",39)
40
41retriever = ColBERTRetriever(document_store=document_store)42
43p = Pipeline()44p.add_node(component=retriever, name="Retriever", inputs=["Query"])45
46
47def retrieve_queries(queries: dict, k: int):48ans = {}49for q, query in tqdm.tqdm(queries.items(), "queries"):50ans[q] = {51# need to concat "doc" to the PLAID IDs to match the qrels IDs52"doc" + str(doc.id): doc.score53for doc in p.run(query, params={"Retriever": {"top_k": k}})["documents"]54}55return ans56
57
58logging.info("Querying documents...")59results = retrieve_queries(queries, max(k_values))60
61
62#### Evaluate the retrieval using NDCG@k, MAP@K ...
63logging.info("Retriever evaluation for k in: {}".format(beir_retriever.k_values))64ndcg, _map, recall, precision = beir_retriever.evaluate(qrels, results, beir_retriever.k_values)65