1
from typing import Sequence, List, Tuple, Callable, Dict
7
from core.state_schema import Dialog
10
def detokenize(tokens):
12
Detokenizing a text undoes the tokenizing operation, restores
13
punctuation and spaces to the places that people expect them to be.
14
Ideally, `detokenize(tokenize(text))` should be identical to `text`,
15
except for line breaks.
17
text = " ".join(tokens)
18
step0 = text.replace(". . .", "...")
19
step1 = step0.replace("`` ", '"').replace(" ''", '"')
20
step2 = step1.replace(" ( ", " (").replace(" ) ", ") ")
21
step3 = re.sub(r' ([.,:;?!%]+)([ \'"`])', r"\1\2", step2)
22
step4 = re.sub(r" ([.,:;?!%]+)$", r"\1", step3)
23
step5 = step4.replace(" '", "'").replace(" n't", "n't").replace(" nt", "nt").replace("can not", "cannot")
24
step6 = step5.replace(" ` ", " '")
28
class PersonNormalizer:
30
Detects mentions of mate user's name and either
31
(0) converts them to user's name taken from state
32
(1) either removes them.
35
person_tag: tag name that corresponds to a person entity
38
def __init__(self, person_tag: str = "PER", **kwargs):
39
self.per_tag = person_tag
42
self, tokens: List[List[str]], tags: List[List[str]], names: List[str]
43
) -> Tuple[List[List[str]], List[List[str]]]:
44
out_tokens, out_tags = [], []
45
for u_name, u_toks, u_tags in zip(names, tokens, tags):
46
u_toks, u_tags = self.tag_mate_gooser_name(u_toks, u_tags, person_tag=self.per_tag)
48
u_toks, u_tags = self.replace_mate_gooser_name(u_toks, u_tags, u_name)
49
if random.random() < 0.5:
50
u_toks = [u_name, ","] + u_toks
51
u_tags = ["B-MATE-GOOSER", "O"] + u_tags
53
u_toks[0] = u_toks[0][0].upper() + u_toks[0][1:]
55
u_toks[2] = u_toks[2][0].lower() + u_toks[2][1:]
57
u_toks, u_tags = self.remove_mate_gooser_name(u_toks, u_tags)
58
out_tokens.append(u_toks)
59
out_tags.append(u_tags)
60
return out_tokens, out_tags
63
def tag_mate_gooser_name(
64
tokens: List[str], tags: List[str], person_tag: str = "PER", mate_tag: str = "MATE-GOOSER"
65
) -> Tuple[List[str], List[str]]:
66
if "B-" + person_tag not in tags:
70
while i < len(tokens):
71
tok, tag = tokens[i], tags[i]
72
if i + 1 < len(tokens):
73
if (tok == ",") and (tags[i + 1] == "B-" + person_tag):
77
while (i + j < len(tokens)) and (tags[i + j][2:] == person_tag):
79
if (i + j == len(tokens)) or (tokens[i + j][0] in ",.!?;)"):
81
out_tags.extend([t[:2] + mate_tag for t in tags[i + 1 : i + j]])
84
out_tags.extend(tags[i + 1 : i + j])
88
if (tok == ",") and (tags[i - 1][2:] == person_tag):
91
while (len(out_tags) >= j) and (out_tags[-j][2:] == person_tag):
93
if (len(out_tags) < j) or (tokens[i - j][-1] in ",.!?("):
95
for k in range(j - 1):
96
out_tags[-k - 1] = out_tags[-k - 1][:2] + mate_tag
102
return tokens, out_tags
105
def replace_mate_gooser_name(
106
tokens: List[str], tags: List[str], replacement: str, mate_tag: str = "MATE-GOOSER"
107
) -> Tuple[List[str], List[str]]:
108
assert len(tokens) == len(tags), f"tokens({tokens}) and tags({tags}) should have the same length"
109
if "B-" + mate_tag not in tags:
112
repl_tokens = replacement.split()
113
repl_tags = ["B-" + mate_tag] + ["I-" + mate_tag] * (len(repl_tokens) - 1)
115
out_tokens, out_tags = [], []
117
while i < len(tokens):
118
tok, tag = tokens[i], tags[i]
119
if tag == "B-" + mate_tag:
120
out_tokens.extend(repl_tokens)
121
out_tags.extend(repl_tags)
123
while (i < len(tokens)) and (tokens[i] == "I-" + mate_tag):
126
out_tokens.append(tok)
129
return out_tokens, out_tags
132
def remove_mate_gooser_name(
133
tokens: List[str], tags: List[str], mate_tag: str = "MATE-GOOSER"
134
) -> Tuple[List[str], List[str]]:
135
assert len(tokens) == len(tags), f"tokens({tokens}) and tags({tags}) should have the same length"
137
if "B-" + mate_tag not in tags:
139
out_tokens, out_tags = [], []
141
while i < len(tokens):
142
tok, tag = tokens[i], tags[i]
143
if i + 1 < len(tokens):
144
if (tok == ",") and (tags[i + 1] == "B-" + mate_tag):
149
if (tok == ",") and (tags[i - 1][2:] == mate_tag):
153
if tag[2:] != mate_tag:
154
out_tokens.append(tok)
157
return out_tokens, out_tags
160
LIST_LIST_STR_BATCH = List[List[List[str]]]
163
class HistoryPersonNormalize:
165
Takes batch of dialog histories and normalizes only bot responses.
167
Detects mentions of mate user's name and either
168
(0) converts them to user's name taken from state
169
(1) either removes them.
172
per_tag: tag name that corresponds to a person entity
175
def __init__(self, per_tag: str = "PER", **kwargs):
176
self.per_normalizer = PersonNormalizer(per_tag=per_tag)
179
self, history_tokens: LIST_LIST_STR_BATCH, tags: LIST_LIST_STR_BATCH, states: List[Dict]
180
) -> Tuple[LIST_LIST_STR_BATCH, LIST_LIST_STR_BATCH]:
181
out_tokens, out_tags = [], []
182
states = states if states else [{}] * len(tags)
183
for u_state, u_hist_tokens, u_hist_tags in zip(states, history_tokens, tags):
186
return out_tokens, out_tags
191
Finds first mention of a name and sets it as a user name.
194
person_tag: tag name that corresponds to a person entity
195
state_slot: name of a state slot corresponding to a user's name
199
def __init__(self, person_tag: str = "PER", **kwargs):
200
self.per_tag = person_tag
202
def __call__(self, tokens: List[List[str]], tags: List[List[str]], states: List[dict]) -> List[str]:
204
for u_state, u_toks, u_tags in zip(states, tokens, tags):
205
cur_name = u_state["user"]["profile"]["name"]
206
new_name = copy(cur_name)
208
name_found = self.find_my_name(u_toks, u_tags, person_tag=self.per_tag)
209
if name_found is not None:
210
new_name = name_found
211
names.append(new_name)
215
def find_my_name(tokens: List[str], tags: List[str], person_tag: str) -> str:
216
if "B-" + person_tag not in tags:
218
per_start = tags.index("B-" + person_tag)
219
per_excl_end = per_start + 1
220
while (per_excl_end < len(tokens)) and (tags[per_excl_end] == "I-" + person_tag):
222
return " ".join(tokens[per_start:per_excl_end])
225
class NerWithContextWrapper:
227
Tokenizers utterance and history of dialogue and gets entity tags for
231
ner_model: named entity recognition model
232
tokenizer: tokenizer to use
236
def __init__(self, ner_model: Callable, tokenizer: Callable, context_delimeter: str = None, **kwargs):
237
self.ner_model = ner_model
238
self.tokenizer = tokenizer
239
self.context_delimeter = context_delimeter
242
self, utterances: List[str], history: List[List[str]] = [[]], prev_utterances: List[str] = []
243
) -> Tuple[List[List[str]], List[List[str]]]:
245
history = history or itertools.repeat([])
246
history = [hist + [prev] for prev, hist in zip(prev_utterances, history)]
248
[tok for toks in self.tokenizer(hist or [""]) for tok in toks + [self.context_delimeter] if tok is not None]
251
utt_toks = self.tokenizer(utterances)
252
texts, ranges = [], []
253
for utt, hist in zip(utt_toks, history_toks):
254
if self.context_delimeter is not None:
255
txt = hist + utt + [self.context_delimeter]
258
ranges.append((len(hist), len(hist) + len(utt)))
261
_, tags = self.ner_model(texts)
262
tags = [t[l:r] for t, (l, r) in zip(tags, ranges)]
264
return utt_toks, tags
267
class DefaultPostprocessor:
268
def __init__(self) -> None:
269
self.person_normalizer = PersonNormalizer(per_tag="PER")
271
def __call__(self, dialogs: Sequence[Dialog]) -> Sequence[str]:
275
response = d["utterances"][-1]
277
ner_annotations = response["annotations"].get("ner", {})
278
user_name = d["user"]["profile"]["name"]
280
if ner_annotations and (response["active_skill"] == "chitchat"):
281
response_toks_norm, _ = self.person_normalizer(
282
[ner_annotations["tokens"]], [ner_annotations["tags"]], [user_name]
284
response_toks_norm = response_toks_norm[0]
286
new_responses.append(detokenize(response_toks_norm))
288
new_responses.append(response["text"])
290
new_responses.append(response["text"])