google-research

Форк
0
/
search_spaces.py 
70 строк · 2.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Generating random search spaces given a base search space."""
17
import jax
18
import jax.numpy as jnp
19

20
# pylint: disable=g-doc-return-or-yield
21
# pylint: disable=line-too-long
22

23

24
def random_subinterval(key, interval, reduce_rate):
25
  """Generate a reduced interval from a base interval.
26

27
  Args:
28
    key: PRNG key for jax.random.
29
    interval: (2,) shaped array of min and max values of the interval.
30
    reduce_rate: interval reduction rate in (0, 1].
31

32
  Returns: (2,) shaped array of min and max values of the new interval.
33
  """
34
  lower = interval[0]
35
  upper = interval[1]
36
  target_length = reduce_rate * (upper - lower)
37
  lower_new = jax.random.uniform(
38
      key, minval=lower, maxval=upper - target_length)
39
  upper_new = lower_new + target_length
40
  return jnp.array([lower_new, upper_new])
41

42

43
def generate_search_space_reduce_vol(key, search_space, reduce_rate=1/2):
44
  """Generate a reduced volumed search space from a base search space.
45

46
  Args:
47
    key: PRNG key for jax.random.
48
    search_space: (d,2) shaped array of min and max values.
49
    reduce_rate: volume reduction rate in (0, 1].
50

51
  Returns: (d,2) shaped array of min and max values of the new search space.
52
  """
53
  reduce_rate_dim = reduce_rate**(1/search_space.shape[0])
54
  keys = jax.random.split(key, search_space.shape[0])
55
  search_space_reduced = jax.vmap(
56
      random_subinterval, in_axes=(0, 0, None))(keys, search_space,
57
                                                reduce_rate_dim)
58
  condition = reduce_rate == 1
59
  return jnp.where(condition, search_space, search_space_reduced)
60

61

62
def eval_vol(search_space):
63
  """Compute volume of a hyperrectangular search space.
64

65
  Args:
66
    search_space: (d,2) shaped array of min and max values.
67

68
  Returns: volume of the search space.
69
  """
70
  return jnp.prod(search_space[:, 1]-search_space[:, 0])
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.