google-research
163 строки · 5.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for scores."""
17import functools18
19from absl.testing import absltest20import jax21import jax.numpy as jnp22import jax.scipy as jscipy23from spaceopt import scores24
25
26def mocked_fantasize_y_values(key,27x_locations_for_p_min,28params,29x_obs,30y_obs,31y_fantasized_num_for_p_min=50):32del params33del x_obs34del y_obs35mu = jnp.zeros((x_locations_for_p_min.shape[0], 1))36cov = jnp.eye(x_locations_for_p_min.shape[0])37
38y_rand = jax.random.normal(key, (cov.shape[0], y_fantasized_num_for_p_min))39chol = jscipy.linalg.cholesky(cov + jnp.eye(cov.shape[0]) * 1e-4, lower=True)40fantasized_y = jnp.dot(chol, y_rand) + mu41
42return fantasized_y43
44
45def draw_x_location_fn(key, search_space, budget):46return jax.random.uniform(47key,48shape=(budget, 1),49minval=search_space[:, 0],50maxval=search_space[:, 1])51
52
53partial_fantasize_y_values = functools.partial(54mocked_fantasize_y_values, params=None, x_obs=None, y_obs=None)55
56
57def draw_y_values_fn(key, x_locations, y_drawn_num):58return partial_fantasize_y_values(59key=key,60x_locations_for_p_min=x_locations,61y_fantasized_num_for_p_min=y_drawn_num)62
63
64class ScoresTest(absltest.TestCase):65
66def setUp(self):67
68super(ScoresTest, self).setUp()69
70x_locations = jnp.linspace(-4., 4., 7)[:, None]71self.x_obs = jnp.linspace(-1., 1., 2)[:, None]72self.y_obs = jnp.array([[7.], [11.]])73
74self.x_batch = jnp.linspace(1., 3., 4)[:, None]75self.y_batch = jnp.array([[1., 5., 9.],76[2., 6., 10.],77[3., 7., 11.],78[4., 8., 12.]])79key = jax.random.PRNGKey(0)80self.utility_measure = scores.UtilityMeasure(81incumbent=6.,82x_locations_for_p_min=x_locations,83params=None,84fantasize_y_values_for_p_min=mocked_fantasize_y_values,85y_fantasized_num_for_p_min=50,86initial_entropy_key=key)87
88self.key = jax.random.PRNGKey(0)89self.budget = 1090self.search_space = jnp.array([[-4., 4.]])91
92def test_is_improvement_shape(self):93"""Test that the is_improvement output has the right shape."""94is_improvement = self.utility_measure.is_improvement(self.y_batch)95self.assertEqual(is_improvement.shape, (self.y_batch.shape[1],))96
97def test_is_improvement_values(self):98"""Test that the is_improvement output has the right value."""99is_improvement = self.utility_measure.is_improvement(self.y_batch)100self.assertTrue((is_improvement == jnp.array([True, True, False])).all())101
102def test_improvement_shape(self):103"""Test that the improvement output has the right shape."""104improvement = self.utility_measure.improvement(self.y_batch)105self.assertEqual(improvement.shape, (self.y_batch.shape[1],))106
107def test_improvement_values(self):108"""Test that the improvement output has the right value."""109improvement = self.utility_measure.improvement(self.y_batch)110self.assertTrue((improvement == jnp.array([5., 1., 0.])).all())111
112def test_information_gain_shape(self):113"""Test that the information_gain output has the right shape."""114key = jax.random.PRNGKey(1)115information_gain = self.utility_measure.information_gain(116key=key,117x_obs=self.x_obs,118y_obs=self.y_obs,119x_batch=self.x_batch,120y_batch=self.y_batch)121self.assertEqual(information_gain.shape, (self.y_batch.shape[1],))122
123def test_score_values(self):124"""Test that the improvement-based score cannot be negative."""125# pylint: disable=unused-argument126def utility_is_imp_fn(key, x_batch, y_batch):127return self.utility_measure.is_improvement(y_batch)128
129def utility_imp_fn(key, x_batch, y_batch):130return self.utility_measure.improvement(y_batch)131# pylint: enable=unused-argument132statistics_fns = [jnp.mean, jnp.median]133key = jax.random.PRNGKey(1)134
135mean_utility = scores.mean_utility(136key,137self.search_space,138self.budget,139utility_is_imp_fn,140draw_y_values_fn,141x_drawn_num=100,142y_drawn_num=100)143scores_dict = scores.scores(mean_utility, statistics_fns)144
145for j in range(len(statistics_fns)):146self.assertGreaterEqual(scores_dict[statistics_fns[j].__name__], 0.)147
148mean_utility = scores.mean_utility(149key,150self.search_space,151self.budget,152utility_imp_fn,153draw_y_values_fn,154x_drawn_num=100,155y_drawn_num=100)156scores_dict = scores.scores(mean_utility, statistics_fns)157
158for j in range(len(statistics_fns)):159self.assertGreaterEqual(scores_dict[statistics_fns[j].__name__], 0.)160
161
162if __name__ == '__main__':163absltest.main()164