google-research

Форк
0
/
synthetic_objective_functions.py 
82 строки · 2.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Synthetic objective functions."""
17

18
import jax.numpy as jnp
19
import numpy as np
20

21

22
def branin(x):
23
  """Return the output of Branin function at input x.
24

25
  https://www.sfu.ca/~ssurjano/branin.html
26

27
  Args:
28
    x: (n, 2) shaped array of n x values in 2d.
29
  Returns:
30
    (n, 1) shaped array of Branin values at x
31
  """
32
  pi = jnp.pi
33
  a = 1
34
  b = 5.1/(4*(pi**2))
35
  c = 5/pi
36
  r = 6
37
  s = 10
38
  t = 1/(8*pi)
39

40
  y = a * (x[:, 1] - b * (x[:, 0]**2) + c * x[:, 0] -
41
           r)**2 + s * (1 - t) * jnp.cos(x[:, 0]) + s
42
  additional_info = {}
43
  return y[:, None], additional_info
44

45

46
def hartman6d(x):
47
  """Return the output of Hartman function at input x.
48

49
  https://www.sfu.ca/~ssurjano/hart6.html
50

51
  Args:
52
    x: (n, 6) shaped array of n x values in 6d.
53
  Returns:
54
    (n, 1) shaped array of Hartman values at x.
55
  """
56
  assert x.shape[1] == 6
57
  n = x.shape[0]
58
  y = np.zeros(n)
59
  for i in range(n):
60
    alpha = jnp.array([1.0, 1.2, 3.0, 3.2])
61
    a = jnp.array([[10, 3, 17, 3.5, 1.7, 8], [0.05, 10, 17, 0.1, 8, 14],
62
                   [3, 3.5, 1.7, 10, 17, 8], [17, 8, 0.05, 10, 0.1, 14]])
63
    p = 1e-4 * jnp.array([[1312, 1696, 5569, 124, 8283, 5886],
64
                          [2329, 4135, 8307, 3736, 1004, 9991],
65
                          [2348, 1451, 3522, 2883, 3047, 6650],
66
                          [4047, 8828, 8732, 5743, 1091, 381]])
67

68
    outer = 0
69
    for ii in range(4):
70
      inner = 0
71
      for jj in range(6):
72
        xj = x[i, jj]
73
        aij = a[ii, jj]
74
        pij = p[ii, jj]
75
        inner = inner + aij*(xj-pij)**2
76

77
      new = alpha[ii] * jnp.exp(-inner)
78
      outer = outer + new
79

80
    y[i] = -outer
81
  additional_info = {}
82
  return y.reshape((n, 1)), additional_info
83

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.