llama-index
280 строк · 8.9 Кб
1"""Param tuner."""
2
3import asyncio4from abc import abstractmethod5from copy import deepcopy6from typing import Any, Awaitable, Callable, Dict, List, Optional7
8from llama_index.legacy.bridge.pydantic import BaseModel, Field, PrivateAttr9from llama_index.legacy.utils import get_tqdm_iterable10
11
12class RunResult(BaseModel):13"""Run result."""14
15score: float16params: Dict[str, Any]17metadata: Dict[str, Any] = Field(default_factory=dict, description="Metadata.")18
19
20class TunedResult(BaseModel):21run_results: List[RunResult]22best_idx: int23
24@property25def best_run_result(self) -> RunResult:26"""Get best run result."""27return self.run_results[self.best_idx]28
29
30def generate_param_combinations(param_dict: Dict[str, Any]) -> List[Dict[str, Any]]:31"""Generate parameter combinations."""32
33def _generate_param_combinations_helper(34param_dict: Dict[str, Any], curr_param_dict: Dict[str, Any]35) -> List[Dict[str, Any]]:36"""Helper function."""37if len(param_dict) == 0:38return [deepcopy(curr_param_dict)]39param_dict = deepcopy(param_dict)40param_name, param_vals = param_dict.popitem()41param_combinations = []42for param_val in param_vals:43curr_param_dict[param_name] = param_val44param_combinations.extend(45_generate_param_combinations_helper(param_dict, curr_param_dict)46)47return param_combinations48
49return _generate_param_combinations_helper(param_dict, {})50
51
52class BaseParamTuner(BaseModel):53"""Base param tuner."""54
55param_dict: Dict[str, Any] = Field(56..., description="A dictionary of parameters to iterate over."57)58fixed_param_dict: Dict[str, Any] = Field(59default_factory=dict,60description="A dictionary of fixed parameters passed to each job.",61)62show_progress: bool = False63
64@abstractmethod65def tune(self) -> TunedResult:66"""Tune parameters."""67
68async def atune(self) -> TunedResult:69"""Async Tune parameters.70
71Override if you implement a native async method.
72
73"""
74return self.tune()75
76
77class ParamTuner(BaseParamTuner):78"""Parameter tuner.79
80Args:
81param_dict(Dict): A dictionary of parameters to iterate over.
82Example param_dict:
83{
84"num_epochs": [10, 20],
85"batch_size": [8, 16, 32],
86}
87fixed_param_dict(Dict): A dictionary of fixed parameters passed to each job.
88
89"""
90
91param_fn: Callable[[Dict[str, Any]], RunResult] = Field(92..., description="Function to run with parameters."93)94
95def tune(self) -> TunedResult:96"""Run tuning."""97# each key in param_dict is a parameter to tune, each val98# is a list of values to try99# generate combinations of parameters from the param_dict100param_combinations = generate_param_combinations(self.param_dict)101
102# for each combination, run the job with the arguments103# in args_dict104
105combos_with_progress = enumerate(106get_tqdm_iterable(107param_combinations, self.show_progress, "Param combinations."108)109)110
111all_run_results = []112for idx, param_combination in combos_with_progress:113full_param_dict = {114**self.fixed_param_dict,115**param_combination,116}117run_result = self.param_fn(full_param_dict)118
119all_run_results.append(run_result)120
121# sort the results by score122sorted_run_results = sorted(123all_run_results, key=lambda x: x.score, reverse=True124)125
126return TunedResult(run_results=sorted_run_results, best_idx=0)127
128
129class AsyncParamTuner(BaseParamTuner):130"""Async Parameter tuner.131
132Args:
133param_dict(Dict): A dictionary of parameters to iterate over.
134Example param_dict:
135{
136"num_epochs": [10, 20],
137"batch_size": [8, 16, 32],
138}
139fixed_param_dict(Dict): A dictionary of fixed parameters passed to each job.
140aparam_fn (Callable): An async function to run with parameters.
141num_workers (int): Number of workers to use.
142
143"""
144
145aparam_fn: Callable[[Dict[str, Any]], Awaitable[RunResult]] = Field(146..., description="Async function to run with parameters."147)148num_workers: int = Field(2, description="Number of workers to use.")149
150_semaphore: asyncio.Semaphore = PrivateAttr()151
152def __init__(self, *args: Any, **kwargs: Any) -> None:153"""Init params."""154super().__init__(*args, **kwargs)155self._semaphore = asyncio.Semaphore(self.num_workers)156
157async def atune(self) -> TunedResult:158"""Run tuning."""159# each key in param_dict is a parameter to tune, each val160# is a list of values to try161# generate combinations of parameters from the param_dict162param_combinations = generate_param_combinations(self.param_dict)163
164# for each combination, run the job with the arguments165# in args_dict166
167async def aparam_fn_worker(168semaphore: asyncio.Semaphore,169full_param_dict: Dict[str, Any],170) -> RunResult:171"""Async param fn worker."""172async with semaphore:173return await self.aparam_fn(full_param_dict)174
175all_run_results = []176run_jobs = []177for param_combination in param_combinations:178full_param_dict = {179**self.fixed_param_dict,180**param_combination,181}182run_jobs.append(aparam_fn_worker(self._semaphore, full_param_dict))183# run_jobs.append(self.aparam_fn(full_param_dict))184
185if self.show_progress:186from tqdm.asyncio import tqdm_asyncio187
188all_run_results = await tqdm_asyncio.gather(*run_jobs)189else:190all_run_results = await asyncio.gather(*run_jobs)191
192# sort the results by score193sorted_run_results = sorted(194all_run_results, key=lambda x: x.score, reverse=True195)196
197return TunedResult(run_results=sorted_run_results, best_idx=0)198
199def tune(self) -> TunedResult:200"""Run tuning."""201return asyncio.run(self.atune())202
203
204class RayTuneParamTuner(BaseParamTuner):205"""Parameter tuner powered by Ray Tune.206
207Args:
208param_dict(Dict): A dictionary of parameters to iterate over.
209Example param_dict:
210{
211"num_epochs": [10, 20],
212"batch_size": [8, 16, 32],
213}
214fixed_param_dict(Dict): A dictionary of fixed parameters passed to each job.
215
216"""
217
218param_fn: Callable[[Dict[str, Any]], RunResult] = Field(219..., description="Function to run with parameters."220)221
222run_config_dict: Optional[dict] = Field(223default=None, description="Run config dict for Ray Tune."224)225
226def tune(self) -> TunedResult:227"""Run tuning."""228from ray import tune229from ray.train import RunConfig230
231# convert every array in param_dict to a tune.grid_search232ray_param_dict = {}233for param_name, param_vals in self.param_dict.items():234ray_param_dict[param_name] = tune.grid_search(param_vals)235
236def param_fn_wrapper(237ray_param_dict: Dict, fixed_param_dict: Optional[Dict] = None238) -> Dict:239# need a wrapper to pass in parameters to tune + fixed params240fixed_param_dict = fixed_param_dict or {}241full_param_dict = {242**fixed_param_dict,243**ray_param_dict,244}245tuned_result = self.param_fn(full_param_dict)246# need to convert RunResult to dict to obey247# Ray Tune's API248return tuned_result.dict()249
250run_config = RunConfig(**self.run_config_dict) if self.run_config_dict else None251
252tuner = tune.Tuner(253tune.with_parameters(254param_fn_wrapper, fixed_param_dict=self.fixed_param_dict255),256param_space=ray_param_dict,257run_config=run_config,258)259
260results = tuner.fit()261all_run_results = []262for idx in range(len(results)):263result = results[idx]264# convert dict back to RunResult (reconstruct it with metadata)265# get the keys in RunResult, assign corresponding values in266# result.metrics to those keys267run_result = RunResult.parse_obj(result.metrics)268# add some more metadata to run_result (e.g. timestamp)269run_result.metadata["timestamp"] = (270result.metrics["timestamp"] if result.metrics else None271)272
273all_run_results.append(run_result)274
275# sort the results by score276sorted_run_results = sorted(277all_run_results, key=lambda x: x.score, reverse=True278)279
280return TunedResult(run_results=sorted_run_results, best_idx=0)281