dream
91 строка · 3.2 Кб
1# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import json
16from logging import getLogger
17from typing import List, Any, Tuple
18
19from pyserini.search import SimpleSearcher
20
21from deeppavlov.core.commands.utils import expand_path
22from deeppavlov.core.common.registry import register
23from deeppavlov.core.models.estimator import Component
24
25logger = getLogger(__name__)
26
27
28@register("pyserini_ranker")
29class PyseriniRanker(Component):
30def __init__(
31self,
32index_folder: str,
33n_threads: int = 1,
34top_n: int = 5,
35text_column_name: str = "contents",
36return_scores: bool = False,
37*args,
38**kwargs,
39):
40self.searcher = SimpleSearcher(str(expand_path(index_folder)))
41self.n_threads = n_threads
42self.top_n = top_n
43self.text_column_name = text_column_name
44self.return_scores = return_scores
45
46def __call__(self, questions: List[str]) -> Tuple[List[Any], List[float]]:
47docs_batch = []
48scores_batch = []
49_doc_ids_batch = []
50
51if len(questions) == 1:
52for question in questions:
53res = self.searcher.search(question, self.top_n)
54docs, doc_ids, scores = self._processing_search_result(res)
55docs_batch.append(docs)
56scores_batch.append(scores)
57_doc_ids_batch.append(doc_ids)
58else:
59n_batches = len(questions) // self.n_threads + int(len(questions) % self.n_threads > 0)
60for i in range(n_batches):
61questions_cur = questions[i * self.n_threads : (i + 1) * self.n_threads]
62qids_cur = list(range(len(questions_cur)))
63res_batch = self.searcher.batch_search(questions_cur, qids_cur, self.top_n, self.n_threads)
64for qid in qids_cur:
65res = res_batch.get(qid)
66docs, doc_ids, scores = self._processing_search_result(res)
67docs_batch.append(docs)
68scores_batch.append(scores)
69_doc_ids_batch.append(doc_ids)
70
71logger.debug(f"found docs {_doc_ids_batch}")
72
73if self.return_scores:
74return docs_batch, scores_batch
75else:
76return docs_batch
77
78@staticmethod
79def _processing_search_result(res):
80docs = []
81doc_ids = []
82scores = []
83for elem in res:
84doc = json.loads(elem.raw)
85score = elem.score
86if doc and isinstance(doc, dict):
87docs.append(doc.get("contents", ""))
88doc_ids.append(doc.get("id", ""))
89scores.append(score)
90
91return docs, doc_ids, scores
92