llama-index

Форк
0
280 строк · 8.9 Кб
1
"""Param tuner."""
2

3
import asyncio
4
from abc import abstractmethod
5
from copy import deepcopy
6
from typing import Any, Awaitable, Callable, Dict, List, Optional
7

8
from llama_index.legacy.bridge.pydantic import BaseModel, Field, PrivateAttr
9
from llama_index.legacy.utils import get_tqdm_iterable
10

11

12
class RunResult(BaseModel):
13
    """Run result."""
14

15
    score: float
16
    params: Dict[str, Any]
17
    metadata: Dict[str, Any] = Field(default_factory=dict, description="Metadata.")
18

19

20
class TunedResult(BaseModel):
21
    run_results: List[RunResult]
22
    best_idx: int
23

24
    @property
25
    def best_run_result(self) -> RunResult:
26
        """Get best run result."""
27
        return self.run_results[self.best_idx]
28

29

30
def generate_param_combinations(param_dict: Dict[str, Any]) -> List[Dict[str, Any]]:
31
    """Generate parameter combinations."""
32

33
    def _generate_param_combinations_helper(
34
        param_dict: Dict[str, Any], curr_param_dict: Dict[str, Any]
35
    ) -> List[Dict[str, Any]]:
36
        """Helper function."""
37
        if len(param_dict) == 0:
38
            return [deepcopy(curr_param_dict)]
39
        param_dict = deepcopy(param_dict)
40
        param_name, param_vals = param_dict.popitem()
41
        param_combinations = []
42
        for param_val in param_vals:
43
            curr_param_dict[param_name] = param_val
44
            param_combinations.extend(
45
                _generate_param_combinations_helper(param_dict, curr_param_dict)
46
            )
47
        return param_combinations
48

49
    return _generate_param_combinations_helper(param_dict, {})
50

51

52
class BaseParamTuner(BaseModel):
53
    """Base param tuner."""
54

55
    param_dict: Dict[str, Any] = Field(
56
        ..., description="A dictionary of parameters to iterate over."
57
    )
58
    fixed_param_dict: Dict[str, Any] = Field(
59
        default_factory=dict,
60
        description="A dictionary of fixed parameters passed to each job.",
61
    )
62
    show_progress: bool = False
63

64
    @abstractmethod
65
    def tune(self) -> TunedResult:
66
        """Tune parameters."""
67

68
    async def atune(self) -> TunedResult:
69
        """Async Tune parameters.
70

71
        Override if you implement a native async method.
72

73
        """
74
        return self.tune()
75

76

77
class ParamTuner(BaseParamTuner):
78
    """Parameter tuner.
79

80
    Args:
81
        param_dict(Dict): A dictionary of parameters to iterate over.
82
            Example param_dict:
83
            {
84
                "num_epochs": [10, 20],
85
                "batch_size": [8, 16, 32],
86
            }
87
        fixed_param_dict(Dict): A dictionary of fixed parameters passed to each job.
88

89
    """
90

91
    param_fn: Callable[[Dict[str, Any]], RunResult] = Field(
92
        ..., description="Function to run with parameters."
93
    )
94

95
    def tune(self) -> TunedResult:
96
        """Run tuning."""
97
        # each key in param_dict is a parameter to tune, each val
98
        # is a list of values to try
99
        # generate combinations of parameters from the param_dict
100
        param_combinations = generate_param_combinations(self.param_dict)
101

102
        # for each combination, run the job with the arguments
103
        # in args_dict
104

105
        combos_with_progress = enumerate(
106
            get_tqdm_iterable(
107
                param_combinations, self.show_progress, "Param combinations."
108
            )
109
        )
110

111
        all_run_results = []
112
        for idx, param_combination in combos_with_progress:
113
            full_param_dict = {
114
                **self.fixed_param_dict,
115
                **param_combination,
116
            }
117
            run_result = self.param_fn(full_param_dict)
118

119
            all_run_results.append(run_result)
120

121
        # sort the results by score
122
        sorted_run_results = sorted(
123
            all_run_results, key=lambda x: x.score, reverse=True
124
        )
125

126
        return TunedResult(run_results=sorted_run_results, best_idx=0)
127

128

129
class AsyncParamTuner(BaseParamTuner):
130
    """Async Parameter tuner.
131

132
    Args:
133
        param_dict(Dict): A dictionary of parameters to iterate over.
134
            Example param_dict:
135
            {
136
                "num_epochs": [10, 20],
137
                "batch_size": [8, 16, 32],
138
            }
139
        fixed_param_dict(Dict): A dictionary of fixed parameters passed to each job.
140
        aparam_fn (Callable): An async function to run with parameters.
141
        num_workers (int): Number of workers to use.
142

143
    """
144

145
    aparam_fn: Callable[[Dict[str, Any]], Awaitable[RunResult]] = Field(
146
        ..., description="Async function to run with parameters."
147
    )
148
    num_workers: int = Field(2, description="Number of workers to use.")
149

150
    _semaphore: asyncio.Semaphore = PrivateAttr()
151

152
    def __init__(self, *args: Any, **kwargs: Any) -> None:
153
        """Init params."""
154
        super().__init__(*args, **kwargs)
155
        self._semaphore = asyncio.Semaphore(self.num_workers)
156

157
    async def atune(self) -> TunedResult:
158
        """Run tuning."""
159
        # each key in param_dict is a parameter to tune, each val
160
        # is a list of values to try
161
        # generate combinations of parameters from the param_dict
162
        param_combinations = generate_param_combinations(self.param_dict)
163

164
        # for each combination, run the job with the arguments
165
        # in args_dict
166

167
        async def aparam_fn_worker(
168
            semaphore: asyncio.Semaphore,
169
            full_param_dict: Dict[str, Any],
170
        ) -> RunResult:
171
            """Async param fn worker."""
172
            async with semaphore:
173
                return await self.aparam_fn(full_param_dict)
174

175
        all_run_results = []
176
        run_jobs = []
177
        for param_combination in param_combinations:
178
            full_param_dict = {
179
                **self.fixed_param_dict,
180
                **param_combination,
181
            }
182
            run_jobs.append(aparam_fn_worker(self._semaphore, full_param_dict))
183
            # run_jobs.append(self.aparam_fn(full_param_dict))
184

185
        if self.show_progress:
186
            from tqdm.asyncio import tqdm_asyncio
187

188
            all_run_results = await tqdm_asyncio.gather(*run_jobs)
189
        else:
190
            all_run_results = await asyncio.gather(*run_jobs)
191

192
        # sort the results by score
193
        sorted_run_results = sorted(
194
            all_run_results, key=lambda x: x.score, reverse=True
195
        )
196

197
        return TunedResult(run_results=sorted_run_results, best_idx=0)
198

199
    def tune(self) -> TunedResult:
200
        """Run tuning."""
201
        return asyncio.run(self.atune())
202

203

204
class RayTuneParamTuner(BaseParamTuner):
205
    """Parameter tuner powered by Ray Tune.
206

207
    Args:
208
        param_dict(Dict): A dictionary of parameters to iterate over.
209
            Example param_dict:
210
            {
211
                "num_epochs": [10, 20],
212
                "batch_size": [8, 16, 32],
213
            }
214
        fixed_param_dict(Dict): A dictionary of fixed parameters passed to each job.
215

216
    """
217

218
    param_fn: Callable[[Dict[str, Any]], RunResult] = Field(
219
        ..., description="Function to run with parameters."
220
    )
221

222
    run_config_dict: Optional[dict] = Field(
223
        default=None, description="Run config dict for Ray Tune."
224
    )
225

226
    def tune(self) -> TunedResult:
227
        """Run tuning."""
228
        from ray import tune
229
        from ray.train import RunConfig
230

231
        # convert every array in param_dict to a tune.grid_search
232
        ray_param_dict = {}
233
        for param_name, param_vals in self.param_dict.items():
234
            ray_param_dict[param_name] = tune.grid_search(param_vals)
235

236
        def param_fn_wrapper(
237
            ray_param_dict: Dict, fixed_param_dict: Optional[Dict] = None
238
        ) -> Dict:
239
            # need a wrapper to pass in parameters to tune + fixed params
240
            fixed_param_dict = fixed_param_dict or {}
241
            full_param_dict = {
242
                **fixed_param_dict,
243
                **ray_param_dict,
244
            }
245
            tuned_result = self.param_fn(full_param_dict)
246
            # need to convert RunResult to dict to obey
247
            # Ray Tune's API
248
            return tuned_result.dict()
249

250
        run_config = RunConfig(**self.run_config_dict) if self.run_config_dict else None
251

252
        tuner = tune.Tuner(
253
            tune.with_parameters(
254
                param_fn_wrapper, fixed_param_dict=self.fixed_param_dict
255
            ),
256
            param_space=ray_param_dict,
257
            run_config=run_config,
258
        )
259

260
        results = tuner.fit()
261
        all_run_results = []
262
        for idx in range(len(results)):
263
            result = results[idx]
264
            # convert dict back to RunResult (reconstruct it with metadata)
265
            # get the keys in RunResult, assign corresponding values in
266
            # result.metrics to those keys
267
            run_result = RunResult.parse_obj(result.metrics)
268
            # add some more metadata to run_result (e.g. timestamp)
269
            run_result.metadata["timestamp"] = (
270
                result.metrics["timestamp"] if result.metrics else None
271
            )
272

273
            all_run_results.append(run_result)
274

275
        # sort the results by score
276
        sorted_run_results = sorted(
277
            all_run_results, key=lambda x: x.score, reverse=True
278
        )
279

280
        return TunedResult(run_results=sorted_run_results, best_idx=0)
281

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.