paddlenlp
177 строк · 6.5 Кб
1# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import csv
15import math
16import time
17from collections import defaultdict
18from typing import Dict, List, cast
19
20from datasets import load_dataset
21from mteb.abstasks import AbsTaskRetrieval
22
23from paddlenlp import Taskflow
24
25csv.field_size_limit(500 * 1024 * 1024)
26
27
28class PaddleModel:
29def __init__(
30self, query_model, corpus_model, batch_size=16, max_seq_len=512, sep=" ", pooling_mode="mean_tokens", **kwargs
31):
32self.query_model = Taskflow(
33"feature_extraction",
34model=query_model,
35pooling_mode=pooling_mode,
36max_seq_len=max_seq_len,
37batch_size=batch_size,
38_static_mode=True,
39)
40self.corpus_model = Taskflow(
41"feature_extraction",
42model=corpus_model,
43pooling_mode=pooling_mode,
44max_seq_len=max_seq_len,
45batch_size=batch_size,
46_static_mode=True,
47)
48self.sep = sep
49
50def encode_queries(self, queries: List[str], batch_size: int, **kwargs):
51return self.query_model(queries, batch_size=batch_size, **kwargs)["features"].detach().cpu().numpy()
52
53def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
54if type(corpus) is dict:
55sentences = [
56(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
57if "title" in corpus
58else corpus["text"][i].strip()
59for i in range(len(corpus["text"]))
60]
61else:
62sentences = [
63(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
64for doc in corpus
65]
66return self.corpus_model(sentences, batch_size=batch_size, **kwargs)["features"].detach().cpu().numpy()
67
68
69class T2RRetrieval(AbsTaskRetrieval):
70def __init__(self, num_max_passages: "int | None" = None, **kwargs):
71super().__init__(**kwargs)
72self.num_max_passages = num_max_passages or math.inf
73
74@property
75def description(self):
76return {
77"name": "T2RankingRetrieval",
78"reference": "https://huggingface.co/datasets/THUIR/T2Ranking",
79"type": "Retrieval",
80"category": "s2p",
81"eval_splits": ["dev"],
82"eval_langs": ["zh"],
83"main_score": "ndcg_at_10",
84}
85
86def evaluate(
87self,
88model_query,
89model_corpus,
90split="test",
91batch_size=32,
92corpus_chunk_size=None,
93target_devices=None,
94score_function="cos_sim",
95**kwargs
96):
97from beir.retrieval.evaluation import EvaluateRetrieval
98
99if not self.data_loaded:
100self.load_data()
101corpus, queries, relevant_docs = self.corpus[split], self.queries[split], self.relevant_docs[split]
102
103from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
104
105model = PaddleModel(model_query, model_corpus)
106
107model = DRES(
108model,
109batch_size=batch_size,
110corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
111**kwargs,
112)
113retriever = EvaluateRetrieval(model, score_function=score_function) # or "cos_sim" or "dot"
114start_time = time.time()
115results = retriever.retrieve(corpus, queries)
116end_time = time.time()
117print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
118
119ndcg, _map, recall, precision = retriever.evaluate(relevant_docs, results, retriever.k_values)
120mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr")
121
122scores = {
123**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
124**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
125**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
126**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
127**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
128}
129print(scores)
130return scores
131
132def load_data(self, **kwargs):
133corpus, queries, qrels = load_t2ranking_for_retraviel(self.num_max_passages)
134self.corpus, self.queries, self.relevant_docs = {}, {}, {}
135self.corpus["dev"] = corpus
136self.queries["dev"] = queries
137self.relevant_docs["dev"] = qrels
138self.data_loaded = True
139
140
141def load_t2ranking_for_retraviel(num_max_passages: float):
142collection_dataset = load_dataset("THUIR/T2Ranking", "collection")["train"] # type: ignore
143dev_queries_dataset = load_dataset("THUIR/T2Ranking", "queries.dev")["train"] # type: ignore
144dev_rels_dataset = load_dataset("THUIR/T2Ranking", "qrels.dev")["train"] # type: ignore
145corpus = {}
146for index in range(min(len(collection_dataset), num_max_passages)):
147record = collection_dataset[index]
148record = cast(dict, record)
149pid: int = record["pid"]
150corpus[str(pid)] = {"text": record["text"]}
151queries = {}
152for record in dev_queries_dataset:
153record = cast(dict, record)
154queries[str(record["qid"])] = record["text"]
155
156all_qrels = defaultdict(dict)
157for record in dev_rels_dataset:
158record = cast(dict, record)
159pid: int = record["pid"]
160if pid > num_max_passages:
161continue
162all_qrels[str(record["qid"])][str(record["pid"])] = record["rel"]
163valid_qrels = {}
164for qid, qrels in all_qrels.items():
165if len(set(list(qrels.values())) - set([0])) >= 1:
166valid_qrels[qid] = qrels
167valid_queries = {}
168for qid, query in queries.items():
169if qid in valid_qrels:
170valid_queries[qid] = query
171print(f"valid qrels: {len(valid_qrels)}")
172return corpus, valid_queries, valid_qrels
173
174
175if __name__ == "__main__":
176tasks = T2RRetrieval(num_max_passages=10000)
177tasks.evaluate(model_query="moka-ai/m3e-base", model_corpus="moka-ai/m3e-base", split="dev")
178