dream
333 строки · 11.8 Кб
1import logging2import os3import random4
5from dff.script import Context6from dff.pipeline import Pipeline7
8import common.constants as common_constants9import common.link as common_link10import common.news as common_news11import common.utils as common_utils12
13logger = logging.getLogger(__name__)14SERVICE_NAME = os.getenv("SERVICE_NAME")15
16
17NEWS_API_ANNOTATOR_URL = os.getenv("NEWS_API_ANNOTATOR_URL")18
19
20def get_new_human_labeled_noun_phrase(ctx: Context, pipeline: Pipeline) -> list:21return (22[]23if ctx.validation24else (25get_last_human_utterance(ctx, pipeline).get("annotations", {}).get("cobot_entities", {}).get("entities", [])26)27)28
29
30def get_human_sentiment(ctx: Context, pipeline: Pipeline, negative_threshold=0.5, positive_threshold=0.333) -> str:31sentiment_probs = (32None if ctx.validation else common_utils.get_sentiment(get_last_human_utterance(ctx, pipeline), probs=True)33)34if sentiment_probs and isinstance(sentiment_probs, dict):35max_sentiment_prob = max(sentiment_probs.values())36max_sentiments = [37sentiment for sentiment in sentiment_probs if sentiment_probs[sentiment] == max_sentiment_prob38]39if max_sentiments:40max_sentiment = max_sentiments[0]41return_negative = max_sentiment == "negative" and max_sentiment_prob >= negative_threshold42return_positive = max_sentiment == "positive" and max_sentiment_prob >= positive_threshold43if return_negative or return_positive:44return max_sentiment45return "neutral"46
47
48def get_cross_state(ctx: Context, _, service_name=SERVICE_NAME.replace("-", "_")) -> dict:49return {} if ctx.validation else ctx.misc["agent"]["dff_shared_state"]["cross_states"].get(service_name, {})50
51
52def save_cross_state(ctx: Context, _, service_name=SERVICE_NAME.replace("-", "_"), new_state={}):53if not ctx.validation:54ctx.misc["agent"]["dff_shared_state"]["cross_states"][service_name] = new_state55
56
57def get_cross_link(ctx: Context, pipeline: Pipeline, service_name=SERVICE_NAME.replace("-", "_")) -> dict:58links = {} if ctx.validation else ctx.misc["agent"]["dff_shared_state"]["cross_links"].get(service_name, {})59cur_human_index = get_human_utter_index(ctx, pipeline)60cross_link = [cross_link for human_index, cross_link in links.items() if (cur_human_index - int(human_index)) == 1]61cross_link = cross_link[0] if cross_link else {}62return cross_link63
64
65def set_cross_link(66ctx: Context,67pipeline: Pipeline,68to_service_name,69cross_link_additional_data={},70from_service_name=SERVICE_NAME.replace("-", "_"),71):72cur_human_index = get_human_utter_index(ctx, pipeline)73if not ctx.validation:74ctx.misc["agent"]["dff_shared_state"]["cross_links"][to_service_name] = {75cur_human_index: {76"from_service": from_service_name,77**cross_link_additional_data,78}79}80
81
82def reset_response_parts(ctx: Context, _):83if not ctx.validation and "response_parts" in ctx.misc["agent"]:84del ctx.misc["agent"]["response_parts"]85
86
87def add_parts_to_response_parts(ctx: Context, _, parts=[]):88response_parts = set([] if ctx.validation else ctx.misc["agent"].get("response_parts", []))89response_parts.update(parts)90if not ctx.validation:91ctx.misc["agent"]["response_parts"] = sorted(list(response_parts))92
93
94def set_acknowledgement_to_response_parts(ctx: Context, pipeline: Pipeline):95reset_response_parts(ctx, pipeline)96add_parts_to_response_parts(ctx, pipeline, parts=["acknowledgement"])97
98
99def add_acknowledgement_to_response_parts(ctx: Context, pipeline: Pipeline):100if not ctx.validation and ctx.misc["agent"].get("response_parts") is None:101add_parts_to_response_parts(ctx, pipeline, parts=["body"])102add_parts_to_response_parts(ctx, pipeline, parts=["acknowledgement"])103
104
105def set_body_to_response_parts(ctx: Context, pipeline: Pipeline):106reset_response_parts(ctx, pipeline)107add_parts_to_response_parts(ctx, pipeline, parts=["body"])108
109
110def add_body_to_response_parts(ctx: Context, pipeline: Pipeline):111add_parts_to_response_parts(ctx, pipeline, parts=["body"])112
113
114def set_prompt_to_response_parts(ctx: Context, pipeline: Pipeline):115reset_response_parts(ctx, pipeline)116add_parts_to_response_parts(ctx, pipeline, parts=["prompt"])117
118
119def add_prompt_to_response_parts(ctx: Context, pipeline: Pipeline):120add_parts_to_response_parts(ctx, pipeline, parts=["prompt"])121
122
123def get_shared_memory(ctx: Context, _) -> dict:124return {} if ctx.validation else ctx.misc["agent"]["shared_memory"]125
126
127def get_used_links(ctx: Context, _) -> dict:128return {} if ctx.validation else ctx.misc["agent"]["used_links"]129
130
131def get_age_group(ctx: Context, _) -> dict:132return {} if ctx.validation else ctx.misc["agent"]["age_group"]133
134
135def set_age_group(ctx: Context, _, set_age_group):136if not ctx.validation:137ctx.misc["agent"]["age_group"] = set_age_group138
139
140def get_disliked_skills(ctx: Context, _) -> list:141return [] if ctx.validation else ctx.misc["agent"]["disliked_skills"]142
143
144def get_human_utter_index(ctx: Context, _) -> int:145return 0 if ctx.validation else ctx.misc["agent"]["human_utter_index"]146
147
148def get_previous_human_utter_index(ctx: Context, _) -> int:149return 0 if ctx.validation else ctx.misc["agent"]["previous_human_utter_index"]150
151
152def get_dialog(ctx: Context, _) -> dict:153return {} if ctx.validation else ctx.misc["agent"]["dialog"]154
155
156def get_utterances(ctx: Context, _) -> dict:157return [] if ctx.validation else ctx.misc["agent"]["dialog"]["utterances"]158
159
160def get_human_utterances(ctx: Context, _) -> dict:161return [] if ctx.validation else ctx.misc["agent"]["dialog"]["human_utterances"]162
163
164def get_last_human_utterance(ctx: Context, _) -> dict:165return {"text": "", "annotations": {}} if ctx.validation else ctx.misc["agent"]["dialog"]["human_utterances"][-1]166
167
168def get_bot_utterances(ctx: Context, _) -> list:169return [] if ctx.validation else ctx.misc["agent"]["dialog"]["bot_utterances"]170
171
172def get_last_bot_utterance(ctx: Context, _) -> dict:173if not ctx.validation and ctx.misc["agent"]["dialog"]["bot_utterances"]:174return ctx.misc["agent"]["dialog"]["bot_utterances"][-1]175else:176return {"text": "", "annotations": {}}177
178
179def save_to_shared_memory(ctx: Context, _, **kwargs):180if not ctx.validation:181ctx.misc["agent"]["shared_memory"].update(kwargs)182
183
184def update_used_links(ctx: Context, _, linked_skill_name, linking_phrase):185if not ctx.validation:186agent = ctx.misc["agent"]187agent["used_links"][linked_skill_name] = agent["used_links"].get(linked_skill_name, []) + [linking_phrase]188
189
190def get_new_link_to(ctx: Context, pipeline: Pipeline, skill_names):191used_links = get_used_links(ctx, pipeline)192disliked_skills = get_disliked_skills(ctx, pipeline)193
194link = common_link.link_to(195skill_names, human_attributes={"used_links": used_links, "disliked_skills": disliked_skills}196)197update_used_links(ctx, pipeline, link["skill"], link["phrase"])198return link199
200
201def set_dff_suspension(ctx: Context, _):202if not ctx.validation:203ctx.misc["agent"]["current_turn_dff_suspended"] = True204
205
206def reset_dff_suspension(ctx: Context, _):207if not ctx.validation:208ctx.misc["agent"]["current_turn_dff_suspended"] = False209
210
211def set_confidence(ctx: Context, pipeline: Pipeline, confidence=1.0):212if not ctx.validation:213ctx.misc["agent"]["response"].update({"confidence": confidence})214if confidence == 0.0:215reset_can_continue(ctx, pipeline)216
217
218def set_can_continue(ctx: Context, _, continue_flag=common_constants.CAN_CONTINUE_SCENARIO):219if not ctx.validation:220ctx.misc["agent"]["response"].update({"can_continue": continue_flag})221
222
223def reset_can_continue(ctx: Context, _):224if not ctx.validation and "can_continue" in ctx.misc["agent"]["response"]:225del ctx.misc["agent"]["response"]["can_continue"]226
227
228def get_named_entities_from_human_utterance(ctx: Context, pipeline: Pipeline):229# ent is a dict! ent = {"text": "London":, "type": "LOC"}230entities = common_utils.get_entities(231get_last_human_utterance(ctx, pipeline),232only_named=True,233with_labels=True,234)235return entities236
237
238def get_nounphrases_from_human_utterance(ctx: Context, pipeline: Pipeline):239nps = common_utils.get_entities(240get_last_human_utterance(ctx, pipeline),241only_named=False,242with_labels=False,243)244return nps245
246
247def get_fact_random_annotations_from_human_utterance(ctx: Context, pipeline: Pipeline) -> dict:248if not ctx.validation:249return (250get_last_human_utterance(ctx, pipeline)251.get("annotations", {})252.get("fact_random", {"facts": [], "response": ""})253)254else:255return {"facts": [], "response": ""}256
257
258def get_fact_for_particular_entity_from_human_utterance(ctx: Context, pipeline: Pipeline, entity) -> list:259fact_random_results = get_fact_random_annotations_from_human_utterance(ctx, pipeline)260facts_for_entity = []261for fact in fact_random_results.get("facts", []):262is_same_entity = fact.get("entity_substr", "").lower() == entity.lower()263is_sorry = "Sorry, I don't know" in fact.get("fact", "")264if is_same_entity and not is_sorry:265facts_for_entity += [fact["fact"]]266
267return facts_for_entity268
269
270def get_news_about_particular_entity_from_human_utterance(ctx: Context, pipeline: Pipeline, entity) -> dict:271last_uttr = get_last_human_utterance(ctx, pipeline)272last_uttr_entities_news = last_uttr.get("annotations", {}).get("news_api_annotator", [])273curr_news = {}274for news_entity in last_uttr_entities_news:275if news_entity["entity"] == entity:276curr_news = news_entity["news"]277break278if not curr_news:279curr_news = common_news.get_news_about_topic(entity, NEWS_API_ANNOTATOR_URL)280
281return curr_news282
283
284def get_facts_from_fact_retrieval(ctx: Context, pipeline: Pipeline) -> list:285annotations = get_last_human_utterance(ctx, pipeline).get("annotations", {})286if "fact_retrieval" in annotations:287if isinstance(annotations["fact_retrieval"], dict):288return annotations["fact_retrieval"].get("facts", [])289elif isinstance(annotations["fact_retrieval"], list):290return annotations["fact_retrieval"]291return []292
293
294def get_unrepeatable_index_from_rand_seq(295ctx: Context, pipeline: Pipeline, seq_name, seq_max, renew_seq_if_empty=False296) -> int:297"""Return a unrepeatable index from RANDOM_SEQUENCE.298RANDOM_SEQUENCE is stored in shared merory by name `seq_name`.
299RANDOM_SEQUENCE is shuffled [0..`seq_max`].
300RANDOM_SEQUENCE will be updated after index will get out of RANDOM_SEQUENCE if `renew_seq_if_empty` is True
301"""
302shared_memory = get_shared_memory(ctx, pipeline)303seq = shared_memory.get(seq_name, random.sample(list(range(seq_max)), seq_max))304if renew_seq_if_empty or seq:305seq = seq if seq else random.sample(list(range(seq_max)), seq_max)306next_index = seq[-1] if seq else None307save_to_shared_memory(ctx, **{seq_name: seq[:-1]})308return next_index309
310
311def get_history(ctx: Context, _):312if not ctx.validation:313return ctx.misc["agent"]["history"]314return {}315
316
317def get_n_last_state(ctx: Context, pipeline: Pipeline, n) -> str:318last_state = ""319history = list(get_history(ctx, pipeline).items())320if history:321history_sorted = sorted(history, key=lambda x: x[0])322if len(history_sorted) >= n:323last_state = history_sorted[-n][1]324return last_state325
326
327def get_last_state(ctx: Context, pipeline: Pipeline) -> str:328last_state = ""329history = list(get_history(ctx, pipeline).items())330if history:331history_sorted = sorted(history, key=lambda x: x[0])332last_state = history_sorted[-1][1]333return last_state334