dream
147 строк · 5.8 Кб
1# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import re16import time17from logging import getLogger18from typing import List, Any, Tuple19
20import numpy as np21import pymorphy222
23from deeppavlov.core.common.registry import register24from deeppavlov.core.models.estimator import Component25from deeppavlov.models.vectorizers.hashing_tfidf_vectorizer import HashingTfIdfVectorizer26
27logger = getLogger(__name__)28
29
30@register("tfidf_ranker")31class TfidfRanker(Component):32"""Rank documents according to input strings.33
34Args:
35vectorizer: a vectorizer class
36top_n: a number of doc ids to return
37active: whether to return a number specified by :attr:`top_n` (``True``) or all ids
38(``False``)
39
40Attributes:
41top_n: a number of doc ids to return
42vectorizer: an instance of vectorizer class
43active: whether to return a number specified by :attr:`top_n` or all ids
44index2doc: inverted :attr:`doc_index`
45iterator: a dataset iterator used for generating batches while fitting the vectorizer
46
47"""
48
49def __init__(50self,51vectorizer: HashingTfIdfVectorizer,52top_n=5,53out_top_n=5,54active: bool = True,55filter_flag: bool = False,56**kwargs,57):58self.top_n = top_n59self.out_top_n = out_top_n60self.vectorizer = vectorizer61self.active = active62self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")63self.lemmatizer = pymorphy2.MorphAnalyzer()64self.filter_flag = filter_flag65self.numbers = 066
67def __call__(68self, questions: List[str], entity_substr_batch: List[List[str]] = None, tags_batch: List[List[str]] = None69) -> Tuple[List[Any], List[float]]:70"""Rank documents and return top n document titles with scores.71
72Args:
73questions: list of queries used in ranking
74
75Returns:
76a tuple of selected doc ids and their scores
77"""
78
79tm_st = time.time()80batch_doc_ids, batch_docs_scores = [], []81
82q_tfidfs = self.vectorizer(questions)83if entity_substr_batch is None:84entity_substr_batch = [[] for _ in questions]85tags_batch = [[] for _ in questions]86
87for question, q_tfidf, entity_substr_list, tags_list in zip(88questions, q_tfidfs, entity_substr_batch, tags_batch89):90if self.filter_flag:91entity_substr_for_search = []92if entity_substr_list and not tags_list:93tags_list = ["NOUN" for _ in entity_substr_list]94for entity_substr, tag in zip(entity_substr_list, tags_list):95if tag in {"PER", "PERSON", "PRODUCT", "WORK_OF_ART", "COUNTRY", "ORGANIZATION", "NOUN"}:96entity_substr_for_search.append(entity_substr)97if not entity_substr_for_search:98for entity_substr, tag in zip(entity_substr_list, tags_list):99if tag in {"LOCATION", "LOC", "ORG"}:100entity_substr_for_search.append(entity_substr)101if not entity_substr_for_search:102question_tokens = re.findall(self.re_tokenizer, question)103for question_token in question_tokens:104if self.lemmatizer.parse(question_token)[0].tag.POS == "NOUN" and self.lemmatizer.parse(105question_token
106)[0].normal_form not in {"мир", "земля", "планета", "человек"}:107entity_substr_for_search.append(question_token)108
109nonzero_scores = set()110
111if entity_substr_for_search:112ent_tfidf = self.vectorizer([", ".join(entity_substr_for_search)])[0]113ent_scores = ent_tfidf * self.vectorizer.tfidf_matrix114ent_scores = np.squeeze(ent_scores.toarray())115nonzero_scores = set(np.nonzero(ent_scores)[0])116
117scores = q_tfidf * self.vectorizer.tfidf_matrix118scores = np.squeeze(scores.toarray() + 0.0001) # add a small value to eliminate zero scores119
120if self.active:121thresh = self.top_n122else:123thresh = len(self.vectorizer.doc_index)124
125if thresh >= len(scores):126o = np.argpartition(-scores, len(scores) - 1)[0:thresh]127else:128o = np.argpartition(-scores, thresh)[0:thresh]129o_sort = o[np.argsort(-scores[o])]130
131filtered_o_sort = []132if self.filter_flag and nonzero_scores:133filtered_o_sort = [elem for elem in o_sort if elem in nonzero_scores]134if filtered_o_sort:135filtered_o_sort = np.array(filtered_o_sort)136if isinstance(filtered_o_sort, list):137filtered_o_sort = o_sort138
139doc_scores = scores[filtered_o_sort].tolist()140doc_ids = [self.vectorizer.index2doc.get(i, "") for i in filtered_o_sort]141
142batch_doc_ids.append(doc_ids[: self.out_top_n])143batch_docs_scores.append(doc_scores[: self.out_top_n])144tm_end = time.time()145logger.info(f"tfidf ranking time: {tm_end - tm_st} num doc_ids {len(batch_doc_ids[0])}")146
147return batch_doc_ids, batch_docs_scores148