dream

Форк
0
155 строк · 5.1 Кб
1
import requests
2
from deeppavlov_kg import TerminusdbKnowledgeGraph
3

4

5
def formulate_utt_annotations(dog_id=None, park_id=None):
6
    utt_annotations = {
7
        "property_extraction": [
8
            {
9
                "triplets": [
10
                    {"subject": "user", "relation": "HAVE PET", "object": "dog"},
11
                    {"subject": "user", "relation": "LIKE GOTO", "object": "park"},
12
                ]
13
            }
14
        ],
15
        "custom_entity_linking": [],
16
    }
17

18
    # if dog is in kg add it to custom_el annotations
19
    if dog_id is not None:
20
        utt_annotations["custom_entity_linking"].append(
21
            {
22
                "entity_substr": "dog",
23
                "entity_ids": [dog_id],
24
                "confidences": [1.0],
25
                "tokens_match_conf": [1.0],
26
                "entity_id_tags": ["Animal"],
27
            },
28
        )
29
    if park_id is not None:
30
        utt_annotations["custom_entity_linking"].append(
31
            {
32
                "entity_substr": "park",
33
                "entity_ids": [park_id],
34
                "confidences": [1.0],
35
                "tokens_match_conf": [1.0],
36
                "entity_id_tags": ["Place"],
37
            },
38
        )
39

40
    return utt_annotations
41

42

43
def prepare_for_comparison(results):
44
    for result in results:
45
        if uttrs := result["added_to_graph"]:
46
            for utt in uttrs:
47
                for triplet in utt:
48
                    triplet[2] = triplet[2].split("/")[0]
49
        if uttrs := result["triplets_already_in_graph"]:
50
            for utt in uttrs:
51
                for triplet in utt:
52
                    triplet[2] = triplet[2].split("/")[0]
53

54
    return results
55

56

57
def compare_results(results, golden_results) -> bool:
58
    def compare(uttrs, golden_result):
59
        for idx, utt in enumerate(uttrs):
60
            for triplet in utt:
61
                if triplet not in golden_result[idx]:
62
                    return False
63
        return True
64

65
    is_successfull = []
66
    for result, golden_result in zip(results, golden_results):
67
        is_added = compare(result["added_to_graph"], golden_result["added_to_graph"])
68
        is_in_graph = compare(result["triplets_already_in_graph"], golden_result["triplets_already_in_graph"])
69
        is_successfull.append(is_added)
70
        is_successfull.append(is_in_graph)
71
    return all(is_successfull)
72

73

74
def main():
75
    TERMINUSDB_SERVER_URL = "http://0.0.0.0:6363"
76
    TERMINUSDB_SERVER_TEAM = "admin"
77
    TERMINUSDB_SERVER_DB = "bot_knowledge_db"
78
    TERMINUSDB_SERVER_PASSWORD = "root"
79
    BOT_KNOWLEDGE_MEMORIZER_PORT = 8044
80

81
    BOT_KNOWLEDGE_MEMORIZER_URL = f"http://0.0.0.0:{BOT_KNOWLEDGE_MEMORIZER_PORT}/respond"
82

83
    graph = TerminusdbKnowledgeGraph(
84
        db_name=TERMINUSDB_SERVER_DB,
85
        team=TERMINUSDB_SERVER_TEAM,
86
        server=TERMINUSDB_SERVER_URL,
87
        password=TERMINUSDB_SERVER_PASSWORD,
88
    )
89

90
    BOT_ID = "Bot/514b2c3d-bb73-4294-9486-04f9e099835e"
91
    # get dog_id and park_id from KG
92
    dog_id, park_id = None, None
93
    try:
94
        user_props = graph.get_properties_of_entity(BOT_ID)
95
        entities_info = graph.get_properties_of_entities(
96
            [*user_props["HAVE PET/Animal"], *user_props["LIKE GOTO/Place"]]
97
        )
98
        for entity_info in entities_info:
99
            if entity_info.get("substr") == "dog":
100
                dog_id = entity_info["@id"]
101
            elif entity_info.get("substr") == "park":
102
                park_id = entity_info["@id"]
103
        print(f"Found park_id: '{park_id}' and dog_ig: '{dog_id}'")
104
        added_new_entities = False
105
    except Exception:
106
        print("Adding new entities and rels")
107
        added_new_entities = True
108

109
    request_data = [
110
        {
111
            "utterances": [
112
                {
113
                    "text": "i have a dog and a cat",
114
                    "user": {"id": BOT_ID.split("/")[1]},
115
                    "annotations": formulate_utt_annotations(dog_id, park_id),
116
                },
117
                {
118
                    "text": "",
119
                    "user": {"id": ""},
120
                    "annotations": {
121
                        "property_extraction": [{}],
122
                        "custom_entity_linking": [],
123
                    },
124
                },
125
            ],
126
            "human_utterances": [
127
                {
128
                    "text": "What's your dog's name?",
129
                },
130
                {
131
                    "text": "",
132
                },
133
            ],
134
        }
135
    ]
136

137
    golden_triplets = [[[BOT_ID, "LIKE GOTO", "Place"], [BOT_ID, "HAVE PET", "Animal"]], []]
138
    if added_new_entities:
139
        golden_results = [[{"added_to_graph": golden_triplets, "triplets_already_in_graph": [[], []]}]]
140
    else:
141
        golden_results = [[{"added_to_graph": [[], []], "triplets_already_in_graph": golden_triplets}]]
142

143
    count = 0
144
    for data, golden_result in zip(request_data, golden_results):
145
        result = requests.post(BOT_KNOWLEDGE_MEMORIZER_URL, json=data).json()
146
        print(result)
147
        result = prepare_for_comparison(result)
148
        if compare_results(result, golden_result):
149
            count += 1
150
    assert count == len(request_data)
151
    print("Success")
152

153

154
if __name__ == "__main__":
155
    main()
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.