google-research

Форк
0
/
run_comparison.py 
356 строк · 11.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Running a sampling method on the base and the reduced search spaces."""
17
import functools
18
import operator
19
from typing import Any, Callable
20

21

22
import jax
23
import jax.numpy as jnp
24
import spaceopt.bo_utils as bo
25
import spaceopt.gp_utils as gp
26
import spaceopt.scores as scores
27
import spaceopt.search_spaces as search_spaces
28

29
# pylint: disable=g-long-lambda
30

31

32
def point_in_search_space(points, search_space):
33
  assert points.shape[1] == search_space.shape[0]
34
  mask = jnp.all(
35
      search_space[:, 0] <= points, axis=1) * jnp.all(
36
          search_space[:, 1] >= points, axis=1)
37
  return mask
38

39

40
def prep_utility_fn(params,
41
                    x_b1,
42
                    y_b1,
43
                    utility_type,
44
                    steps=1000,
45
                    percentile=0,
46
                    y_incumb=None):
47
  """Prepare the utility functions for the score."""
48
  gp_util = gp.GPUtils()
49
  params = gp_util.fit_gp(x_b1, y_b1, params, steps=steps)
50
  if y_incumb is None:
51
    y_incumb = jnp.percentile(y_b1, percentile, axis=0)
52
  utility_measure = scores.UtilityMeasure(incumbent=y_incumb, params=params)
53
  utility_measure_fn = lambda key, x_batch, y_batch: getattr(
54
      utility_measure, utility_type)(y_batch)
55
  return utility_measure_fn, params
56

57

58
def draw_y_values(key,
59
                  params,
60
                  x_obs,
61
                  y_obs,
62
                  x_test,
63
                  num_samples,
64
                  sampling_gp_method):
65
  """Draw y samples from the GP posterior."""
66
  gp_util = gp.GPUtils()
67
  mu, cov = gp_util.posterior_mean_cov(params, x_obs, y_obs, x_test)
68
  samples = gp_util.draw_gp_samples(
69
      key, mu, cov, num_samples=num_samples, method=sampling_gp_method)
70
  return samples
71

72

73
def eval_score(key,
74
               search_space,
75
               budget,
76
               utility_measure_fn,
77
               draw_y_values_fn,
78
               centrality_over_x,
79
               x_drawn_num=500,
80
               y_drawn_num=500):
81
  """Evaluate the (is)improvement scores."""
82
  mean_utility_val = scores.mean_utility(
83
      key,
84
      search_space,
85
      budget,
86
      utility_measure_fn,
87
      draw_y_values_fn,
88
      x_drawn_num=x_drawn_num,
89
      y_drawn_num=y_drawn_num)
90
  mean_utility_is = mean_utility_val > 0
91
  imp = scores.scores(mean_utility_val, centrality_over_x)
92
  is_imp = scores.scores(mean_utility_is, centrality_over_x)
93
  imp_method_to_stats = {}
94
  imp_method_to_stats['imp'] = imp
95
  imp_method_to_stats['is_imp'] = is_imp
96
  return imp_method_to_stats
97

98

99
def extract_best_search_space(scores_dict, centrality_key,
100
                              search_spaces_reduced):
101
  """Select the arg max score search space."""
102
  return search_spaces_reduced[jnp.argmax(scores_dict[centrality_key]), :, :]
103

104

105
class SamplingMethod:
106
  """Class for the sampling method to spend a budget over a search space."""
107

108
  def __init__(self,
109
               search_space,
110
               objective_fn,
111
               x_precollect=None,
112
               y_precollect=None,
113
               additional_info_precollect=None):
114
    self.search_space = search_space
115
    self.objective_fn = objective_fn
116
    self.x_precollect = x_precollect
117
    self.y_precollect = y_precollect
118
    self.additional_info_precollect = additional_info_precollect
119

120
  def rs(self, key, budget):
121
    """Random search sampling method with uniform distribution."""
122
    if operator.xor(self.x_precollect is None, self.y_precollect is None):
123
      raise ValueError(
124
          'Both x_precollect and y_precollect need to be provided.')
125

126
    if self.x_precollect is not None:
127
      ind_pre = point_in_search_space(self.x_precollect, self.search_space)
128
      x_precollect_in_search_space = self.x_precollect[ind_pre, :]
129
      y_precollect_in_search_space = self.y_precollect[ind_pre, :]
130

131
      if y_precollect_in_search_space.shape[0] < budget:
132
        raise ValueError(
133
            'budget is larger than the precollected data in the search space.')
134
      ind_chosen = jax.random.choice(
135
          key,
136
          y_precollect_in_search_space.shape[0],
137
          shape=(budget,),
138
          replace=False)
139
      ind_chosen = jnp.sort(ind_chosen)
140
      x = x_precollect_in_search_space[ind_chosen, :]
141
      y = y_precollect_in_search_space[ind_chosen, :]
142
      additional_info_dict = {}
143
    else:
144
      x = jax.random.uniform(
145
          key,
146
          shape=(budget, self.search_space.shape[0]),
147
          minval=self.search_space[:, 0],
148
          maxval=self.search_space[:, 1])
149
      y, additional_info_dict = self.objective_fn(x)
150
      ind_chosen = None
151
    return x, y, additional_info_dict, ind_chosen
152

153
  def bo(self,
154
         key,
155
         budget,
156
         params,
157
         x_obs,
158
         y_obs,
159
         batch_size=1,
160
         num_points=500,
161
         num_steps=1000,
162
         sampling_gp_method='tfp'):
163
    """Bayesian optimization sampling method."""
164
    rest_budget = budget - x_obs.shape[0]
165
    x, y, additional_info_dict = bo.bo(
166
        key,
167
        x_obs,
168
        y_obs,
169
        self.objective_fn,
170
        params,
171
        self.search_space,
172
        rest_budget,
173
        batch_size=batch_size,
174
        num_points=num_points,
175
        num_steps=num_steps,
176
        method=sampling_gp_method)
177
    ind_chosen = None
178
    return x, y, additional_info_dict, ind_chosen
179

180

181
def run_sampling_method(key,
182
                        objective_fn,
183
                        search_space,
184
                        budget,
185
                        sampling_method='RS',
186
                        params=None,
187
                        x_init=None,
188
                        y_init=None,
189
                        num_init_for_bo=1,
190
                        batch_size_for_bo=1,
191
                        num_pnts_for_af=500,
192
                        num_steps_for_gp=1000,
193
                        sampling_gp_method='tfp',
194
                        x_precollect=None,
195
                        y_precollect=None,
196
                        additional_info_precollect=None):
197
  """Run the sampling method on a search space given a budget."""
198
  sampling = SamplingMethod(search_space, objective_fn, x_precollect,
199
                            y_precollect, additional_info_precollect)
200
  if sampling_method == 'RS':
201
    x_sampled, y_sampled, _, ind_sampled_in_precollected = sampling.rs(
202
        key, budget)
203
  elif sampling_method == 'BO':
204
    if (x_init is None) and (y_init is None):
205
      key_init, key_rest = jax.random.split(key)
206
      x_init, y_init, _, _ = sampling.rs(key_init, num_init_for_bo)
207
    else:
208
      key_rest = key
209
    x_sampled, y_sampled, _, ind_sampled_in_precollected = sampling.bo(
210
        key_rest,
211
        budget,
212
        params,
213
        x_init,
214
        y_init,
215
        batch_size=batch_size_for_bo,
216
        num_points=num_pnts_for_af,
217
        num_steps=num_steps_for_gp,
218
        sampling_gp_method=sampling_gp_method)
219
  else:
220
    raise ValueError('Sampling method should be either RS or BO.')
221
  return x_sampled, y_sampled, ind_sampled_in_precollected
222

223

224
def run_method_on_base_and_reduced_search_spaces(
225
    key,
226
    objective_fn,
227
    search_space,
228
    budget,
229
    budget_b1,
230
    params,
231
    centrality_over_x,
232
    sampling_method_primary='RS',
233
    num_init_for_bo=1,
234
    batch_size_for_bo=1,
235
    num_pnts_for_af=500,
236
    num_steps_for_gp=1000,
237
    num_x_for_score=500,
238
    num_y_for_score=500,
239
    reduce_rates=None,
240
    num_ss_per_rate=50,
241
    sampling_method_secondary=None,
242
    acquisition_method='improvement',
243
    sampling_gp_method='tfp',
244
    percentile=0,
245
    y_incumb=None,
246
    x_precollect=None,
247
    y_precollect=None,
248
    additional_info_precollect=None):
249
  """Run the sampling method on base and best reduced search space."""
250
  key_sampling_prim, key_sampling_sec, key_ss, key_score = jax.random.split(
251
      key, 4)
252
  x_base, y_base, _ = run_sampling_method(
253
      key_sampling_prim,
254
      objective_fn,
255
      search_space,
256
      budget,
257
      sampling_method=sampling_method_primary,
258
      params=params,
259
      x_init=None,
260
      y_init=None,
261
      num_init_for_bo=num_init_for_bo,
262
      batch_size_for_bo=batch_size_for_bo,
263
      num_pnts_for_af=num_pnts_for_af,
264
      num_steps_for_gp=num_steps_for_gp,
265
      x_precollect=x_precollect,
266
      y_precollect=y_precollect,
267
      additional_info_precollect=additional_info_precollect)
268

269
  reduce_rates_repeated = jnp.repeat(reduce_rates, num_ss_per_rate)
270
  keys_ss = jax.random.split(key_ss, reduce_rates_repeated.shape[0])
271
  search_spaces_reduced = jax.vmap(
272
      search_spaces.generate_search_space_reduce_vol,
273
      in_axes=(0, None, 0))(keys_ss, search_space, reduce_rates_repeated)
274
  vols_reduced = jax.vmap(search_spaces.eval_vol)(search_spaces_reduced)
275

276
  x_b1 = x_base[:budget_b1, :]
277
  y_b1 = y_base[:budget_b1, :]
278

279
  utility_measure_fn, params_optimized = prep_utility_fn(
280
      params,
281
      x_b1,
282
      y_b1,
283
      acquisition_method,
284
      steps=num_steps_for_gp,
285
      percentile=percentile,
286
      y_incumb=y_incumb)
287

288
  # preparing the inputs for eval_score function to evaluate the score of
289
  # randomly generated reduced-volume search spaces at the remaining budget.
290
  draw_y_values_fn = lambda key, x_test, num_samples: draw_y_values(
291
      key=key,
292
      params=params_optimized,
293
      x_obs=x_b1,
294
      y_obs=y_b1,
295
      x_test=x_test,
296
      num_samples=num_samples,
297
      sampling_gp_method=sampling_gp_method)
298
  budget_b2 = budget - budget_b1
299
  partial_eval_score = functools.partial(
300
      eval_score,
301
      budget=budget_b2,
302
      utility_measure_fn=utility_measure_fn,
303
      draw_y_values_fn=draw_y_values_fn,
304
      centrality_over_x=centrality_over_x,
305
      x_drawn_num=num_x_for_score,
306
      y_drawn_num=num_y_for_score)
307
  keys_score = jax.random.split(key_score, search_spaces_reduced.shape[0])
308
  # the vmap evaluates the score for the generated search spaces.
309
  imp_method_to_stats = jax.vmap(partial_eval_score)(keys_score,
310
                                                     search_spaces_reduced)
311

312
  if sampling_method_secondary is None:
313
    sampling_method_secondary = sampling_method_primary
314

315
  result_secondary = {}
316
  result_secondary['search_spaces_reduced'] = search_spaces_reduced
317
  result_secondary['vols_reduced'] = vols_reduced
318

319
  for imp in imp_method_to_stats.keys():
320
    imp_dict = imp_method_to_stats[imp]
321
    for centrality in imp_dict.keys():
322
      secondary_dict = {}
323
      search_space_best = extract_best_search_space(imp_dict, centrality,
324
                                                    search_spaces_reduced)
325
      x_secondary, y_secondary, _ = run_sampling_method(
326
          key_sampling_sec,
327
          objective_fn,
328
          search_space_best,
329
          budget_b2,
330
          sampling_method=sampling_method_secondary,
331
          params=params,
332
          x_init=x_b1,
333
          y_init=y_b1,
334
          num_init_for_bo=num_init_for_bo,
335
          batch_size_for_bo=batch_size_for_bo,
336
          num_pnts_for_af=num_pnts_for_af,
337
          num_steps_for_gp=num_steps_for_gp,
338
          x_precollect=x_precollect,
339
          y_precollect=y_precollect,
340
          additional_info_precollect=additional_info_precollect)
341
      if sampling_method_secondary == 'RS':
342
        x_secondary = jnp.vstack((x_b1, x_secondary))
343
        y_secondary = jnp.vstack((y_b1, y_secondary))
344
      secondary_dict['x_secondary'] = x_secondary
345
      secondary_dict['y_secondary'] = y_secondary
346
      secondary_dict['search_space_best'] = search_space_best
347
      secondary_dict['vol_best'] = search_spaces.eval_vol(
348
          search_space_best) / search_spaces.eval_vol(search_space)
349
      result_secondary[(imp, centrality)] = secondary_dict
350

351
  results = {}
352
  results['x_base'] = x_base
353
  results['y_base'] = y_base
354
  results['secondary'] = result_secondary
355

356
  return results
357

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.