GenerativeAIExamples
118 строк · 3.6 Кб
1# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2# SPDX-License-Identifier: Apache-2.0
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import argparse
17import logging
18
19from evaluator import eval_llm_judge, eval_ragas
20from llm_answer_generator import generate_answers
21
22logging.basicConfig(level=logging.INFO)
23logger = logging.getLogger(__name__)
24
25if __name__ == "__main__":
26
27parser = argparse.ArgumentParser()
28parser.add_argument("--base_url", type=str, help="Specify the base URL to be used.")
29parser.add_argument(
30"--generate_answer",
31type=bool,
32nargs="?",
33default=False,
34help="Specify if 'answer' generation by RAG pipeline is required.",
35)
36parser.add_argument(
37"--evaluate",
38type=bool,
39nargs="?",
40default=True,
41help="Specify if evaluation is required.",
42)
43parser.add_argument(
44"--docs",
45type=str,
46nargs="?",
47default="",
48help="Specify the folder path for dataset.",
49)
50parser.add_argument(
51"--ga_input",
52type=str,
53nargs="?",
54default="",
55help="Specify the .json file with QnA pair for generating answers by RAG pipeline.",
56)
57parser.add_argument(
58"--ga_output",
59type=str,
60nargs="?",
61default="",
62help="Specify the .JSON file path for generated answers along with QnA.",
63)
64parser.add_argument(
65"--ev_input",
66type=str,
67nargs="?",
68default="",
69help="Specify the .JSON file path with 'question','gt_answer','gt_context','answer',and 'contexts'.",
70)
71parser.add_argument(
72"--ev_result",
73type=str,
74nargs="?",
75default="",
76help="Specify the file path to store evaluation results.",
77)
78parser.add_argument(
79"--metrics",
80type=str,
81nargs="?",
82default="judge_llm",
83choices=["ragas", "judge_llm"],
84help="Specify evaluation metrics between ragas and judge-llm.",
85)
86parser.add_argument(
87"--judge_llm_model",
88type=str,
89nargs="?",
90default="ai-mixtral-8x7b-instruct",
91help="Specify the LLM model to be used as judge llm for evaluation from ChatNVIDIA catalog."
92)
93args = parser.parse_args()
94
95if args.generate_answer:
96generate_answers(
97base_url=args.base_url,
98dataset_folder_path=args.docs,
99qa_generation_file_path=args.ga_input,
100eval_file_path=args.ga_output,
101)
102
103logger.info("\nANSWERS GENERATED\n")
104if args.evaluate:
105if args.metrics == "ragas":
106eval_ragas(
107ev_file_path=args.ev_input,
108ev_result_path=args.ev_result,
109llm_model=args.judge_llm_model,
110)
111logger.info("\nRAG EVALUATED WITH RAGAS METRICS\n")
112elif args.metrics == "judge_llm":
113eval_llm_judge(
114ev_file_path=args.ev_input,
115ev_result_path=args.ev_result,
116llm_model=args.judge_llm_model,
117)
118logger.info("\nRAG EVALUATED WITH JUDGE LLM\n")
119