google-research
345 строк · 13.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Scoring functions for a search space given an iteration budget."""
17
18import functools19import operator20from typing import Any, Callable, Dict, List21
22import jax23import jax.numpy as jnp24
25
26class UtilityMeasure:27"""Class for utility measures used in the scoring functions."""28
29def __init__(self,30params,31x_locations_for_p_min = None,32fantasize_y_values_for_p_min = None,33y_fantasized_num_for_p_min = 500,34initial_entropy_key=None,35eps = 1e-16,36incumbent = None):37"""Set the required arguments for calculating the utility measures.38
39Args:
40params: a dictionary from names to values to specify GP hyperparameters.
41x_locations_for_p_min: (n, d) shaped array of n x-locations in d
42dimensions. We estimate the pmf of the global min over
43x_locations_for_p_min.
44fantasize_y_values_for_p_min: a function generating y value draws at
45x_locations_for_p_min. Inputs of this function include a PRNG key, the
46observed data, params, x_locations and the number of desired draws.
47y_fantasized_num_for_p_min: number of desired draws of y. This parameter
48will be passed to fantasize_y_values_for_p_min function.
49initial_entropy_key : PRNG key for jax.random. We use this key to evaluate
50the initial entropy of p_min given the observed data.
51eps: optional float tolerance to avoid numerical issues of entropy.
52incumbent: float y value for improvement-based utility measures. One
53typical choice under noise-less evals is the best observed y value.
54"""
55if (x_locations_for_p_min is not None) * (fantasize_y_values_for_p_min56is not None):57
58self.fantasize_y_values_for_p_min = functools.partial(59fantasize_y_values_for_p_min,60x_locations_for_p_min=x_locations_for_p_min,61params=params,62y_fantasized_num_for_p_min=y_fantasized_num_for_p_min)63self.initial_entropy_key = initial_entropy_key64self.eps = eps65self.incumbent = incumbent66
67def is_improvement(self, y_batch):68"""Return whether a batch of y values can improve over the incumbent.69
70Args:
71y_batch: (q, t) shaped array of t y values corresponding to a batch of
72size q of x-locations. Each column of this array corresponds to one
73realization of y values at q x-locations evaluated/predicted t times.
74Returns:
75(t,) shaped array of boolean values indicating whether the best of
76q y value within each of t realizations have improved over the incumbent.
77"""
78return jnp.min(y_batch, axis=0) < self.incumbent79
80def improvement(self, y_batch):81"""Return how much a batch of y values can improve over the incumbent.82
83Args:
84y_batch: (q, t) shaped array of t y values corresponding to a batch of
85size q of x-locations. Each column of this array corresponds to one
86realization of y values at q x-locations evaluated/predicted t times.
87Returns:
88(t,) shaped array of non-negative float values indicating the
89improvement the best of q y value within each of t realizations achieves
90over the incumbent.
91"""
92difference = self.incumbent - jnp.min(y_batch, axis=0)93return jnp.maximum(0.0, difference)94
95def _p_x_min(self, y_fantasized):96"""Estimate a probablity mass function over the x-location of global min.97
98Args:
99y_fantasized: (n, m) shaped array of m fantasized y values over a common
100set of n x-locations.
101
102Returns:
103Estimated (n,) shaped array of pmf of the global min over x-location where
104the domain of the pmf is the common set of previous x-locations.
105
106"""
107counts = jnp.bincount(108jnp.argmin(y_fantasized, axis=0), length=y_fantasized.shape[0])109return counts / jnp.sum(counts)110
111def _entropy(self, p, eps = 1e-16):112"""Evaluate the entropy of an empirical probablity distribution.113
114Args:
115p: (n,) shaped array of probability values over n x_lodations.
116eps: optional float tolerance to avoid numerical issues.
117
118Returns:
119estimated entropy of p.
120"""
121return -jnp.sum(122jnp.where(p < eps, 0., (p + eps) * jnp.log(p + eps)), axis=0)123
124@functools.partial(jax.jit, static_argnums=(0,))125def _entropy_of_p_x_min_given_data(126self,127for_i_index,128key,129x_obs,130y_obs,131x_batch = None,132y_batch = None):133"""Compute information nats for locating global min given 1 batch data pair.134
135Args:
136for_i_index: int used to fold in over the key and to slice y_batch.
137key: PRNG key for jax.random.
138x_obs: (k, d) shaped array of k observed x-locations in d dimensions.
139y_obs: (k, 1) shaped array of observed y values at x_obs.
140x_batch: (q, d) shaped array of q candidate x-locations in d dimensions.
141y_batch: (q, 1) shaped array of y values corresponding to x_batch.
142
143Returns:
144Non-negative float value of the nats of gained information from
145x_batch and y_batch for locating the global min.
146"""
147if operator.xor(x_batch is None, y_batch is None):148raise ValueError("Both x_batch and y_batch need to be provided.")149
150if (x_batch is not None) and (y_batch is not None):151if y_batch.ndim == 1:152y_batch = y_batch[:, None]153x_obs = jnp.vstack((x_obs, x_batch))154y_obs = jnp.vstack((y_obs, y_batch[:, for_i_index][:, None]))155key = jax.random.fold_in(key, for_i_index)156
157fantasized_y_values = self.fantasize_y_values_for_p_min(158key=key, x_obs=x_obs, y_obs=y_obs)159p_x_min = self._p_x_min(fantasized_y_values)160entropy = self._entropy(p_x_min, self.eps)161return entropy162
163def information_gain(self, key, x_obs,164y_obs, x_batch,165y_batch,166include_initial_entropy = True,167vectorized = False):168"""Compute information nats for locating the global min given batch data.169
170In the below function, k refers to the number of observations,
171d is the data dimension, q is the size of the batch of interest,
172and t refers to the number of predictions/evaluations of y perfomed at the
173batch of x-locations.
174
175Args:
176key: PRNG key for jax.random.
177x_obs: (k, d) shaped array of k observed x-locations in d dimensions.
178y_obs: (k, 1) shaped array of observed y values at x_obs.
179x_batch: (q, d) shaped array of q candidate x-locations in d dimensions.
180y_batch: (q, t) shaped array of y values corresponding to x_batch. Each
181column of this array corresponds to one realization of q y values at
182x_batch evaluated/predicted for t times.
183include_initial_entropy: bool which decides whether initial entropy should
184be included in the calculations.
185vectorized: bool to set whether to evaluate conditional entropy over
186y_batch in a vectorized manner. Note: vectorizing requires significantly
187more memory.
188Returns:
189(t,) shaped array of non-negative float values indicating the nats
190of gained information from the x_batch and y_batch for locating the
191x-location of the global min.
192"""
193for_i_index_all = jnp.arange(y_batch.shape[1])194partial_conditional_entropy = functools.partial(195self._entropy_of_p_x_min_given_data,196key=key,197x_obs=x_obs,198y_obs=y_obs,199x_batch=x_batch,200y_batch=y_batch)201
202if vectorized:203conditional_entropy = jax.vmap(partial_conditional_entropy)(204for_i_index_all)205else:206conditional_entropy = jax.lax.map(partial_conditional_entropy,207for_i_index_all)208if include_initial_entropy:209if self.initial_entropy_key is None:210raise ValueError("initial_entropy_key needs to be initialized.")211initial_entropy = self._entropy_of_p_x_min_given_data(212for_i_index=0, # dummy int213key=self.initial_entropy_key,214x_obs=x_obs,215y_obs=y_obs)216return initial_entropy - conditional_entropy217else:218return - conditional_entropy219
220
221def _mean_utility(for_i_index,222key_xs,223key_ys,224key_utilities,225search_space,226budget,227utility_measure_fn,228draw_x_location_fn,229draw_y_values_fn,230y_drawn_num = 1000):231"""Evaluate the mean utility of a search space at a budget."""232key_x = key_xs[for_i_index, :]233key_y = key_ys[for_i_index, :]234key_utility = key_utilities[for_i_index, :]235
236x_locations = draw_x_location_fn(key_x, search_space, budget)237y_values = draw_y_values_fn(key_y, x_locations, y_drawn_num)238utility_values = utility_measure_fn(key_utility, x_locations, y_values)239return jnp.mean(utility_values, axis=0)240
241
242def uniform(key_x, search_space, budget):243return jax.random.uniform(244key_x,245shape=(budget, search_space.shape[0]),246minval=search_space[:, 0],247maxval=search_space[:, 1])248
249
250def mean_utility(key,251search_space,252budget,253utility_measure_fn,254draw_y_values_fn,255x_drawn_num,256y_drawn_num,257draw_x_location_fn = None,258vectorized = False):259"""Evaluate the mean utility of budget # of x-locations from a search space.260
261This function evaluates the utility (e.g., improvement) of budget # of
262x-locations drawn from the search space via sampling fn draw_x_location_fn,
263where the utility is averaged over the y values of the x-locations. The y
264values are calculated via draw_y_values_fn.
265
266Args:
267key: PRNG key for jax.random.
268search_space: (d, 2) shaped array of lower and upper bounds for x-locations.
269budget: integer number of x-locations one can afford to query.
270utility_measure_fn: a function of an instance of the UtilityMeasure class.
271draw_y_values_fn: a function which inputs the x-locations and evaluates or
272predicts their corresponding y-values. This function can be either the
273ground-truth function or a predictive function (e.g., fantasize_y_values)
274of a probablistic model (e.g., GP).
275x_drawn_num: number of MC iterations for x-location.
276y_drawn_num: number of MC iterations for y-values of each x-location draws.
277draw_x_location_fn: a function which inputs the search space and the budget
278and draws budget of x-locations within the search space. Default uniform.
279vectorized: bool to set whether to evaluate mean utility over keys in a
280vectorized manner. Note: vectorizing requires significantly more memory.
281Returns:
282(x_drawn_num,) shaped array of utilities at x-locations averaged over
283their y-values.
284
285Examples of how to set the utility_measure: ``` utility_measure =
286scores.UtilityMeasure() # if utility is is-improvement and improvement
287utility_measure_fn = lambda key, x_batch, y_batch:
288utility_measure.is_improvement(y_batch)
289utility_measure_fn= lambda key, x_batch, y_batch:
290utility_measure.improvement(y_batch)
291# if utility is information gain
292utility_measure_fn= partial(utility_measure.information_gain,
293x_obs=x_obs, y_obs=y_obs)
294```
295"""
296key_x, key_y, key_utility = jax.random.split(key, 3)297
298key_x_mc = jax.random.split(key_x, x_drawn_num)299key_y_mc = jax.random.split(key_y, x_drawn_num)300key_utility_mc = jax.random.split(key_utility, x_drawn_num)301
302if draw_x_location_fn is None:303draw_x_location_fn = uniform304partial_mean_utility = functools.partial(305_mean_utility,306key_xs=key_x_mc,307key_ys=key_y_mc,308key_utilities=key_utility_mc,309search_space=search_space,310budget=budget,311utility_measure_fn=utility_measure_fn,312draw_x_location_fn=draw_x_location_fn,313draw_y_values_fn=draw_y_values_fn,314y_drawn_num=y_drawn_num)315for_i_index_all = jnp.arange(x_drawn_num)316
317if vectorized:318mean_utility_arr = jax.vmap(partial_mean_utility, for_i_index_all)319else:320mean_utility_arr = jax.lax.map(partial_mean_utility, for_i_index_all)321
322return mean_utility_arr323
324
325def scores(326mean_utility_arr,327statistic_fns):328"""Evaluate the statisctics of mean utility of a search space and a budget.329
330Args:
331mean_utility_arr: (x_drawn_num,) shaped array of utilities at x-locations
332averaged over their y-values. This array is the output of score function.
333statistic_fns: a list of statistics to be evaluated across the mean_utlity
334including generic functions (e.g., jnp.mean & jnp.median) and user-defined
335functions.
336
337Returns:
338result: a dict containing statistics of the mean_utility, i.e., the scores.
339
340"""
341results = {}342for statistic_fn in statistic_fns:343results[statistic_fn.__name__] = statistic_fn(mean_utility_arr)344return results345# pylint: enable=g-doc-return-or-yield
346