llama-index

Форк
0
328 строк · 11.0 Кб
1
import asyncio
2
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
3

4
from llama_index.legacy.async_utils import asyncio_module
5
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
6
from llama_index.legacy.core.response.schema import RESPONSE_TYPE, Response
7
from llama_index.legacy.evaluation.base import BaseEvaluator, EvaluationResult
8

9

10
async def eval_response_worker(
11
    semaphore: asyncio.Semaphore,
12
    evaluator: BaseEvaluator,
13
    evaluator_name: str,
14
    query: Optional[str] = None,
15
    response: Optional[Response] = None,
16
    eval_kwargs: Optional[Dict[str, Any]] = None,
17
) -> Tuple[str, EvaluationResult]:
18
    """Get aevaluate_response tasks with semaphore."""
19
    eval_kwargs = eval_kwargs or {}
20
    async with semaphore:
21
        return (
22
            evaluator_name,
23
            await evaluator.aevaluate_response(
24
                query=query, response=response, **eval_kwargs
25
            ),
26
        )
27

28

29
async def eval_worker(
30
    semaphore: asyncio.Semaphore,
31
    evaluator: BaseEvaluator,
32
    evaluator_name: str,
33
    query: Optional[str] = None,
34
    response_str: Optional[str] = None,
35
    contexts: Optional[Sequence[str]] = None,
36
    eval_kwargs: Optional[Dict[str, Any]] = None,
37
) -> Tuple[str, EvaluationResult]:
38
    """Get aevaluate tasks with semaphore."""
39
    eval_kwargs = eval_kwargs or {}
40
    async with semaphore:
41
        return (
42
            evaluator_name,
43
            await evaluator.aevaluate(
44
                query=query, response=response_str, contexts=contexts, **eval_kwargs
45
            ),
46
        )
47

48

49
async def response_worker(
50
    semaphore: asyncio.Semaphore,
51
    query_engine: BaseQueryEngine,
52
    query: str,
53
) -> RESPONSE_TYPE:
54
    """Get aquery tasks with semaphore."""
55
    async with semaphore:
56
        return await query_engine.aquery(query)
57

58

59
class BatchEvalRunner:
60
    """Batch evaluation runner.
61

62
    Args:
63
        evaluators (Dict[str, BaseEvaluator]): Dictionary of evaluators.
64
        workers (int): Number of workers to use for parallelization.
65
            Defaults to 2.
66
        show_progress (bool): Whether to show progress bars. Defaults to False.
67

68
    """
69

70
    def __init__(
71
        self,
72
        evaluators: Dict[str, BaseEvaluator],
73
        workers: int = 2,
74
        show_progress: bool = False,
75
    ):
76
        self.evaluators = evaluators
77
        self.workers = workers
78
        self.semaphore = asyncio.Semaphore(self.workers)
79
        self.show_progress = show_progress
80
        self.asyncio_mod = asyncio_module(show_progress=self.show_progress)
81

82
    def _format_results(
83
        self, results: List[EvaluationResult]
84
    ) -> Dict[str, List[EvaluationResult]]:
85
        """Format results."""
86
        # Format results
87
        results_dict: Dict[str, List[EvaluationResult]] = {
88
            name: [] for name in self.evaluators
89
        }
90
        for name, result in results:
91
            results_dict[name].append(result)
92

93
        return results_dict
94

95
    def _validate_and_clean_inputs(
96
        self,
97
        *inputs_list: Any,
98
    ) -> List[Any]:
99
        """Validate and clean input lists.
100

101
        Enforce that at least one of the inputs is not None.
102
        Make sure that all inputs have the same length.
103
        Make sure that None inputs are replaced with [None] * len(inputs).
104

105
        """
106
        assert len(inputs_list) > 0
107
        # first, make sure at least one of queries or response_strs is not None
108
        input_len: Optional[int] = None
109
        for inputs in inputs_list:
110
            if inputs is not None:
111
                input_len = len(inputs)
112
                break
113
        if input_len is None:
114
            raise ValueError("At least one item in inputs_list must be provided.")
115

116
        new_inputs_list = []
117
        for inputs in inputs_list:
118
            if inputs is None:
119
                new_inputs_list.append([None] * input_len)
120
            else:
121
                if len(inputs) != input_len:
122
                    raise ValueError("All inputs must have the same length.")
123
                new_inputs_list.append(inputs)
124
        return new_inputs_list
125

126
    def _get_eval_kwargs(
127
        self, eval_kwargs_lists: Dict[str, Any], idx: int
128
    ) -> Dict[str, Any]:
129
        """Get eval kwargs from eval_kwargs_lists at a given idx.
130

131
        Since eval_kwargs_lists is a dict of lists, we need to get the
132
        value at idx for each key.
133

134
        """
135
        return {k: v[idx] for k, v in eval_kwargs_lists.items()}
136

137
    async def aevaluate_response_strs(
138
        self,
139
        queries: Optional[List[str]] = None,
140
        response_strs: Optional[List[str]] = None,
141
        contexts_list: Optional[List[List[str]]] = None,
142
        **eval_kwargs_lists: List,
143
    ) -> Dict[str, List[EvaluationResult]]:
144
        """Evaluate query, response pairs.
145

146
        This evaluates queries, responses, contexts as string inputs.
147
        Can supply additional kwargs to the evaluator in eval_kwargs_lists.
148

149
        Args:
150
            queries (Optional[List[str]]): List of query strings. Defaults to None.
151
            response_strs (Optional[List[str]]): List of response strings.
152
                Defaults to None.
153
            contexts_list (Optional[List[List[str]]]): List of context lists.
154
                Defaults to None.
155
            **eval_kwargs_lists (Dict[str, Any]): Dict of lists of kwargs to
156
                pass to evaluator. Defaults to None.
157

158
        """
159
        queries, response_strs, contexts_list = self._validate_and_clean_inputs(
160
            queries, response_strs, contexts_list
161
        )
162
        for k in eval_kwargs_lists:
163
            v = eval_kwargs_lists[k]
164
            if not isinstance(v, list):
165
                raise ValueError(
166
                    f"Each value in eval_kwargs must be a list. Got {k}: {v}"
167
                )
168
            eval_kwargs_lists[k] = self._validate_and_clean_inputs(v)[0]
169

170
        # run evaluations
171
        eval_jobs = []
172
        for idx, query in enumerate(cast(List[str], queries)):
173
            response_str = cast(List, response_strs)[idx]
174
            contexts = cast(List, contexts_list)[idx]
175
            eval_kwargs = self._get_eval_kwargs(eval_kwargs_lists, idx)
176
            for name, evaluator in self.evaluators.items():
177
                eval_jobs.append(
178
                    eval_worker(
179
                        self.semaphore,
180
                        evaluator,
181
                        name,
182
                        query=query,
183
                        response_str=response_str,
184
                        contexts=contexts,
185
                        eval_kwargs=eval_kwargs,
186
                    )
187
                )
188
        results = await self.asyncio_mod.gather(*eval_jobs)
189

190
        # Format results
191
        return self._format_results(results)
192

193
    async def aevaluate_responses(
194
        self,
195
        queries: Optional[List[str]] = None,
196
        responses: Optional[List[Response]] = None,
197
        **eval_kwargs_lists: Dict[str, Any],
198
    ) -> Dict[str, List[EvaluationResult]]:
199
        """Evaluate query, response pairs.
200

201
        This evaluates queries and response objects.
202

203
        Args:
204
            queries (Optional[List[str]]): List of query strings. Defaults to None.
205
            responses (Optional[List[Response]]): List of response objects.
206
                Defaults to None.
207
            **eval_kwargs_lists (Dict[str, Any]): Dict of lists of kwargs to
208
                pass to evaluator. Defaults to None.
209

210
        """
211
        queries, responses = self._validate_and_clean_inputs(queries, responses)
212
        for k in eval_kwargs_lists:
213
            v = eval_kwargs_lists[k]
214
            if not isinstance(v, list):
215
                raise ValueError(
216
                    f"Each value in eval_kwargs must be a list. Got {k}: {v}"
217
                )
218
            eval_kwargs_lists[k] = self._validate_and_clean_inputs(v)[0]
219

220
        # run evaluations
221
        eval_jobs = []
222
        for idx, query in enumerate(cast(List[str], queries)):
223
            response = cast(List, responses)[idx]
224
            eval_kwargs = self._get_eval_kwargs(eval_kwargs_lists, idx)
225
            for name, evaluator in self.evaluators.items():
226
                eval_jobs.append(
227
                    eval_response_worker(
228
                        self.semaphore,
229
                        evaluator,
230
                        name,
231
                        query=query,
232
                        response=response,
233
                        eval_kwargs=eval_kwargs,
234
                    )
235
                )
236
        results = await self.asyncio_mod.gather(*eval_jobs)
237

238
        # Format results
239
        return self._format_results(results)
240

241
    async def aevaluate_queries(
242
        self,
243
        query_engine: BaseQueryEngine,
244
        queries: Optional[List[str]] = None,
245
        **eval_kwargs_lists: Dict[str, Any],
246
    ) -> Dict[str, List[EvaluationResult]]:
247
        """Evaluate queries.
248

249
        Args:
250
            query_engine (BaseQueryEngine): Query engine.
251
            queries (Optional[List[str]]): List of query strings. Defaults to None.
252
            **eval_kwargs_lists (Dict[str, Any]): Dict of lists of kwargs to
253
                pass to evaluator. Defaults to None.
254

255
        """
256
        if queries is None:
257
            raise ValueError("`queries` must be provided")
258

259
        # gather responses
260
        response_jobs = []
261
        for query in queries:
262
            response_jobs.append(response_worker(self.semaphore, query_engine, query))
263
        responses = await self.asyncio_mod.gather(*response_jobs)
264

265
        return await self.aevaluate_responses(
266
            queries=queries,
267
            responses=responses,
268
            **eval_kwargs_lists,
269
        )
270

271
    def evaluate_response_strs(
272
        self,
273
        queries: Optional[List[str]] = None,
274
        response_strs: Optional[List[str]] = None,
275
        contexts_list: Optional[List[List[str]]] = None,
276
        **eval_kwargs_lists: List,
277
    ) -> Dict[str, List[EvaluationResult]]:
278
        """Evaluate query, response pairs.
279

280
        Sync version of aevaluate_response_strs.
281

282
        """
283
        return asyncio.run(
284
            self.aevaluate_response_strs(
285
                queries=queries,
286
                response_strs=response_strs,
287
                contexts_list=contexts_list,
288
                **eval_kwargs_lists,
289
            )
290
        )
291

292
    def evaluate_responses(
293
        self,
294
        queries: Optional[List[str]] = None,
295
        responses: Optional[List[Response]] = None,
296
        **eval_kwargs_lists: Dict[str, Any],
297
    ) -> Dict[str, List[EvaluationResult]]:
298
        """Evaluate query, response objs.
299

300
        Sync version of aevaluate_responses.
301

302
        """
303
        return asyncio.run(
304
            self.aevaluate_responses(
305
                queries=queries,
306
                responses=responses,
307
                **eval_kwargs_lists,
308
            )
309
        )
310

311
    def evaluate_queries(
312
        self,
313
        query_engine: BaseQueryEngine,
314
        queries: Optional[List[str]] = None,
315
        **eval_kwargs_lists: Dict[str, Any],
316
    ) -> Dict[str, List[EvaluationResult]]:
317
        """Evaluate queries.
318

319
        Sync version of aevaluate_queries.
320

321
        """
322
        return asyncio.run(
323
            self.aevaluate_queries(
324
                query_engine=query_engine,
325
                queries=queries,
326
                **eval_kwargs_lists,
327
            )
328
        )
329

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.