dream
175 строк · 6.0 Кб
1#!/usr/bin/env python
2
3import logging
4import os
5from typing import List, Any
6
7import sentry_sdk
8from sentry_sdk.integrations.logging import ignore_logger
9from dff.script import Context, MultiMessage
10from dff.pipeline import Pipeline
11from pydantic import BaseModel, Field, Extra, root_validator
12
13from common.constants import CAN_NOT_CONTINUE, CAN_CONTINUE_SCENARIO
14from common.dff_api_v1.integration.context import get_last_human_utterance
15from common.dff_api_v1.integration.message import DreamMessage
16
17
18ignore_logger("root")
19
20sentry_sdk.init(os.getenv("SENTRY_DSN"))
21SERVICE_NAME = os.getenv("SERVICE_NAME")
22
23logger = logging.getLogger(__name__)
24
25
26class ExtraIgnoreModel(BaseModel):
27class Config:
28extra = Extra.ignore
29
30
31class HumanAttr(ExtraIgnoreModel):
32dff_shared_state: dict
33used_links: dict
34age_group: str
35disliked_skills: list
36
37
38class HypeAttr(ExtraIgnoreModel):
39can_continue: str
40
41@root_validator(pre=True)
42def calculate_can_continue(cls, values):
43confidence = values["response"].get("confidence", 0.85)
44can_continue = CAN_CONTINUE_SCENARIO if confidence else CAN_NOT_CONTINUE
45values["can_continue"] = values["response"].get("can_continue", can_continue)
46return values
47
48
49class State(ExtraIgnoreModel):
50context: Context
51previous_human_utter_index: int = -1
52current_turn_dff_suspended: bool = False
53history: dict = Field(default_factory=dict)
54shared_memory: dict = Field(default_factory=dict)
55
56@root_validator(pre=True)
57def shift_state_history(cls, values):
58values["previous_human_utter_index"] = values["human_utter_index"]
59values["history"][str(values["human_utter_index"])] = list(values["context"].labels.values())[-1]
60return values
61
62@root_validator(pre=True)
63def validate_context(cls, values):
64context = values["context"]
65context.clear(2, ["requests", "responses", "labels"])
66del context.misc["agent"]
67return values
68
69
70class Agent(ExtraIgnoreModel):
71previous_human_utter_index: int = -1
72human_utter_index: int
73dialog: Any
74entities: dict = Field(default_factory=dict)
75shared_memory: dict = Field(default_factory=dict)
76current_turn_dff_suspended: bool = False
77previous_turn_dff_suspended: bool = False
78response: dict = Field(default_factory=dict)
79dff_shared_state: dict = Field(default_factory=dict)
80cache: dict = Field(default_factory=dict)
81history: dict = Field(default_factory=dict)
82used_links: dict = Field(default_factory=dict)
83age_group: str = ""
84disliked_skills: list = Field(default_factory=list)
85clarification_request_flag: bool = False
86
87@root_validator(pre=True)
88def get_state_props(cls, values):
89state = values.get("state", {})
90values = values | state
91return values
92
93
94def load_ctxs(requested_data) -> List[Context]:
95dialog_batch = requested_data.get("dialog_batch", [])
96human_utter_index_batch = requested_data.get("human_utter_index_batch", [0] * len(dialog_batch))
97state_batch = requested_data.get(f"{SERVICE_NAME}_state_batch", [{}] * len(dialog_batch))
98dff_shared_state_batch = requested_data.get("dff_shared_state_batch", [{}] * len(dialog_batch))
99entities_batch = requested_data.get("entities_batch", [{}] * len(dialog_batch))
100used_links_batch = requested_data.get("used_links_batch", [{}] * len(dialog_batch))
101age_group_batch = requested_data.get("age_group_batch", [""] * len(dialog_batch))
102disliked_skills_batch = requested_data.get("disliked_skills_batch", [{}] * len(dialog_batch))
103clarification_request_flag_batch = requested_data.get(
104"clarification_request_flag_batch",
105[False] * len(dialog_batch),
106)
107ctxs = []
108for (
109human_utter_index,
110dialog,
111state,
112dff_shared_state,
113entities,
114used_links,
115age_group,
116disliked_skills,
117clarification_request_flag,
118) in zip(
119human_utter_index_batch,
120dialog_batch,
121state_batch,
122dff_shared_state_batch,
123entities_batch,
124used_links_batch,
125age_group_batch,
126disliked_skills_batch,
127clarification_request_flag_batch,
128):
129ctx = Context.cast(state.get("context", {}))
130agent = Agent(
131human_utter_index=human_utter_index,
132dialog=dialog,
133state=state,
134dff_shared_state=dff_shared_state,
135entities=entities,
136used_links=used_links,
137age_group=age_group,
138disliked_skills=disliked_skills,
139clarification_request_flag=clarification_request_flag,
140)
141ctx.misc["agent"] = agent.dict()
142ctxs += [ctx]
143return ctxs
144
145
146def get_response(ctx: Context, _):
147agent = ctx.misc["agent"]
148response_parts = agent.get("response_parts", [])
149confidence = agent["response"].get("confidence", 0.85)
150state = State(context=ctx, **agent).dict(exclude_none=True)
151human_attr = HumanAttr.parse_obj(agent).dict() | {f"{SERVICE_NAME}_state": state}
152hype_attr = HypeAttr.parse_obj(agent).dict() | ({"response_parts": response_parts} if response_parts else {})
153response = ctx.last_response
154if isinstance(response, MultiMessage):
155responses = []
156message: dict
157for message in response.messages:
158reply = message.text or ""
159conf = message.confidence or confidence
160h_a = human_attr | (message.human_attr or {})
161attr = hype_attr | (message.hype_attr or {})
162b_a = message.bot_attr or {}
163responses += [(reply, conf, h_a, b_a, attr)]
164return list(zip(*responses))
165else:
166return (response.text, confidence, human_attr, {}, hype_attr)
167
168
169def run_dff(ctx: Context, pipeline: Pipeline):
170last_request = get_last_human_utterance(ctx, pipeline.actor)["text"]
171pipeline.context_storage[ctx.id] = ctx
172ctx = pipeline(DreamMessage(text=last_request), ctx.id)
173response = get_response(ctx, pipeline.actor)
174del pipeline.context_storage[ctx.id]
175return response
176