dream

Форк
0
333 строки · 11.8 Кб
1
import logging
2
import os
3
import random
4

5
from dff.script import Context
6
from dff.pipeline import Pipeline
7

8
import common.constants as common_constants
9
import common.link as common_link
10
import common.news as common_news
11
import common.utils as common_utils
12

13
logger = logging.getLogger(__name__)
14
SERVICE_NAME = os.getenv("SERVICE_NAME")
15

16

17
NEWS_API_ANNOTATOR_URL = os.getenv("NEWS_API_ANNOTATOR_URL")
18

19

20
def get_new_human_labeled_noun_phrase(ctx: Context, pipeline: Pipeline) -> list:
21
    return (
22
        []
23
        if ctx.validation
24
        else (
25
            get_last_human_utterance(ctx, pipeline).get("annotations", {}).get("cobot_entities", {}).get("entities", [])
26
        )
27
    )
28

29

30
def get_human_sentiment(ctx: Context, pipeline: Pipeline, negative_threshold=0.5, positive_threshold=0.333) -> str:
31
    sentiment_probs = (
32
        None if ctx.validation else common_utils.get_sentiment(get_last_human_utterance(ctx, pipeline), probs=True)
33
    )
34
    if sentiment_probs and isinstance(sentiment_probs, dict):
35
        max_sentiment_prob = max(sentiment_probs.values())
36
        max_sentiments = [
37
            sentiment for sentiment in sentiment_probs if sentiment_probs[sentiment] == max_sentiment_prob
38
        ]
39
        if max_sentiments:
40
            max_sentiment = max_sentiments[0]
41
            return_negative = max_sentiment == "negative" and max_sentiment_prob >= negative_threshold
42
            return_positive = max_sentiment == "positive" and max_sentiment_prob >= positive_threshold
43
            if return_negative or return_positive:
44
                return max_sentiment
45
    return "neutral"
46

47

48
def get_cross_state(ctx: Context, _, service_name=SERVICE_NAME.replace("-", "_")) -> dict:
49
    return {} if ctx.validation else ctx.misc["agent"]["dff_shared_state"]["cross_states"].get(service_name, {})
50

51

52
def save_cross_state(ctx: Context, _, service_name=SERVICE_NAME.replace("-", "_"), new_state={}):
53
    if not ctx.validation:
54
        ctx.misc["agent"]["dff_shared_state"]["cross_states"][service_name] = new_state
55

56

57
def get_cross_link(ctx: Context, pipeline: Pipeline, service_name=SERVICE_NAME.replace("-", "_")) -> dict:
58
    links = {} if ctx.validation else ctx.misc["agent"]["dff_shared_state"]["cross_links"].get(service_name, {})
59
    cur_human_index = get_human_utter_index(ctx, pipeline)
60
    cross_link = [cross_link for human_index, cross_link in links.items() if (cur_human_index - int(human_index)) == 1]
61
    cross_link = cross_link[0] if cross_link else {}
62
    return cross_link
63

64

65
def set_cross_link(
66
    ctx: Context,
67
    pipeline: Pipeline,
68
    to_service_name,
69
    cross_link_additional_data={},
70
    from_service_name=SERVICE_NAME.replace("-", "_"),
71
):
72
    cur_human_index = get_human_utter_index(ctx, pipeline)
73
    if not ctx.validation:
74
        ctx.misc["agent"]["dff_shared_state"]["cross_links"][to_service_name] = {
75
            cur_human_index: {
76
                "from_service": from_service_name,
77
                **cross_link_additional_data,
78
            }
79
        }
80

81

82
def reset_response_parts(ctx: Context, _):
83
    if not ctx.validation and "response_parts" in ctx.misc["agent"]:
84
        del ctx.misc["agent"]["response_parts"]
85

86

87
def add_parts_to_response_parts(ctx: Context, _, parts=[]):
88
    response_parts = set([] if ctx.validation else ctx.misc["agent"].get("response_parts", []))
89
    response_parts.update(parts)
90
    if not ctx.validation:
91
        ctx.misc["agent"]["response_parts"] = sorted(list(response_parts))
92

93

94
def set_acknowledgement_to_response_parts(ctx: Context, pipeline: Pipeline):
95
    reset_response_parts(ctx, pipeline)
96
    add_parts_to_response_parts(ctx, pipeline, parts=["acknowledgement"])
97

98

99
def add_acknowledgement_to_response_parts(ctx: Context, pipeline: Pipeline):
100
    if not ctx.validation and ctx.misc["agent"].get("response_parts") is None:
101
        add_parts_to_response_parts(ctx, pipeline, parts=["body"])
102
    add_parts_to_response_parts(ctx, pipeline, parts=["acknowledgement"])
103

104

105
def set_body_to_response_parts(ctx: Context, pipeline: Pipeline):
106
    reset_response_parts(ctx, pipeline)
107
    add_parts_to_response_parts(ctx, pipeline, parts=["body"])
108

109

110
def add_body_to_response_parts(ctx: Context, pipeline: Pipeline):
111
    add_parts_to_response_parts(ctx, pipeline, parts=["body"])
112

113

114
def set_prompt_to_response_parts(ctx: Context, pipeline: Pipeline):
115
    reset_response_parts(ctx, pipeline)
116
    add_parts_to_response_parts(ctx, pipeline, parts=["prompt"])
117

118

119
def add_prompt_to_response_parts(ctx: Context, pipeline: Pipeline):
120
    add_parts_to_response_parts(ctx, pipeline, parts=["prompt"])
121

122

123
def get_shared_memory(ctx: Context, _) -> dict:
124
    return {} if ctx.validation else ctx.misc["agent"]["shared_memory"]
125

126

127
def get_used_links(ctx: Context, _) -> dict:
128
    return {} if ctx.validation else ctx.misc["agent"]["used_links"]
129

130

131
def get_age_group(ctx: Context, _) -> dict:
132
    return {} if ctx.validation else ctx.misc["agent"]["age_group"]
133

134

135
def set_age_group(ctx: Context, _, set_age_group):
136
    if not ctx.validation:
137
        ctx.misc["agent"]["age_group"] = set_age_group
138

139

140
def get_disliked_skills(ctx: Context, _) -> list:
141
    return [] if ctx.validation else ctx.misc["agent"]["disliked_skills"]
142

143

144
def get_human_utter_index(ctx: Context, _) -> int:
145
    return 0 if ctx.validation else ctx.misc["agent"]["human_utter_index"]
146

147

148
def get_previous_human_utter_index(ctx: Context, _) -> int:
149
    return 0 if ctx.validation else ctx.misc["agent"]["previous_human_utter_index"]
150

151

152
def get_dialog(ctx: Context, _) -> dict:
153
    return {} if ctx.validation else ctx.misc["agent"]["dialog"]
154

155

156
def get_utterances(ctx: Context, _) -> dict:
157
    return [] if ctx.validation else ctx.misc["agent"]["dialog"]["utterances"]
158

159

160
def get_human_utterances(ctx: Context, _) -> dict:
161
    return [] if ctx.validation else ctx.misc["agent"]["dialog"]["human_utterances"]
162

163

164
def get_last_human_utterance(ctx: Context, _) -> dict:
165
    return {"text": "", "annotations": {}} if ctx.validation else ctx.misc["agent"]["dialog"]["human_utterances"][-1]
166

167

168
def get_bot_utterances(ctx: Context, _) -> list:
169
    return [] if ctx.validation else ctx.misc["agent"]["dialog"]["bot_utterances"]
170

171

172
def get_last_bot_utterance(ctx: Context, _) -> dict:
173
    if not ctx.validation and ctx.misc["agent"]["dialog"]["bot_utterances"]:
174
        return ctx.misc["agent"]["dialog"]["bot_utterances"][-1]
175
    else:
176
        return {"text": "", "annotations": {}}
177

178

179
def save_to_shared_memory(ctx: Context, _, **kwargs):
180
    if not ctx.validation:
181
        ctx.misc["agent"]["shared_memory"].update(kwargs)
182

183

184
def update_used_links(ctx: Context, _, linked_skill_name, linking_phrase):
185
    if not ctx.validation:
186
        agent = ctx.misc["agent"]
187
        agent["used_links"][linked_skill_name] = agent["used_links"].get(linked_skill_name, []) + [linking_phrase]
188

189

190
def get_new_link_to(ctx: Context, pipeline: Pipeline, skill_names):
191
    used_links = get_used_links(ctx, pipeline)
192
    disliked_skills = get_disliked_skills(ctx, pipeline)
193

194
    link = common_link.link_to(
195
        skill_names, human_attributes={"used_links": used_links, "disliked_skills": disliked_skills}
196
    )
197
    update_used_links(ctx, pipeline, link["skill"], link["phrase"])
198
    return link
199

200

201
def set_dff_suspension(ctx: Context, _):
202
    if not ctx.validation:
203
        ctx.misc["agent"]["current_turn_dff_suspended"] = True
204

205

206
def reset_dff_suspension(ctx: Context, _):
207
    if not ctx.validation:
208
        ctx.misc["agent"]["current_turn_dff_suspended"] = False
209

210

211
def set_confidence(ctx: Context, pipeline: Pipeline, confidence=1.0):
212
    if not ctx.validation:
213
        ctx.misc["agent"]["response"].update({"confidence": confidence})
214
    if confidence == 0.0:
215
        reset_can_continue(ctx, pipeline)
216

217

218
def set_can_continue(ctx: Context, _, continue_flag=common_constants.CAN_CONTINUE_SCENARIO):
219
    if not ctx.validation:
220
        ctx.misc["agent"]["response"].update({"can_continue": continue_flag})
221

222

223
def reset_can_continue(ctx: Context, _):
224
    if not ctx.validation and "can_continue" in ctx.misc["agent"]["response"]:
225
        del ctx.misc["agent"]["response"]["can_continue"]
226

227

228
def get_named_entities_from_human_utterance(ctx: Context, pipeline: Pipeline):
229
    # ent is a dict! ent = {"text": "London":, "type": "LOC"}
230
    entities = common_utils.get_entities(
231
        get_last_human_utterance(ctx, pipeline),
232
        only_named=True,
233
        with_labels=True,
234
    )
235
    return entities
236

237

238
def get_nounphrases_from_human_utterance(ctx: Context, pipeline: Pipeline):
239
    nps = common_utils.get_entities(
240
        get_last_human_utterance(ctx, pipeline),
241
        only_named=False,
242
        with_labels=False,
243
    )
244
    return nps
245

246

247
def get_fact_random_annotations_from_human_utterance(ctx: Context, pipeline: Pipeline) -> dict:
248
    if not ctx.validation:
249
        return (
250
            get_last_human_utterance(ctx, pipeline)
251
            .get("annotations", {})
252
            .get("fact_random", {"facts": [], "response": ""})
253
        )
254
    else:
255
        return {"facts": [], "response": ""}
256

257

258
def get_fact_for_particular_entity_from_human_utterance(ctx: Context, pipeline: Pipeline, entity) -> list:
259
    fact_random_results = get_fact_random_annotations_from_human_utterance(ctx, pipeline)
260
    facts_for_entity = []
261
    for fact in fact_random_results.get("facts", []):
262
        is_same_entity = fact.get("entity_substr", "").lower() == entity.lower()
263
        is_sorry = "Sorry, I don't know" in fact.get("fact", "")
264
        if is_same_entity and not is_sorry:
265
            facts_for_entity += [fact["fact"]]
266

267
    return facts_for_entity
268

269

270
def get_news_about_particular_entity_from_human_utterance(ctx: Context, pipeline: Pipeline, entity) -> dict:
271
    last_uttr = get_last_human_utterance(ctx, pipeline)
272
    last_uttr_entities_news = last_uttr.get("annotations", {}).get("news_api_annotator", [])
273
    curr_news = {}
274
    for news_entity in last_uttr_entities_news:
275
        if news_entity["entity"] == entity:
276
            curr_news = news_entity["news"]
277
            break
278
    if not curr_news:
279
        curr_news = common_news.get_news_about_topic(entity, NEWS_API_ANNOTATOR_URL)
280

281
    return curr_news
282

283

284
def get_facts_from_fact_retrieval(ctx: Context, pipeline: Pipeline) -> list:
285
    annotations = get_last_human_utterance(ctx, pipeline).get("annotations", {})
286
    if "fact_retrieval" in annotations:
287
        if isinstance(annotations["fact_retrieval"], dict):
288
            return annotations["fact_retrieval"].get("facts", [])
289
        elif isinstance(annotations["fact_retrieval"], list):
290
            return annotations["fact_retrieval"]
291
    return []
292

293

294
def get_unrepeatable_index_from_rand_seq(
295
    ctx: Context, pipeline: Pipeline, seq_name, seq_max, renew_seq_if_empty=False
296
) -> int:
297
    """Return a unrepeatable index from RANDOM_SEQUENCE.
298
    RANDOM_SEQUENCE is stored in shared merory by name `seq_name`.
299
    RANDOM_SEQUENCE is shuffled [0..`seq_max`].
300
    RANDOM_SEQUENCE will be updated after index will get out of RANDOM_SEQUENCE if `renew_seq_if_empty` is True
301
    """
302
    shared_memory = get_shared_memory(ctx, pipeline)
303
    seq = shared_memory.get(seq_name, random.sample(list(range(seq_max)), seq_max))
304
    if renew_seq_if_empty or seq:
305
        seq = seq if seq else random.sample(list(range(seq_max)), seq_max)
306
        next_index = seq[-1] if seq else None
307
        save_to_shared_memory(ctx, **{seq_name: seq[:-1]})
308
        return next_index
309

310

311
def get_history(ctx: Context, _):
312
    if not ctx.validation:
313
        return ctx.misc["agent"]["history"]
314
    return {}
315

316

317
def get_n_last_state(ctx: Context, pipeline: Pipeline, n) -> str:
318
    last_state = ""
319
    history = list(get_history(ctx, pipeline).items())
320
    if history:
321
        history_sorted = sorted(history, key=lambda x: x[0])
322
        if len(history_sorted) >= n:
323
            last_state = history_sorted[-n][1]
324
    return last_state
325

326

327
def get_last_state(ctx: Context, pipeline: Pipeline) -> str:
328
    last_state = ""
329
    history = list(get_history(ctx, pipeline).items())
330
    if history:
331
        history_sorted = sorted(history, key=lambda x: x[0])
332
        last_state = history_sorted[-1][1]
333
    return last_state
334

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.