MemGPT

Форк
0
/
llm_judge_doc_qa.py 
157 строк · 5.7 Кб
1
import argparse
2
import json
3
import re
4

5
from openai import OpenAI
6
from tqdm import tqdm
7

8
from memgpt.credentials import MemGPTCredentials
9

10
# Note: did not end up using since no cases of cheating were observed
11
# CHEATING_PROMPT = \
12
#    """
13
#    You are evaluating an LLM trying to do task.
14
#    The main point of failure is the LLM may cheat and hallucinate the document it is meant to provide as evidence.
15
#    For example, even though the LLM has been instructed to provide the document that it's using as evidence to answer the question, in this case the LLM hallucinates a document:
16
#
17
#    The format of the LLM response with be:
18
#    ANSWER: <answer>, DOCUMENT: <document_text>
19
#
20
#    We also provide the list of documents that the LLM should have used to answer the question.
21
#    This is provided with the format:
22
#    ALL DOCUMENTS: <list of documents>
23
#
24
#    Your task is to ensure the LLM did not cheat and hallucinate the document.
25
#    You can ensure this by checking if the document provided by the LLM is in the list of documents provided.
26
#
27
#    Write a short reason for your answer before stating the final answer (format <your reason> - CHEAT/CORRECT).
28
#    """
29

30
EVAL_PROMPT = """
31
    Your task is to evaluate whether an LLM correct answered a question.
32
    The LLM response should be the format 'ANSWER: <answer>, DOCUMENT: <document_text>' or say 'INSUFFICIENT INFORMATION'.
33
    The true answer is provided in the format 'TRUE ANSWER: <list of possible answers>'.
34
    The questions is provided in the format 'QUESTION: <question>'.
35
    If the LLM response contains both the correct answer and corresponding document text, the response is correct.
36
    Even if the LLM's answer and the true answer are slightly different in wording, the response is still correct.
37
    For example, if the answer is more specific than the true answer or uses a different phrasing that is still correct, the response is correct.
38
    If the LLM response if 'INSUFFICIENT INFORMATION', or the 'DOCUMENT' field is missing, the response is incorrect.
39
    Respond with a single token: 'CORRECT' or 'INCORRECT'.
40
    """
41

42
EVAL_MODEL = "gpt-4-0613"
43

44

45
def evaluate_response(output: str):
46
    credentials = MemGPTCredentials().load()
47
    assert credentials.openai_key is not None, credentials.openai_key
48

49
    client = OpenAI(api_key=credentials.openai_key)
50

51
    chat_completion = client.chat.completions.create(
52
        messages=[
53
            {
54
                "role": "user",
55
                "content": "\n".join([EVAL_PROMPT, "\n", output, "\n"]),
56
            },
57
        ],
58
        model=EVAL_MODEL,
59
    )
60

61
    response = chat_completion.choices[0].message.content
62
    print("llm judge", response)
63
    if "INCORRECT" in response:
64
        return False
65
    elif "CORRECT" in response:
66
        return True
67
    else:
68
        print("INVALID RESPONSE", response)
69
        return False
70

71

72
# Grab the last thing MemGPT generated, treat it as the reply
73
def extract_final_memgpt_response(memgpt_responses: list) -> str:
74
    final_index = -1
75
    if "function_return" in memgpt_responses[final_index]:
76
        final_index = -2
77
    final_memgpt_response = [v for k, v in memgpt_responses[final_index].items()]
78
    final_memgpt_response = final_memgpt_response[-1]
79
    return final_memgpt_response
80

81

82
if __name__ == "__main__":
83
    parser = argparse.ArgumentParser(description="Test script")
84
    parser.add_argument("--file", type=str, help="File data to evaluate")
85
    parser.add_argument("--baseline", action="store_true", help="Whether to use the baseline model")
86
    args = parser.parse_args()
87

88
    # load data
89
    data = json.load(open(args.file))
90

91
    # counters
92
    correct = 0
93
    total = 0
94

95
    # Make an intial pass to determine how many documents had the correct answer
96
    results = []  # store all results
97
    eval_results = []  # store results that need LLM judge
98
    if args.baseline:
99
        # baseline experiment
100
        match = re.search(r"model_([^_]+)_num_docs_([^\.]+)\.json", args.file)
101
        model = match.group(1)
102
        num_docs = int(match.group(2))
103
        baseline = "baseline"
104
    else:
105
        # model = re.search(r"model_([^\.]+)\.json", args.file).group(1)
106
        model = re.search(r"model_([-\w.]+)(?:_num_docs_([-\d]+))?.json", args.file).group(1)
107

108
        num_docs = None
109
        baseline = "memgpt"
110

111
    # evaluate data
112
    for d in tqdm(data):
113
        answer = d["true_answers"]
114
        question = d["question"]
115
        response = d["memgpt_responses"]
116
        if not args.baseline:
117
            # need to parse response for memgpt
118
            response = extract_final_memgpt_response(response)
119
        else:
120
            response = response["response"]
121

122
        found = False
123
        for a in answer:
124
            if a in response:
125
                found = True
126

127
        if not found and not "INSUFFICIENT INFORMATION" in response:
128
            # inconclusive: pass to llm judge
129
            print(question)
130
            print(answer)
131
            print(response)
132
            print(args.baseline)
133
            doc = "QUESTION: " + question + "\n" + "TRUE ANSWER: " + str(answer) + "\n" + response
134
            judge = "llm"
135
            judge_result = evaluate_response(doc)
136
            print("JUDGEMENT", judge_result)
137
            if judge_result:
138
                correct += 1
139
                found = True
140
        elif found:
141
            # answer found in text
142
            correct += 1
143
            judge = "text"
144
        else:
145
            judge = "text"
146

147
        results.append({"question": question, "true_answers": answer, "response": response, "correct": found, "judge": judge})
148

149
        total += 1
150

151
    # Dump aggregated results
152
    json.dump(
153
        {"accuracy": correct / total, "total": total, "results": results},
154
        open(f"results_{model}_{num_docs}_{baseline}.json", "w"),
155
        indent=4,
156
    )
157
    print(correct / total)
158

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.