dream

Форк
0
175 строк · 6.0 Кб
1
#!/usr/bin/env python
2

3
import logging
4
import os
5
from typing import List, Any
6

7
import sentry_sdk
8
from sentry_sdk.integrations.logging import ignore_logger
9
from dff.script import Context, MultiMessage
10
from dff.pipeline import Pipeline
11
from pydantic import BaseModel, Field, Extra, root_validator
12

13
from common.constants import CAN_NOT_CONTINUE, CAN_CONTINUE_SCENARIO
14
from common.dff_api_v1.integration.context import get_last_human_utterance
15
from common.dff_api_v1.integration.message import DreamMessage
16

17

18
ignore_logger("root")
19

20
sentry_sdk.init(os.getenv("SENTRY_DSN"))
21
SERVICE_NAME = os.getenv("SERVICE_NAME")
22

23
logger = logging.getLogger(__name__)
24

25

26
class ExtraIgnoreModel(BaseModel):
27
    class Config:
28
        extra = Extra.ignore
29

30

31
class HumanAttr(ExtraIgnoreModel):
32
    dff_shared_state: dict
33
    used_links: dict
34
    age_group: str
35
    disliked_skills: list
36

37

38
class HypeAttr(ExtraIgnoreModel):
39
    can_continue: str
40

41
    @root_validator(pre=True)
42
    def calculate_can_continue(cls, values):
43
        confidence = values["response"].get("confidence", 0.85)
44
        can_continue = CAN_CONTINUE_SCENARIO if confidence else CAN_NOT_CONTINUE
45
        values["can_continue"] = values["response"].get("can_continue", can_continue)
46
        return values
47

48

49
class State(ExtraIgnoreModel):
50
    context: Context
51
    previous_human_utter_index: int = -1
52
    current_turn_dff_suspended: bool = False
53
    history: dict = Field(default_factory=dict)
54
    shared_memory: dict = Field(default_factory=dict)
55

56
    @root_validator(pre=True)
57
    def shift_state_history(cls, values):
58
        values["previous_human_utter_index"] = values["human_utter_index"]
59
        values["history"][str(values["human_utter_index"])] = list(values["context"].labels.values())[-1]
60
        return values
61

62
    @root_validator(pre=True)
63
    def validate_context(cls, values):
64
        context = values["context"]
65
        context.clear(2, ["requests", "responses", "labels"])
66
        del context.misc["agent"]
67
        return values
68

69

70
class Agent(ExtraIgnoreModel):
71
    previous_human_utter_index: int = -1
72
    human_utter_index: int
73
    dialog: Any
74
    entities: dict = Field(default_factory=dict)
75
    shared_memory: dict = Field(default_factory=dict)
76
    current_turn_dff_suspended: bool = False
77
    previous_turn_dff_suspended: bool = False
78
    response: dict = Field(default_factory=dict)
79
    dff_shared_state: dict = Field(default_factory=dict)
80
    cache: dict = Field(default_factory=dict)
81
    history: dict = Field(default_factory=dict)
82
    used_links: dict = Field(default_factory=dict)
83
    age_group: str = ""
84
    disliked_skills: list = Field(default_factory=list)
85
    clarification_request_flag: bool = False
86

87
    @root_validator(pre=True)
88
    def get_state_props(cls, values):
89
        state = values.get("state", {})
90
        values = values | state
91
        return values
92

93

94
def load_ctxs(requested_data) -> List[Context]:
95
    dialog_batch = requested_data.get("dialog_batch", [])
96
    human_utter_index_batch = requested_data.get("human_utter_index_batch", [0] * len(dialog_batch))
97
    state_batch = requested_data.get(f"{SERVICE_NAME}_state_batch", [{}] * len(dialog_batch))
98
    dff_shared_state_batch = requested_data.get("dff_shared_state_batch", [{}] * len(dialog_batch))
99
    entities_batch = requested_data.get("entities_batch", [{}] * len(dialog_batch))
100
    used_links_batch = requested_data.get("used_links_batch", [{}] * len(dialog_batch))
101
    age_group_batch = requested_data.get("age_group_batch", [""] * len(dialog_batch))
102
    disliked_skills_batch = requested_data.get("disliked_skills_batch", [{}] * len(dialog_batch))
103
    clarification_request_flag_batch = requested_data.get(
104
        "clarification_request_flag_batch",
105
        [False] * len(dialog_batch),
106
    )
107
    ctxs = []
108
    for (
109
        human_utter_index,
110
        dialog,
111
        state,
112
        dff_shared_state,
113
        entities,
114
        used_links,
115
        age_group,
116
        disliked_skills,
117
        clarification_request_flag,
118
    ) in zip(
119
        human_utter_index_batch,
120
        dialog_batch,
121
        state_batch,
122
        dff_shared_state_batch,
123
        entities_batch,
124
        used_links_batch,
125
        age_group_batch,
126
        disliked_skills_batch,
127
        clarification_request_flag_batch,
128
    ):
129
        ctx = Context.cast(state.get("context", {}))
130
        agent = Agent(
131
            human_utter_index=human_utter_index,
132
            dialog=dialog,
133
            state=state,
134
            dff_shared_state=dff_shared_state,
135
            entities=entities,
136
            used_links=used_links,
137
            age_group=age_group,
138
            disliked_skills=disliked_skills,
139
            clarification_request_flag=clarification_request_flag,
140
        )
141
        ctx.misc["agent"] = agent.dict()
142
        ctxs += [ctx]
143
    return ctxs
144

145

146
def get_response(ctx: Context, _):
147
    agent = ctx.misc["agent"]
148
    response_parts = agent.get("response_parts", [])
149
    confidence = agent["response"].get("confidence", 0.85)
150
    state = State(context=ctx, **agent).dict(exclude_none=True)
151
    human_attr = HumanAttr.parse_obj(agent).dict() | {f"{SERVICE_NAME}_state": state}
152
    hype_attr = HypeAttr.parse_obj(agent).dict() | ({"response_parts": response_parts} if response_parts else {})
153
    response = ctx.last_response
154
    if isinstance(response, MultiMessage):
155
        responses = []
156
        message: dict
157
        for message in response.messages:
158
            reply = message.text or ""
159
            conf = message.confidence or confidence
160
            h_a = human_attr | (message.human_attr or {})
161
            attr = hype_attr | (message.hype_attr or {})
162
            b_a = message.bot_attr or {}
163
            responses += [(reply, conf, h_a, b_a, attr)]
164
        return list(zip(*responses))
165
    else:
166
        return (response.text, confidence, human_attr, {}, hype_attr)
167

168

169
def run_dff(ctx: Context, pipeline: Pipeline):
170
    last_request = get_last_human_utterance(ctx, pipeline.actor)["text"]
171
    pipeline.context_storage[ctx.id] = ctx
172
    ctx = pipeline(DreamMessage(text=last_request), ctx.id)
173
    response = get_response(ctx, pipeline.actor)
174
    del pipeline.context_storage[ctx.id]
175
    return response
176

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.