llama-index

Форк
0
77 строк · 2.3 Кб
1
"""Notebook utils."""
2

3
from collections import defaultdict
4
from typing import List, Optional, Tuple
5

6
import pandas as pd
7

8
from llama_index.legacy.evaluation import EvaluationResult
9
from llama_index.legacy.evaluation.retrieval.base import RetrievalEvalResult
10

11
DEFAULT_METRIC_KEYS = ["hit_rate", "mrr"]
12

13

14
def get_retrieval_results_df(
15
    names: List[str],
16
    results_arr: List[List[RetrievalEvalResult]],
17
    metric_keys: Optional[List[str]] = None,
18
) -> pd.DataFrame:
19
    """Display retrieval results."""
20
    metric_keys = metric_keys or DEFAULT_METRIC_KEYS
21

22
    avg_metrics_dict = defaultdict(list)
23
    for name, eval_results in zip(names, results_arr):
24
        metric_dicts = []
25
        for eval_result in eval_results:
26
            metric_dict = eval_result.metric_vals_dict
27
            metric_dicts.append(metric_dict)
28
        results_df = pd.DataFrame(metric_dicts)
29

30
        for metric_key in metric_keys:
31
            if metric_key not in results_df.columns:
32
                raise ValueError(f"Metric key {metric_key} not in results_df")
33
            avg_metrics_dict[metric_key].append(results_df[metric_key].mean())
34

35
    return pd.DataFrame({"retrievers": names, **avg_metrics_dict})
36

37

38
def get_eval_results_df(
39
    names: List[str], results_arr: List[EvaluationResult], metric: Optional[str] = None
40
) -> Tuple[pd.DataFrame, pd.DataFrame]:
41
    """Organizes EvaluationResults into a deep dataframe and computes the mean
42
    score.
43

44
    result:
45
        result_df: pd.DataFrame representing all the evaluation results
46
        mean_df: pd.DataFrame of average scores groupby names
47
    """
48
    if len(names) != len(results_arr):
49
        raise ValueError("names and results_arr must have same length.")
50

51
    qs = []
52
    ss = []
53
    fs = []
54
    rs = []
55
    cs = []
56
    for res in results_arr:
57
        qs.append(res.query)
58
        ss.append(res.score)
59
        fs.append(res.feedback)
60
        rs.append(res.response)
61
        cs.append(res.contexts)
62

63
    deep_df = pd.DataFrame(
64
        {
65
            "rag": names,
66
            "query": qs,
67
            "answer": rs,
68
            "contexts": cs,
69
            "scores": ss,
70
            "feedbacks": fs,
71
        }
72
    )
73
    mean_df = pd.DataFrame(deep_df.groupby(["rag"])["scores"].mean()).T
74
    if metric:
75
        mean_df.index = [f"mean_{metric}_score"]
76

77
    return deep_df, mean_df
78

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.