google-research
356 строк · 11.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Running a sampling method on the base and the reduced search spaces."""
17import functools
18import operator
19from typing import Any, Callable
20
21
22import jax
23import jax.numpy as jnp
24import spaceopt.bo_utils as bo
25import spaceopt.gp_utils as gp
26import spaceopt.scores as scores
27import spaceopt.search_spaces as search_spaces
28
29# pylint: disable=g-long-lambda
30
31
32def point_in_search_space(points, search_space):
33assert points.shape[1] == search_space.shape[0]
34mask = jnp.all(
35search_space[:, 0] <= points, axis=1) * jnp.all(
36search_space[:, 1] >= points, axis=1)
37return mask
38
39
40def prep_utility_fn(params,
41x_b1,
42y_b1,
43utility_type,
44steps=1000,
45percentile=0,
46y_incumb=None):
47"""Prepare the utility functions for the score."""
48gp_util = gp.GPUtils()
49params = gp_util.fit_gp(x_b1, y_b1, params, steps=steps)
50if y_incumb is None:
51y_incumb = jnp.percentile(y_b1, percentile, axis=0)
52utility_measure = scores.UtilityMeasure(incumbent=y_incumb, params=params)
53utility_measure_fn = lambda key, x_batch, y_batch: getattr(
54utility_measure, utility_type)(y_batch)
55return utility_measure_fn, params
56
57
58def draw_y_values(key,
59params,
60x_obs,
61y_obs,
62x_test,
63num_samples,
64sampling_gp_method):
65"""Draw y samples from the GP posterior."""
66gp_util = gp.GPUtils()
67mu, cov = gp_util.posterior_mean_cov(params, x_obs, y_obs, x_test)
68samples = gp_util.draw_gp_samples(
69key, mu, cov, num_samples=num_samples, method=sampling_gp_method)
70return samples
71
72
73def eval_score(key,
74search_space,
75budget,
76utility_measure_fn,
77draw_y_values_fn,
78centrality_over_x,
79x_drawn_num=500,
80y_drawn_num=500):
81"""Evaluate the (is)improvement scores."""
82mean_utility_val = scores.mean_utility(
83key,
84search_space,
85budget,
86utility_measure_fn,
87draw_y_values_fn,
88x_drawn_num=x_drawn_num,
89y_drawn_num=y_drawn_num)
90mean_utility_is = mean_utility_val > 0
91imp = scores.scores(mean_utility_val, centrality_over_x)
92is_imp = scores.scores(mean_utility_is, centrality_over_x)
93imp_method_to_stats = {}
94imp_method_to_stats['imp'] = imp
95imp_method_to_stats['is_imp'] = is_imp
96return imp_method_to_stats
97
98
99def extract_best_search_space(scores_dict, centrality_key,
100search_spaces_reduced):
101"""Select the arg max score search space."""
102return search_spaces_reduced[jnp.argmax(scores_dict[centrality_key]), :, :]
103
104
105class SamplingMethod:
106"""Class for the sampling method to spend a budget over a search space."""
107
108def __init__(self,
109search_space,
110objective_fn,
111x_precollect=None,
112y_precollect=None,
113additional_info_precollect=None):
114self.search_space = search_space
115self.objective_fn = objective_fn
116self.x_precollect = x_precollect
117self.y_precollect = y_precollect
118self.additional_info_precollect = additional_info_precollect
119
120def rs(self, key, budget):
121"""Random search sampling method with uniform distribution."""
122if operator.xor(self.x_precollect is None, self.y_precollect is None):
123raise ValueError(
124'Both x_precollect and y_precollect need to be provided.')
125
126if self.x_precollect is not None:
127ind_pre = point_in_search_space(self.x_precollect, self.search_space)
128x_precollect_in_search_space = self.x_precollect[ind_pre, :]
129y_precollect_in_search_space = self.y_precollect[ind_pre, :]
130
131if y_precollect_in_search_space.shape[0] < budget:
132raise ValueError(
133'budget is larger than the precollected data in the search space.')
134ind_chosen = jax.random.choice(
135key,
136y_precollect_in_search_space.shape[0],
137shape=(budget,),
138replace=False)
139ind_chosen = jnp.sort(ind_chosen)
140x = x_precollect_in_search_space[ind_chosen, :]
141y = y_precollect_in_search_space[ind_chosen, :]
142additional_info_dict = {}
143else:
144x = jax.random.uniform(
145key,
146shape=(budget, self.search_space.shape[0]),
147minval=self.search_space[:, 0],
148maxval=self.search_space[:, 1])
149y, additional_info_dict = self.objective_fn(x)
150ind_chosen = None
151return x, y, additional_info_dict, ind_chosen
152
153def bo(self,
154key,
155budget,
156params,
157x_obs,
158y_obs,
159batch_size=1,
160num_points=500,
161num_steps=1000,
162sampling_gp_method='tfp'):
163"""Bayesian optimization sampling method."""
164rest_budget = budget - x_obs.shape[0]
165x, y, additional_info_dict = bo.bo(
166key,
167x_obs,
168y_obs,
169self.objective_fn,
170params,
171self.search_space,
172rest_budget,
173batch_size=batch_size,
174num_points=num_points,
175num_steps=num_steps,
176method=sampling_gp_method)
177ind_chosen = None
178return x, y, additional_info_dict, ind_chosen
179
180
181def run_sampling_method(key,
182objective_fn,
183search_space,
184budget,
185sampling_method='RS',
186params=None,
187x_init=None,
188y_init=None,
189num_init_for_bo=1,
190batch_size_for_bo=1,
191num_pnts_for_af=500,
192num_steps_for_gp=1000,
193sampling_gp_method='tfp',
194x_precollect=None,
195y_precollect=None,
196additional_info_precollect=None):
197"""Run the sampling method on a search space given a budget."""
198sampling = SamplingMethod(search_space, objective_fn, x_precollect,
199y_precollect, additional_info_precollect)
200if sampling_method == 'RS':
201x_sampled, y_sampled, _, ind_sampled_in_precollected = sampling.rs(
202key, budget)
203elif sampling_method == 'BO':
204if (x_init is None) and (y_init is None):
205key_init, key_rest = jax.random.split(key)
206x_init, y_init, _, _ = sampling.rs(key_init, num_init_for_bo)
207else:
208key_rest = key
209x_sampled, y_sampled, _, ind_sampled_in_precollected = sampling.bo(
210key_rest,
211budget,
212params,
213x_init,
214y_init,
215batch_size=batch_size_for_bo,
216num_points=num_pnts_for_af,
217num_steps=num_steps_for_gp,
218sampling_gp_method=sampling_gp_method)
219else:
220raise ValueError('Sampling method should be either RS or BO.')
221return x_sampled, y_sampled, ind_sampled_in_precollected
222
223
224def run_method_on_base_and_reduced_search_spaces(
225key,
226objective_fn,
227search_space,
228budget,
229budget_b1,
230params,
231centrality_over_x,
232sampling_method_primary='RS',
233num_init_for_bo=1,
234batch_size_for_bo=1,
235num_pnts_for_af=500,
236num_steps_for_gp=1000,
237num_x_for_score=500,
238num_y_for_score=500,
239reduce_rates=None,
240num_ss_per_rate=50,
241sampling_method_secondary=None,
242acquisition_method='improvement',
243sampling_gp_method='tfp',
244percentile=0,
245y_incumb=None,
246x_precollect=None,
247y_precollect=None,
248additional_info_precollect=None):
249"""Run the sampling method on base and best reduced search space."""
250key_sampling_prim, key_sampling_sec, key_ss, key_score = jax.random.split(
251key, 4)
252x_base, y_base, _ = run_sampling_method(
253key_sampling_prim,
254objective_fn,
255search_space,
256budget,
257sampling_method=sampling_method_primary,
258params=params,
259x_init=None,
260y_init=None,
261num_init_for_bo=num_init_for_bo,
262batch_size_for_bo=batch_size_for_bo,
263num_pnts_for_af=num_pnts_for_af,
264num_steps_for_gp=num_steps_for_gp,
265x_precollect=x_precollect,
266y_precollect=y_precollect,
267additional_info_precollect=additional_info_precollect)
268
269reduce_rates_repeated = jnp.repeat(reduce_rates, num_ss_per_rate)
270keys_ss = jax.random.split(key_ss, reduce_rates_repeated.shape[0])
271search_spaces_reduced = jax.vmap(
272search_spaces.generate_search_space_reduce_vol,
273in_axes=(0, None, 0))(keys_ss, search_space, reduce_rates_repeated)
274vols_reduced = jax.vmap(search_spaces.eval_vol)(search_spaces_reduced)
275
276x_b1 = x_base[:budget_b1, :]
277y_b1 = y_base[:budget_b1, :]
278
279utility_measure_fn, params_optimized = prep_utility_fn(
280params,
281x_b1,
282y_b1,
283acquisition_method,
284steps=num_steps_for_gp,
285percentile=percentile,
286y_incumb=y_incumb)
287
288# preparing the inputs for eval_score function to evaluate the score of
289# randomly generated reduced-volume search spaces at the remaining budget.
290draw_y_values_fn = lambda key, x_test, num_samples: draw_y_values(
291key=key,
292params=params_optimized,
293x_obs=x_b1,
294y_obs=y_b1,
295x_test=x_test,
296num_samples=num_samples,
297sampling_gp_method=sampling_gp_method)
298budget_b2 = budget - budget_b1
299partial_eval_score = functools.partial(
300eval_score,
301budget=budget_b2,
302utility_measure_fn=utility_measure_fn,
303draw_y_values_fn=draw_y_values_fn,
304centrality_over_x=centrality_over_x,
305x_drawn_num=num_x_for_score,
306y_drawn_num=num_y_for_score)
307keys_score = jax.random.split(key_score, search_spaces_reduced.shape[0])
308# the vmap evaluates the score for the generated search spaces.
309imp_method_to_stats = jax.vmap(partial_eval_score)(keys_score,
310search_spaces_reduced)
311
312if sampling_method_secondary is None:
313sampling_method_secondary = sampling_method_primary
314
315result_secondary = {}
316result_secondary['search_spaces_reduced'] = search_spaces_reduced
317result_secondary['vols_reduced'] = vols_reduced
318
319for imp in imp_method_to_stats.keys():
320imp_dict = imp_method_to_stats[imp]
321for centrality in imp_dict.keys():
322secondary_dict = {}
323search_space_best = extract_best_search_space(imp_dict, centrality,
324search_spaces_reduced)
325x_secondary, y_secondary, _ = run_sampling_method(
326key_sampling_sec,
327objective_fn,
328search_space_best,
329budget_b2,
330sampling_method=sampling_method_secondary,
331params=params,
332x_init=x_b1,
333y_init=y_b1,
334num_init_for_bo=num_init_for_bo,
335batch_size_for_bo=batch_size_for_bo,
336num_pnts_for_af=num_pnts_for_af,
337num_steps_for_gp=num_steps_for_gp,
338x_precollect=x_precollect,
339y_precollect=y_precollect,
340additional_info_precollect=additional_info_precollect)
341if sampling_method_secondary == 'RS':
342x_secondary = jnp.vstack((x_b1, x_secondary))
343y_secondary = jnp.vstack((y_b1, y_secondary))
344secondary_dict['x_secondary'] = x_secondary
345secondary_dict['y_secondary'] = y_secondary
346secondary_dict['search_space_best'] = search_space_best
347secondary_dict['vol_best'] = search_spaces.eval_vol(
348search_space_best) / search_spaces.eval_vol(search_space)
349result_secondary[(imp, centrality)] = secondary_dict
350
351results = {}
352results['x_base'] = x_base
353results['y_base'] = y_base
354results['secondary'] = result_secondary
355
356return results
357