GenerativeAIExamples

Форк
0
233 строки · 10.2 Кб
1
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
# SPDX-License-Identifier: Apache-2.0
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import json
17
import logging
18
import os
19
import statistics
20

21
from datasets import Dataset
22
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
23
from ragas import evaluate
24
from ragas.llms import LangchainLLMWrapper
25
from ragas.embeddings import LangchainEmbeddingsWrapper
26
from ragas.metrics import (
27
    answer_relevancy,
28
    context_precision,
29
    context_recall,
30
    context_relevancy,
31
    faithfulness,
32
    answer_similarity
33
)
34

35
LLAMA_PROMPT_TEMPLATE = (
36
    "<s>[INST] <<SYS>>"
37
    "{system_prompt}"
38
    "<</SYS>>"
39
    ""
40
    "Example 1:"
41
    "[Question]"
42
    "When did Queen Elizabeth II die?"
43
    "[The Start of the Reference Context]"
44
    """On 8 September 2022, Buckingham Palace released a statement which read: "Following further evaluation this morning, the Queen's doctors are concerned for Her Majesty's health and have recommended she remain under medical supervision. The Queen remains comfortable and at Balmoral."[257][258] Her immediate family rushed to Balmoral to be by her side.[259][260] She died peacefully at 15:10 BST at the age of 96, with two of her children, Charles and Anne, by her side;[261][262] Charles immediately succeeded as monarch. Her death was announced to the public at 18:30,[263][264] setting in motion Operation London Bridge and, because she died in Scotland, Operation Unicorn.[265][266] Elizabeth was the first monarch to die in Scotland since James V in 1542.[267] Her death certificate recorded her cause of death as old age"""
45
    "[The End of Reference Context]"
46
    "[The Start of the Reference Answer]"
47
    "Queen Elizabeth II died on September 8, 2022."
48
    "[The End of Reference Answer]"
49
    "[The Start of the Assistant's Answer]"
50
    "She died on September 8, 2022"
51
    "[The End of Assistant's Answer]"
52
    '"Rating": 5, "Explanation": "The answer is helpful, relevant, accurate, and concise. It matches the information provided in the reference context and answer."'
53
    ""
54
    "Example 2:"
55
    "[Question]"
56
    "When did Queen Elizabeth II die?"
57
    "[The Start of the Reference Context]"
58
    """On 8 September 2022, Buckingham Palace released a statement which read: "Following further evaluation this morning, the Queen's doctors are concerned for Her Majesty's health and have recommended she remain under medical supervision. The Queen remains comfortable and at Balmoral."[257][258] Her immediate family rushed to Balmoral to be by her side.[259][260] She died peacefully at 15:10 BST at the age of 96, with two of her children, Charles and Anne, by her side;[261][262] Charles immediately succeeded as monarch. Her death was announced to the public at 18:30,[263][264] setting in motion Operation London Bridge and, because she died in Scotland, Operation Unicorn.[265][266] Elizabeth was the first monarch to die in Scotland since James V in 1542.[267] Her death certificate recorded her cause of death as old age"""
59
    "[The End of Reference Context]"
60
    "[The Start of the Reference Answer]"
61
    "Queen Elizabeth II died on September 8, 2022."
62
    "[The End of Reference Answer]"
63
    "[The Start of the Assistant's Answer]"
64
    "Queen Elizabeth II was the longest reigning monarch of the United Kingdom and the Commonwealth."
65
    "[The End of Assistant's Answer]"
66
    '"Rating": 1, "Explanation": "The answer is not helpful or relevant. It does not answer the question and instead goes off topic."'
67
    ""
68
    "Follow the exact same format as above. Put Rating first and Explanation second. Rating must be between 1 and 5. What is the rating and explanation for the following assistant's answer"
69
    "Rating and Explanation should be in JSON format"
70
    "[Question]"
71
    "{question}"
72
    "[The Start of the Reference Context]"
73
    "{ctx_ref}"
74
    "[The End of Reference Context]"
75
    "[The Start of the Reference Answer]"
76
    "{answer_ref}"
77
    "[The End of Reference Answer]"
78
    "[The Start of the Assistant's Answer]"
79
    "{answer}"
80
    "[The End of Assistant's Answer][/INST]"
81
)
82
SYS_PROMPT = """
83
    You are an impartial judge that evaluates the quality of an assistant's answer to the question provided.
84
    You evaluation takes into account helpfullness, relevancy, accuracy, and level of detail of the answer.
85
    You must use both the reference context and reference answer to guide your evaluation.
86
    """
87

88
logging.basicConfig(level=logging.INFO)
89
logger = logging.getLogger(__name__)
90

91
def calculate_ragas_score(row):
92
    values = row[['faithfulness', 'context_relevancy', 'answer_relevancy','context_recall']].values
93
    return statistics.harmonic_mean(values)
94

95
def eval_ragas(ev_file_path, ev_result_path,llm_model='ai-mixtral-8x7b-instruct'):
96
    """
97
    This function evaluates a language model's performance using a dataset and metrics.
98
    It sets the NVAPI_KEY, initializes a ChatNVIDIA model and LangchainLLM object, loads the
99
    evaluation dataset, prepares data samples, creates a Dataset object, sets the language model
100
    for each metric, and evaluates the model with the specified metrics, printing the results.
101
    """
102
    llm_params={
103
        "temperature": 0.1,
104
        "max_tokens": 200,
105
        "top_p": 1.0,
106
        "stream": False,}
107
    nvidia_api_key = os.environ["NVIDIA_API_KEY"]
108
    llm_params["nvidia_api_key"]=nvidia_api_key
109
    llm_params["model"]=llm_model
110
    llm = ChatNVIDIA(**llm_params)
111
    nvpl_llm = LangchainLLMWrapper(langchain_llm=llm)
112
    embeddings = NVIDIAEmbeddings(model="ai-embed-qa-4", model_type="passage")
113
    nvpl_embeddings = LangchainEmbeddingsWrapper(embeddings)
114
    try:
115
        with open(ev_file_path, "r", encoding="utf-8") as file:
116
            json_data = json.load(file)
117
    except Exception as e:
118
        logger.info(f"Error Occured while loading file : {e}")
119
    eval_questions = []
120
    eval_answers = []
121
    ground_truth = []
122
    vdb_contexts = []
123
    for entry in json_data:
124
        eval_questions.append(entry["question"])
125
        eval_answers.append(entry["generated_answer"])
126
        vdb_contexts.append(entry["retrieved_context"])
127
        ground_truth.append(entry["ground_truth_answer"])
128
    
129
    data_samples = {
130
            'question': eval_questions,
131
            'answer': eval_answers,
132
            'contexts': vdb_contexts,
133
            'ground_truth': ground_truth,
134
        }
135
    dataset = Dataset.from_dict(data_samples)
136
    
137
    result = evaluate(
138
        dataset,
139
        llm=llm,
140
        embeddings=nvpl_embeddings,
141
        metrics=[
142
            answer_similarity,
143
            faithfulness,
144
            context_precision,
145
            context_relevancy,
146
            answer_relevancy,
147
            context_recall
148
        ],
149
    )
150
    df = result.to_pandas()
151
    df['ragas_score']=df.apply(calculate_ragas_score,axis=1)
152
    df.to_parquet(ev_result_path+'.parquet')
153
    result['ragas_score']= statistics.harmonic_mean([result['faithfulness'], result['context_relevancy'], result['answer_relevancy'], result['context_recall']])
154
    with open(ev_result_path+'.json', "w", encoding="utf-8") as json_file:
155
        json.dump(result, json_file, indent=2)
156

157
    logger.info(f"Results written to {ev_result_path}.json and {ev_result_path}.parquet")
158

159

160
def eval_llm_judge(
161
    ev_file_path,
162
    ev_result_path,
163
    llm_model='ai-mixtral-8x7b-instruct'
164
):  
165
    """
166
    The function utilizes pre-trained Judge LLM to assess the coherence and relevance of a generated answer
167
    for a given question and context. It returns a Likert rating between 1 and 5, indicating the quality of
168
    the answer and an explanation supporting the same, returns the mean of likert rating, dumping the same in JSON format.
169
    """
170
    llm_params={
171
        "temperature": 0.1,
172
        "max_tokens": 200,
173
        "top_p": 1.0,
174
        "stream": False,}
175
    nvidia_api_key = os.environ["NVIDIA_API_KEY"]
176
    llm_params["nvidia_api_key"]=nvidia_api_key
177
    llm_params["model"]=llm_model
178
    
179
    llm = ChatNVIDIA()
180
    # Read the JSON file
181
    try:
182
        with open(ev_file_path, "r", encoding="utf-8") as file:
183
            data = json.load(file)
184
    except Exception as e:
185
        logger.info(f"Error Occured while loading file : {e}")
186

187
    llama_ratings = []
188
    llama_explanations = []
189
    for d in data:
190
        try:
191
            context = LLAMA_PROMPT_TEMPLATE.format(
192
                system_prompt=SYS_PROMPT,
193
                question=d["question"],
194
                ctx_ref=d["ground_truth_context"],
195
                answer_ref=d["ground_truth_answer"],
196
                answer=d["answer"],
197
            )
198

199
            response = llm.invoke(context)
200
            response_body = json.loads(response.content)
201
            rating = response_body["Rating"]
202
            explanantion = response_body["Explanantion"]
203
            llama_ratings.append(rating)
204
            llama_explanations.append(explanantion)
205
            logger.info(f"progress: {len(llama_explanations)}/{len(data)}")
206
        except Exception as e:
207
            logger.info(f"Exception Occured: {e}")
208
            llama_ratings.append(None)
209

210
    logger.info(f"Number of judgements: {len(llama_ratings)}")
211

212
    llama_ratings = [1 if r == 0 else r for r in llama_ratings]  # Change 0 ratings to 1
213
    llama_ratings_filtered = [r for r in llama_ratings if r]  # Remove empty ratings
214

215
    mean = round(statistics.mean(llama_ratings_filtered), 1)
216
    logger.info(f"Number of ratings: {len(llama_ratings_filtered)}")
217
    logger.info(f"Mean rating: {mean}")
218

219
    results = list(
220
        zip(
221
            llama_ratings,
222
            llama_explanations,
223
            [d["question"] for d in data],
224
            [d["answer"] for d in data],
225
            [d["ground_truth_answer"] for d in data],
226
            [d["ground_truth_context"] for d in data],
227
        )
228
    )
229

230
    with open(ev_result_path, "w", encoding="utf-8") as json_file:
231
        json.dump(results, json_file, indent=2)
232

233
    logger.info(f"Results written to {ev_result_path}")

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.