google-research

Форк
0
345 строк · 13.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Scoring functions for a search space given an iteration budget."""
17

18
import functools
19
import operator
20
from typing import Any, Callable, Dict, List
21

22
import jax
23
import jax.numpy as jnp
24

25

26
class UtilityMeasure:
27
  """Class for utility measures used in the scoring functions."""
28

29
  def __init__(self,
30
               params,
31
               x_locations_for_p_min = None,
32
               fantasize_y_values_for_p_min = None,
33
               y_fantasized_num_for_p_min = 500,
34
               initial_entropy_key=None,
35
               eps = 1e-16,
36
               incumbent = None):
37
    """Set the required arguments for calculating the utility measures.
38

39
    Args:
40
      params: a dictionary from names to values to specify GP hyperparameters.
41
      x_locations_for_p_min: (n, d) shaped array of n x-locations in d
42
        dimensions. We estimate the pmf of the global min over
43
        x_locations_for_p_min.
44
      fantasize_y_values_for_p_min: a function generating y value draws at
45
        x_locations_for_p_min. Inputs of this function include a PRNG key, the
46
        observed data, params, x_locations and the number of desired draws.
47
      y_fantasized_num_for_p_min: number of desired draws of y. This parameter
48
        will be passed to fantasize_y_values_for_p_min function.
49
      initial_entropy_key : PRNG key for jax.random. We use this key to evaluate
50
        the initial entropy of p_min given the observed data.
51
      eps: optional float tolerance to avoid numerical issues of entropy.
52
      incumbent: float y value for improvement-based utility measures. One
53
        typical choice under noise-less evals is the best observed y value.
54
    """
55
    if (x_locations_for_p_min is not None) * (fantasize_y_values_for_p_min
56
                                              is not None):
57

58
      self.fantasize_y_values_for_p_min = functools.partial(
59
          fantasize_y_values_for_p_min,
60
          x_locations_for_p_min=x_locations_for_p_min,
61
          params=params,
62
          y_fantasized_num_for_p_min=y_fantasized_num_for_p_min)
63
    self.initial_entropy_key = initial_entropy_key
64
    self.eps = eps
65
    self.incumbent = incumbent
66

67
  def is_improvement(self, y_batch):
68
    """Return whether a batch of y values can improve over the incumbent.
69

70
    Args:
71
      y_batch: (q, t) shaped array of t y values corresponding to a batch of
72
      size q of x-locations. Each column of this array corresponds to one
73
        realization of y values at q x-locations evaluated/predicted t times.
74
    Returns:
75
      (t,) shaped array of boolean values indicating whether the best of
76
      q y value within each of t realizations have improved over the incumbent.
77
    """
78
    return jnp.min(y_batch, axis=0) < self.incumbent
79

80
  def improvement(self, y_batch):
81
    """Return how much a batch of y values can improve over the incumbent.
82

83
    Args:
84
      y_batch: (q, t) shaped array of t y values corresponding to a batch of
85
      size q of x-locations. Each column of this array corresponds to one
86
        realization of y values at q x-locations evaluated/predicted t times.
87
    Returns:
88
      (t,) shaped array of non-negative float values indicating the
89
      improvement the best of q y value within each of t realizations achieves
90
      over the incumbent.
91
    """
92
    difference = self.incumbent - jnp.min(y_batch, axis=0)
93
    return jnp.maximum(0.0, difference)
94

95
  def _p_x_min(self, y_fantasized):
96
    """Estimate a probablity mass function over the x-location of global min.
97

98
    Args:
99
      y_fantasized: (n, m) shaped array of m fantasized y values over a common
100
      set of n x-locations.
101

102
    Returns:
103
      Estimated (n,) shaped array of pmf of the global min over x-location where
104
      the domain of the pmf is the common set of previous x-locations.
105

106
    """
107
    counts = jnp.bincount(
108
        jnp.argmin(y_fantasized, axis=0), length=y_fantasized.shape[0])
109
    return counts / jnp.sum(counts)
110

111
  def _entropy(self, p, eps = 1e-16):
112
    """Evaluate the entropy of an empirical probablity distribution.
113

114
    Args:
115
      p: (n,) shaped array of probability values over n x_lodations.
116
      eps: optional float tolerance to avoid numerical issues.
117

118
    Returns:
119
      estimated entropy of p.
120
    """
121
    return -jnp.sum(
122
        jnp.where(p < eps, 0., (p + eps) * jnp.log(p + eps)), axis=0)
123

124
  @functools.partial(jax.jit, static_argnums=(0,))
125
  def _entropy_of_p_x_min_given_data(
126
      self,
127
      for_i_index,
128
      key,
129
      x_obs,
130
      y_obs,
131
      x_batch = None,
132
      y_batch = None):
133
    """Compute information nats for locating global min given 1 batch data pair.
134

135
    Args:
136
      for_i_index: int used to fold in over the key and to slice y_batch.
137
      key: PRNG key for jax.random.
138
      x_obs: (k, d) shaped array of k observed x-locations in d dimensions.
139
      y_obs: (k, 1) shaped array of observed y values at x_obs.
140
      x_batch: (q, d) shaped array of q candidate x-locations in d dimensions.
141
      y_batch: (q, 1) shaped array of y values corresponding to x_batch.
142

143
    Returns:
144
      Non-negative float value of the nats of gained information from
145
      x_batch and y_batch for locating the global min.
146
    """
147
    if operator.xor(x_batch is None, y_batch is None):
148
      raise ValueError("Both x_batch and y_batch need to be provided.")
149

150
    if (x_batch is not None) and (y_batch is not None):
151
      if y_batch.ndim == 1:
152
        y_batch = y_batch[:, None]
153
      x_obs = jnp.vstack((x_obs, x_batch))
154
      y_obs = jnp.vstack((y_obs, y_batch[:, for_i_index][:, None]))
155
      key = jax.random.fold_in(key, for_i_index)
156

157
    fantasized_y_values = self.fantasize_y_values_for_p_min(
158
        key=key, x_obs=x_obs, y_obs=y_obs)
159
    p_x_min = self._p_x_min(fantasized_y_values)
160
    entropy = self._entropy(p_x_min, self.eps)
161
    return entropy
162

163
  def information_gain(self, key, x_obs,
164
                       y_obs, x_batch,
165
                       y_batch,
166
                       include_initial_entropy = True,
167
                       vectorized = False):
168
    """Compute information nats for locating the global min given batch data.
169

170
    In the below function, k refers to the number of observations,
171
    d is the data dimension, q is the size of the batch of interest,
172
    and t refers to the number of predictions/evaluations of y perfomed at the
173
    batch of x-locations.
174

175
    Args:
176
      key: PRNG key for jax.random.
177
      x_obs: (k, d) shaped array of k observed x-locations in d dimensions.
178
      y_obs: (k, 1) shaped array of observed y values at x_obs.
179
      x_batch: (q, d) shaped array of q candidate x-locations in d dimensions.
180
      y_batch: (q, t) shaped array of y values corresponding to x_batch. Each
181
        column of this array corresponds to one realization of q y values at
182
        x_batch evaluated/predicted for t times.
183
      include_initial_entropy: bool which decides whether initial entropy should
184
        be included in the calculations.
185
      vectorized: bool to set whether to evaluate conditional entropy over
186
        y_batch in a vectorized manner. Note: vectorizing requires significantly
187
          more memory.
188
    Returns:
189
      (t,) shaped array of non-negative float values indicating the nats
190
      of gained information from the x_batch and y_batch for locating the
191
      x-location of the global min.
192
    """
193
    for_i_index_all = jnp.arange(y_batch.shape[1])
194
    partial_conditional_entropy = functools.partial(
195
        self._entropy_of_p_x_min_given_data,
196
        key=key,
197
        x_obs=x_obs,
198
        y_obs=y_obs,
199
        x_batch=x_batch,
200
        y_batch=y_batch)
201

202
    if vectorized:
203
      conditional_entropy = jax.vmap(partial_conditional_entropy)(
204
          for_i_index_all)
205
    else:
206
      conditional_entropy = jax.lax.map(partial_conditional_entropy,
207
                                        for_i_index_all)
208
    if include_initial_entropy:
209
      if self.initial_entropy_key is None:
210
        raise ValueError("initial_entropy_key needs to be initialized.")
211
      initial_entropy = self._entropy_of_p_x_min_given_data(
212
          for_i_index=0,  # dummy int
213
          key=self.initial_entropy_key,
214
          x_obs=x_obs,
215
          y_obs=y_obs)
216
      return initial_entropy - conditional_entropy
217
    else:
218
      return - conditional_entropy
219

220

221
def _mean_utility(for_i_index,
222
                  key_xs,
223
                  key_ys,
224
                  key_utilities,
225
                  search_space,
226
                  budget,
227
                  utility_measure_fn,
228
                  draw_x_location_fn,
229
                  draw_y_values_fn,
230
                  y_drawn_num = 1000):
231
  """Evaluate the mean utility of a search space at a budget."""
232
  key_x = key_xs[for_i_index, :]
233
  key_y = key_ys[for_i_index, :]
234
  key_utility = key_utilities[for_i_index, :]
235

236
  x_locations = draw_x_location_fn(key_x, search_space, budget)
237
  y_values = draw_y_values_fn(key_y, x_locations, y_drawn_num)
238
  utility_values = utility_measure_fn(key_utility, x_locations, y_values)
239
  return jnp.mean(utility_values, axis=0)
240

241

242
def uniform(key_x, search_space, budget):
243
  return jax.random.uniform(
244
      key_x,
245
      shape=(budget, search_space.shape[0]),
246
      minval=search_space[:, 0],
247
      maxval=search_space[:, 1])
248

249

250
def mean_utility(key,
251
                 search_space,
252
                 budget,
253
                 utility_measure_fn,
254
                 draw_y_values_fn,
255
                 x_drawn_num,
256
                 y_drawn_num,
257
                 draw_x_location_fn = None,
258
                 vectorized = False):
259
  """Evaluate the mean utility of budget # of x-locations from a search space.
260

261
  This function evaluates the utility (e.g., improvement) of budget # of
262
  x-locations drawn from the search space via sampling fn draw_x_location_fn,
263
  where the utility is averaged over the y values of the x-locations. The y
264
  values are calculated via draw_y_values_fn.
265

266
  Args:
267
    key: PRNG key for jax.random.
268
    search_space: (d, 2) shaped array of lower and upper bounds for x-locations.
269
    budget: integer number of x-locations one can afford to query.
270
    utility_measure_fn: a function of an instance of the UtilityMeasure class.
271
    draw_y_values_fn: a function which inputs the x-locations and evaluates or
272
      predicts their corresponding y-values. This function can be either the
273
      ground-truth function or a predictive function (e.g., fantasize_y_values)
274
      of a probablistic model (e.g., GP).
275
    x_drawn_num: number of MC iterations for x-location.
276
    y_drawn_num: number of MC iterations for y-values of each x-location draws.
277
    draw_x_location_fn: a function which inputs the search space and the budget
278
      and draws budget of x-locations within the search space. Default uniform.
279
    vectorized: bool to set whether to evaluate mean utility over keys in a
280
      vectorized manner. Note: vectorizing requires significantly more memory.
281
  Returns:
282
    (x_drawn_num,) shaped array of utilities at x-locations averaged over
283
    their y-values.
284

285
  Examples of how to set the utility_measure: ``` utility_measure =
286
    scores.UtilityMeasure() # if utility is is-improvement and improvement
287
  utility_measure_fn = lambda key, x_batch, y_batch:
288
  utility_measure.is_improvement(y_batch)
289
  utility_measure_fn= lambda key, x_batch, y_batch:
290
  utility_measure.improvement(y_batch)
291
  # if utility is information gain
292
  utility_measure_fn= partial(utility_measure.information_gain,
293
  x_obs=x_obs, y_obs=y_obs)
294
  ```
295
  """
296
  key_x, key_y, key_utility = jax.random.split(key, 3)
297

298
  key_x_mc = jax.random.split(key_x, x_drawn_num)
299
  key_y_mc = jax.random.split(key_y, x_drawn_num)
300
  key_utility_mc = jax.random.split(key_utility, x_drawn_num)
301

302
  if draw_x_location_fn is None:
303
    draw_x_location_fn = uniform
304
  partial_mean_utility = functools.partial(
305
      _mean_utility,
306
      key_xs=key_x_mc,
307
      key_ys=key_y_mc,
308
      key_utilities=key_utility_mc,
309
      search_space=search_space,
310
      budget=budget,
311
      utility_measure_fn=utility_measure_fn,
312
      draw_x_location_fn=draw_x_location_fn,
313
      draw_y_values_fn=draw_y_values_fn,
314
      y_drawn_num=y_drawn_num)
315
  for_i_index_all = jnp.arange(x_drawn_num)
316

317
  if vectorized:
318
    mean_utility_arr = jax.vmap(partial_mean_utility, for_i_index_all)
319
  else:
320
    mean_utility_arr = jax.lax.map(partial_mean_utility, for_i_index_all)
321

322
  return mean_utility_arr
323

324

325
def scores(
326
    mean_utility_arr,
327
    statistic_fns):
328
  """Evaluate the statisctics of mean utility of a search space and a budget.
329

330
  Args:
331
    mean_utility_arr: (x_drawn_num,) shaped array of utilities at x-locations
332
      averaged over their y-values. This array is the output of score function.
333
    statistic_fns: a list of statistics to be evaluated across the mean_utlity
334
      including generic functions (e.g., jnp.mean & jnp.median) and user-defined
335
      functions.
336

337
  Returns:
338
    result: a dict containing statistics of the mean_utility, i.e., the scores.
339

340
  """
341
  results = {}
342
  for statistic_fn in statistic_fns:
343
    results[statistic_fn.__name__] = statistic_fn(mean_utility_arr)
344
  return results
345
# pylint: enable=g-doc-return-or-yield
346

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.