dream
317 строк · 12.8 Кб
1from typing import Dict, List
2import logging
3from copy import deepcopy
4import re
5
6from common.universal_templates import if_chat_about_particular_topic
7from common.utils import get_intents, service_intents
8from common.grounding import BUT_PHRASE, REPEAT_PHRASE
9
10logger = logging.getLogger(__name__)
11LAST_N_TURNS = 5 # number of turns to consider in annotator/skill.
12
13
14spaces_pat = re.compile(r"\s+")
15special_symb_pat = re.compile(r"[^a-zа-я0-9' ]", flags=re.IGNORECASE)
16
17
18def clean_text(text):
19return special_symb_pat.sub(" ", spaces_pat.sub(" ", text.lower().replace("\n", " "))).strip()
20
21
22def get_last_n_turns(
23dialog: Dict,
24bot_last_turns=None,
25human_last_turns=None,
26total_last_turns=None,
27excluded_attributes=["entities"],
28):
29bot_last_turns = bot_last_turns or LAST_N_TURNS
30human_last_turns = human_last_turns or bot_last_turns + 1
31total_last_turns = total_last_turns or bot_last_turns * 2 + 1
32utterance_texts = [utterance["text"] for utterance in dialog["utterances"][-total_last_turns:]]
33for utterance_text in utterance_texts:
34if "#repeat" in utterance_text: # Not to lose history on each repeat
35human_last_turns += 1
36bot_last_turns += 1
37total_last_turns += 2
38new_dialog = {}
39for key, value in dialog.items():
40if key not in ["utterances", "human_utterances", "bot_utterances"]:
41if isinstance(value, dict) and "attributes" in value:
42new_dialog[key] = {k: deepcopy(v) for k, v in value.items() if k != "attributes"}
43new_dialog[key]["attributes"] = {
44k: deepcopy(v) for k, v in value["attributes"].items() if k not in excluded_attributes
45}
46else:
47new_dialog[key] = deepcopy(value)
48new_dialog["utterances"] = deepcopy(dialog["utterances"][-total_last_turns:])
49
50new_dialog["human_utterances"] = []
51new_dialog["bot_utterances"] = []
52
53for utt in new_dialog["utterances"]:
54if utt["user"]["user_type"] == "human":
55new_dialog["human_utterances"].append(deepcopy(utt))
56elif utt["user"]["user_type"] == "bot":
57new_dialog["bot_utterances"].append(deepcopy(utt))
58
59return new_dialog
60
61
62def is_human_uttr_repeat_request_or_misheard(utt):
63is_repeat_request = utt.get("annotations", {}).get("intent_catcher", {}).get("repeat", {}).get("detected", 0) == 1
64is_low_asr_conf = utt.get("annotations", {}).get("asr", {}).get("asr_confidence", "") == "very_low"
65if is_low_asr_conf or is_repeat_request:
66return True
67else:
68return False
69
70
71def is_bot_uttr_repeated_or_misheard(utt):
72is_asr = utt.get("active_skill", "") == "misheard_asr" and utt.get("confidence", 0.0) == 1.0
73is_repeated = "#+#repeat" in utt.get("text", "")
74detected_interrupt = any(
75[interrupt_phrase in utt.get("text", "") for interrupt_phrase in [BUT_PHRASE, REPEAT_PHRASE]]
76)
77if is_asr or is_repeated or detected_interrupt:
78return True
79else:
80return False
81
82
83def remove_clarification_turns_from_dialog(dialog):
84new_dialog = deepcopy(dialog)
85new_dialog["utterances"] = []
86dialog_length = len(dialog["utterances"])
87
88for i, utt in enumerate(dialog["utterances"]):
89if utt["user"]["user_type"] == "human":
90new_dialog["utterances"].append(utt)
91elif utt["user"]["user_type"] == "bot":
92if (
930 < i < dialog_length - 1
94and is_bot_uttr_repeated_or_misheard(utt)
95and is_human_uttr_repeat_request_or_misheard(dialog["utterances"][i - 1])
96):
97new_dialog["utterances"] = new_dialog["utterances"][:-1]
98else:
99new_dialog["utterances"].append(utt)
100
101new_dialog["human_utterances"] = []
102new_dialog["bot_utterances"] = []
103
104for utt in new_dialog["utterances"]:
105if utt["user"]["user_type"] == "human":
106new_dialog["human_utterances"].append(deepcopy(utt))
107elif utt["user"]["user_type"] == "bot":
108new_dialog["bot_utterances"].append(deepcopy(utt))
109
110return new_dialog
111
112
113def replace_with_annotated_utterances(dialog, mode="punct_sent"):
114if mode == "punct_sent":
115for utt in dialog["utterances"] + dialog["human_utterances"]:
116utt["orig_text"] = utt["text"]
117if "sentseg" in utt["annotations"]:
118utt["text"] = utt["annotations"]["sentseg"]["punct_sent"]
119elif mode == "segments":
120for utt in dialog["utterances"] + dialog["human_utterances"] + dialog["bot_utterances"]:
121utt["orig_text"] = utt["text"]
122if "sentseg" in utt["annotations"]:
123utt["text"] = deepcopy(utt["annotations"]["sentseg"]["segments"])
124elif isinstance(utt["text"], str):
125utt["text"] = [utt["text"]]
126elif mode == "modified_sents":
127for utt in dialog["utterances"] + dialog["human_utterances"]:
128utt["orig_text"] = utt["text"]
129if "sentrewrite" in utt["annotations"]:
130utt["text"] = utt["annotations"]["sentrewrite"]["modified_sents"][-1]
131elif "sentseg" in utt["annotations"]:
132utt["text"] = utt["annotations"]["sentseg"]["punct_sent"]
133elif mode == "clean_sent":
134for utt in dialog["utterances"] + dialog["human_utterances"] + dialog["bot_utterances"]:
135utt["orig_text"] = utt["text"]
136utt["text"] = clean_text(utt["text"])
137return dialog
138
139
140def clean_up_utterances_to_avoid_unwanted_keys(
141dialog,
142wanted_keys=["text", "annotations", "active_skill", "user"],
143types_utterances=["human_utterances", "bot_utterances", "utterances"],
144used_annotations=None,
145):
146# Attention! It removes all other keys from the dialog
147new_dialog = {}
148for key in types_utterances:
149new_dialog[key] = []
150for utter in dialog.get(key, []):
151new_utter = {}
152for wanted_key in wanted_keys:
153if wanted_key in utter:
154if used_annotations and isinstance(used_annotations, list) and wanted_key == "annotations":
155new_annotations = {}
156for annotation_key in used_annotations:
157if annotation_key in utter[wanted_key]:
158new_annotations[annotation_key] = utter[wanted_key][annotation_key]
159new_utter[wanted_key] = new_annotations
160else:
161new_utter[wanted_key] = utter[wanted_key]
162new_dialog[key] += [new_utter]
163return new_dialog
164
165
166def last_n_human_utt_dialog_formatter(dialog: Dict, last_n_utts: int, only_last_sentence: bool = False) -> List:
167"""
168Args:
169dialog (Dict): full dialog state
170last_n_utts (int): how many last user utterances to take
171only_last_sentence (bool, optional): take only last sentence in each utterance. Defaults to False.
172"""
173dialog = deepcopy(dialog)
174if len(dialog["human_utterances"]) <= last_n_utts and not if_chat_about_particular_topic(
175dialog["human_utterances"][0]
176):
177# in all cases when not particular topic, convert first phrase in the dialog to `hello!`
178if "sentseg" in dialog["human_utterances"][0].get("annotations", {}):
179dialog["human_utterances"][0]["annotations"]["sentseg"]["punct_sent"] = "hello!"
180dialog["human_utterances"][0]["annotations"]["sentseg"]["segments"] = ["hello"]
181else:
182dialog["human_utterances"][0]["text"] = "hello"
183
184human_utts = []
185detected_intents = []
186for utt in dialog["human_utterances"][-last_n_utts:]:
187if "sentseg" in utt.get("annotations", {}):
188sentseg_ann = utt["annotations"]["sentseg"]
189if only_last_sentence:
190text = sentseg_ann["segments"][-1] if len(sentseg_ann["segments"]) > 0 else ""
191else:
192text = sentseg_ann["punct_sent"]
193else:
194text = utt["text"]
195human_utts += [text]
196detected_intents += [get_intents(utt, which="all")]
197return [{"sentences_batch": [human_utts], "intents": [detected_intents]}]
198
199
200def stop_formatter_dialog(dialog: Dict) -> List[Dict]:
201# Used by: stop annotator, conv eval annotator
202hypotheses = dialog["utterances"][-1]["hypotheses"]
203utts = []
204for h in hypotheses:
205tmp_utts = [m["text"] for m in dialog["utterances"]]
206tmp_utts.append(h["text"])
207tmp_utts = " [SEP] ".join([j for j in tmp_utts])
208utts.append(tmp_utts)
209return [{"dialogs": utts}]
210
211
212def count_ongoing_skill_utterances(bot_utterances: List[Dict], skill: str) -> int:
213i = 0
214for utt in bot_utterances[::-1]:
215if utt["active_skill"] == skill:
216i += 1
217else:
218break
219return i
220
221
222def dff_formatter(
223dialog: Dict,
224service_name: str,
225bot_last_turns=1,
226human_last_turns=1,
227used_annotations=None,
228types_utterances=None,
229wanted_keys=None,
230) -> List[Dict]:
231types_utterances = ["human_utterances", "bot_utterances"] if types_utterances is None else types_utterances
232wanted_keys = ["text", "annotations", "active_skill", "user"] if wanted_keys is None else wanted_keys
233# DialoFlow Framework formatter
234state_name = f"{service_name}_state"
235human_utter_index = len(dialog["human_utterances"]) - 1
236
237human_attributes = dialog.get("human", {}).get("attributes", {})
238state = human_attributes.get(state_name, {})
239dff_shared_state = human_attributes.get("dff_shared_state", {"cross_states": {}, "cross_links": {}})
240used_links = human_attributes.get("used_links", {})
241age_group = human_attributes.get("age_group", "")
242disliked_skills = human_attributes.get("disliked_skills", {})
243entities = human_attributes.get("entities", {})
244prompts_goals = human_attributes.get("prompts_goals", {})
245
246previous_human_utter_index = state.get("previous_human_utter_index", -1)
247checking_unclarified_n_turns = human_utter_index - previous_human_utter_index
248if 1 < checking_unclarified_n_turns <= LAST_N_TURNS and previous_human_utter_index != -1:
249turns = list(
250zip(
251dialog["human_utterances"][-checking_unclarified_n_turns:],
252dialog["bot_utterances"][-checking_unclarified_n_turns:],
253)
254)
255unclarified_turns = [
256None
257for hu, bu in turns
258if is_human_uttr_repeat_request_or_misheard(hu) and is_bot_uttr_repeated_or_misheard(bu)
259]
260clarification_request_flag = len(unclarified_turns) == 1
261else:
262clarification_request_flag = False
263
264dialog = get_last_n_turns(dialog)
265dialog = remove_clarification_turns_from_dialog(dialog)
266dialog = get_last_n_turns(dialog, bot_last_turns=bot_last_turns, human_last_turns=human_last_turns)
267dialog = replace_with_annotated_utterances(dialog, mode="punct_sent")
268
269# rm all execpt human_utterances, bot_utterances
270# we need only: text, annotations, active_skill
271new_dialog = clean_up_utterances_to_avoid_unwanted_keys(
272dialog, wanted_keys=wanted_keys, types_utterances=types_utterances, used_annotations=used_annotations
273)
274
275return [
276{
277"human_utter_index_batch": [human_utter_index],
278"dialog_batch": [new_dialog],
279f"{state_name}_batch": [state],
280"dff_shared_state_batch": [dff_shared_state],
281"entities_batch": [entities],
282"used_links_batch": [used_links],
283"age_group_batch": [age_group],
284"disliked_skills_batch": [disliked_skills],
285"prompts_goals_batch": [prompts_goals],
286"clarification_request_flag_batch": [clarification_request_flag],
287"dialog_id_batch": [dialog["dialog_id"]],
288}
289]
290
291
292def programy_post_formatter_dialog(dialog: Dict) -> Dict:
293# Used by: program_y, program_y_dangerous, program_y_wide
294# Look at skills/program_y*
295dialog = get_last_n_turns(dialog, bot_last_turns=6)
296first_uttr_hi = False
297if len(dialog["human_utterances"]) == 1 and not if_chat_about_particular_topic(dialog["human_utterances"][-1]):
298first_uttr_hi = True
299
300dialog = remove_clarification_turns_from_dialog(dialog)
301dialog = last_n_human_utt_dialog_formatter(dialog, last_n_utts=5)[0]
302sentences = dialog["sentences_batch"][0]
303intents = dialog["intents"][0]
304
305# modify sentences with yes/no intents to yes/no phrase
306# todo: sent may contain multiple sentence, logic here could be improved
307prioritized_intents = service_intents - {"yes", "no"}
308for i, (sent, ints) in enumerate(zip(sentences, intents)):
309ints = set(ints)
310if "?" not in sent and len(ints & prioritized_intents) == 0:
311if "yes" in ints:
312sentences[i] = "yes."
313elif "no" in ints:
314sentences[i] = "no."
315if first_uttr_hi:
316sentences = ["hi."]
317return {"sentences_batch": [sentences]}
318