google-research
82 строки · 2.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Synthetic objective functions."""
17
18import jax.numpy as jnp
19import numpy as np
20
21
22def branin(x):
23"""Return the output of Branin function at input x.
24
25https://www.sfu.ca/~ssurjano/branin.html
26
27Args:
28x: (n, 2) shaped array of n x values in 2d.
29Returns:
30(n, 1) shaped array of Branin values at x
31"""
32pi = jnp.pi
33a = 1
34b = 5.1/(4*(pi**2))
35c = 5/pi
36r = 6
37s = 10
38t = 1/(8*pi)
39
40y = a * (x[:, 1] - b * (x[:, 0]**2) + c * x[:, 0] -
41r)**2 + s * (1 - t) * jnp.cos(x[:, 0]) + s
42additional_info = {}
43return y[:, None], additional_info
44
45
46def hartman6d(x):
47"""Return the output of Hartman function at input x.
48
49https://www.sfu.ca/~ssurjano/hart6.html
50
51Args:
52x: (n, 6) shaped array of n x values in 6d.
53Returns:
54(n, 1) shaped array of Hartman values at x.
55"""
56assert x.shape[1] == 6
57n = x.shape[0]
58y = np.zeros(n)
59for i in range(n):
60alpha = jnp.array([1.0, 1.2, 3.0, 3.2])
61a = jnp.array([[10, 3, 17, 3.5, 1.7, 8], [0.05, 10, 17, 0.1, 8, 14],
62[3, 3.5, 1.7, 10, 17, 8], [17, 8, 0.05, 10, 0.1, 14]])
63p = 1e-4 * jnp.array([[1312, 1696, 5569, 124, 8283, 5886],
64[2329, 4135, 8307, 3736, 1004, 9991],
65[2348, 1451, 3522, 2883, 3047, 6650],
66[4047, 8828, 8732, 5743, 1091, 381]])
67
68outer = 0
69for ii in range(4):
70inner = 0
71for jj in range(6):
72xj = x[i, jj]
73aij = a[ii, jj]
74pij = p[ii, jj]
75inner = inner + aij*(xj-pij)**2
76
77new = alpha[ii] * jnp.exp(-inner)
78outer = outer + new
79
80y[i] = -outer
81additional_info = {}
82return y.reshape((n, 1)), additional_info
83