dream
247 строк · 10.3 Кб
1import json2import logging3import nltk4import os5import pickle6import re7import time8
9import numpy as np10import sentry_sdk11from flask import Flask, request, jsonify12from sentry_sdk.integrations.flask import FlaskIntegration13from deeppavlov import build_model14
15from common.fact_retrieval import topic_titles, find_topic_titles16from common.wiki_skill import find_all_titles, find_paragraph, delete_hyperlinks, WIKI_BADLIST17
18logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)19logger = logging.getLogger(__name__)20sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])21
22FILTER_FREQ = False23
24CONFIG = os.getenv("CONFIG")25CONFIG_PAGE_EXTRACTOR = os.getenv("CONFIG_WIKI")26CONFIG_WOW_PAGE_EXTRACTOR = os.getenv("CONFIG_WHOW")27N_FACTS = int(os.getenv("N_FACTS", 3))28
29DATA_GOOGLE_10K_ENG_NO_SWEARS = "common/google-10000-english-no-swears.txt"30DATA_SENTENCES = "data/sentences.pickle"31
32re_tokenizer = re.compile(r"[\w']+|[^\w ]")33
34with open(DATA_GOOGLE_10K_ENG_NO_SWEARS, "r") as fl:35lines = fl.readlines()36freq_words = [line.strip() for line in lines]37freq_words = set(freq_words[:800])38
39with open("%s" % DATA_SENTENCES, "rb") as fl:40test_sentences = pickle.load(fl)41
42try:43fact_retrieval = build_model(CONFIG, download=True)44
45with open("/root/.deeppavlov/downloads/wikidata/entity_types_sets.pickle", "rb") as fl:46entity_types_sets = pickle.load(fl)47
48page_extractor = build_model(CONFIG_PAGE_EXTRACTOR, download=True)49logger.info("model loaded, test query processed")50
51whow_page_extractor = build_model(CONFIG_WOW_PAGE_EXTRACTOR, download=True)52
53with open("/root/.deeppavlov/downloads/wikihow/wikihow_topics.json", "r") as fl:54wikihow_topics = json.load(fl)55except Exception as e:56sentry_sdk.capture_exception(e)57logger.exception(e)58raise e59
60app = Flask(__name__)61
62
63def get_page_content(page_title):64page_content = {}65try:66if page_title:67page_content_batch, main_pages_batch = page_extractor([[page_title]])68if page_content_batch and page_content_batch[0]:69page_content = page_content_batch[0][0]70except Exception as e:71sentry_sdk.capture_exception(e)72logger.exception(e)73
74return page_content75
76
77def get_wikihow_content(page_title):78page_content = {}79try:80if page_title:81page_content_batch = whow_page_extractor([[page_title]])82if page_content_batch and page_content_batch[0]:83page_content = page_content_batch[0][0]84except Exception as e:85sentry_sdk.capture_exception(e)86logger.exception(e)87
88return page_content89
90
91def find_sentences(paragraphs):92sentences_list = []93if paragraphs:94paragraph = paragraphs[0]95paragraph, mentions, mention_pages = delete_hyperlinks(paragraph)96sentences = nltk.sent_tokenize(paragraph)97cur_len = 098max_len = 5099for sentence in sentences:100words = re.findall(re_tokenizer, sentence)101if cur_len + len(words) < max_len and not re.findall(WIKI_BADLIST, sentence):102sentences_list.append(sentence)103cur_len += len(words)104return sentences_list105
106
107def find_facts(entity_substr_batch, entity_ids_batch, entity_pages_batch):108facts_batch = []109for entity_substr_list, entity_ids_list, entity_pages_list in zip(110entity_substr_batch, entity_ids_batch, entity_pages_batch111):112facts_list = []113for entity_substr, entity_ids, entity_pages in zip(entity_substr_list, entity_ids_list, entity_pages_list):114for entity_id, entity_page in zip(entity_ids, entity_pages):115for entity_types_substr in entity_types_sets:116if entity_id in entity_types_sets[entity_types_substr]:117logger.info(f"found_entity_types_substr {entity_types_substr} entity_page {entity_page}")118if entity_types_substr in {"food", "fruit", "vegetable", "berry"}:119found_page_title = ""120entity_tokens = set(re.findall(re_tokenizer, entity_substr))121food_subtopics = wikihow_topics["Food and Entertaining"]122for subtopic in food_subtopics:123page_titles = food_subtopics[subtopic]124for page_title in page_titles:125page_title_tokens = set(page_title.lower().split("-"))126if entity_tokens.intersection(page_title_tokens):127found_page_title = page_title128break129if found_page_title:130break131if found_page_title:132page_content = get_wikihow_content(found_page_title)133if page_content:134page_title_clean = found_page_title.lower().replace("-", " ")135intro = page_content["intro"]136sentences = nltk.sent_tokenize(intro)137facts_list.append(138{139"entity_substr": entity_substr,140"entity_type": entity_types_substr,141"facts": [{"title": page_title_clean, "sentences": sentences}],142}143)144else:145facts = []146page_content = get_page_content(entity_page)147all_titles = find_all_titles([], page_content)148if entity_types_substr in topic_titles:149cur_topic_titles = topic_titles[entity_types_substr]150page_titles = find_topic_titles(all_titles, cur_topic_titles)151for title, page_title in page_titles:152paragraphs = find_paragraph(page_content, page_title)153sentences_list = find_sentences(paragraphs)154if sentences_list:155facts.append({"title": title, "sentences": sentences_list})156if facts:157facts_list.append(158{159"entity_substr": entity_substr,160"entity_type": entity_types_substr,161"facts": list(np.random.choice(facts, size=N_FACTS, replace=False)),162}163)164facts_batch.append(165list(np.random.choice(facts_list, size=N_FACTS, replace=False)) if len(facts_list) > 0 else facts_list166)167return facts_batch168
169
170@app.route("/model", methods=["POST"])171def respond():172st_time = time.time()173cur_utt = request.json.get("human_sentences", [" "])174dialog_history = request.json.get("dialog_history", [" "])175cur_utt = [utt.lstrip("alexa") for utt in cur_utt]176nounphr_list = request.json.get("nounphrases", [])177if FILTER_FREQ:178nounphr_list = [179[nounphrase for nounphrase in nounphrases if nounphrase not in freq_words] for nounphrases in nounphr_list180]181if not nounphr_list:182nounphr_list = [[] for _ in cur_utt]183
184entity_substr = request.json.get("entity_substr", [])185if not entity_substr:186entity_substr = [[] for _ in cur_utt]187entity_pages = request.json.get("entity_pages", [])188if not entity_pages:189entity_pages = [[] for _ in cur_utt]190entity_pages_titles = request.json.get("entity_pages_titles", [])191if not entity_pages_titles:192entity_pages_titles = [[] for _ in cur_utt]193entity_ids = request.json.get("entity_ids", [])194if not entity_ids:195entity_ids = [[] for _ in cur_utt]196logger.info(197f"cur_utt {cur_utt} dialog_history {dialog_history} nounphr_list {nounphr_list} entity_pages {entity_pages}"198)199
200nf_numbers, f_utt, f_dh, f_nounphr_list, f_entity_pages = [], [], [], [], []201for n, (utt, dh, nounphrases, input_pages) in enumerate(zip(cur_utt, dialog_history, nounphr_list, entity_pages)):202if utt not in freq_words and nounphrases:203f_utt.append(utt)204f_dh.append(dh)205f_nounphr_list.append(nounphrases)206f_entity_pages.append(input_pages)207else:208nf_numbers.append(n)209
210out_res = [{"facts": [], "topic_facts": []} for _ in cur_utt]211try:212facts_batch = find_facts(entity_substr, entity_ids, entity_pages_titles)213logger.info(f"f_utt {f_utt}")214if f_utt:215fact_res = fact_retrieval(f_utt) if len(f_utt[0].split()) > 3 else fact_retrieval(f_dh)216if fact_res:217fact_res = fact_res[0]218fact_res = [[fact.replace('""', '"') for fact in facts] for facts in fact_res]219
220out_res = []221cnt_fnd = 0222for i in range(len(cur_utt)):223if i in nf_numbers:224out_res.append({})225else:226if cnt_fnd < len(fact_res):227out_res.append(228{229"topic_facts": facts_batch[cnt_fnd],230"facts": list(np.random.choice(fact_res[cnt_fnd], size=N_FACTS, replace=False))231if len(fact_res[cnt_fnd]) > 0232else fact_res[cnt_fnd],233}234)235cnt_fnd += 1236else:237out_res.append({"facts": [], "topic_facts": []})238except Exception as e:239sentry_sdk.capture_exception(e)240logger.exception(e)241total_time = time.time() - st_time242logger.info(f"fact_retrieval exec time: {total_time:.3f}s")243return jsonify(out_res)244
245
246if __name__ == "__main__":247app.run(debug=False, host="0.0.0.0", port=3000)248