google-research
84 строки · 2.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for the scalarization optimizers."""
17
18from absl.testing import parameterized19import tensorflow.compat.v1 as tf20
21from yoto.optimizers import scalarization22import yoto.problems as problems23
24
25class ProblemWithConstantLosses(problems.Problem):26
27def __init__(self, losses_values):28self._losses_values = losses_values29self._dummy = tf.Variable(0.)30
31def losses_and_metrics(self, inputs, inputs_extra=None, training=False):32"""Map the inputs to a {loss_name: loss_tensor} dictionary."""33del inputs34del inputs_extra35losses = {key: value + 0 * self._dummy36for key, value in self._losses_values.items()}37return losses, {}38
39@property40def losses_keys(self):41return tuple(sorted(self._losses_values.keys()))42
43def initialize_model(self):44pass45
46def module_spec(self):47pass48
49
50class LinearlyScalarizedOptimizerTest(parameterized.TestCase,51tf.test.TestCase):52
53def test_check_weighted_value_on_constant_losses(self):54weights = {"a": tf.constant(0.5),55"b": tf.constant(0.3),56"c": tf.constant(0.4)}57losses = {"a": -15, "b": .4, "c": .3}58optimizer = scalarization.LinearlyScalarizedOptimizer(59problem=ProblemWithConstantLosses(losses), weights=weights)60loss, _ = optimizer.compute_train_loss_and_update_op(61inputs=dict(), base_optimizer=tf.train.GradientDescentOptimizer(0.))62with self.cached_session() as session:63session.run(tf.initializers.global_variables())64self.assertAllClose(loss,65sum(weights[key] * losses[key] for key in weights))66
67def test_exception_thrown_when_weights_is_of_invalid_type(self):68losses = {"a": -15, "b": .4, "c": .3}69# Should fail as `weights` is neither a dict nor in the enum.70with self.assertRaises(TypeError):71_ = scalarization.LinearlyScalarizedOptimizer(72problem=ProblemWithConstantLosses(losses), weights=123)73
74def test_throws_exception_when_weights_key_is_missing(self):75losses = {"a": -15, "b": .4, "c": .3}76weights = {"a": tf.constant(0.5),77"b": tf.constant(0.3)} # Misses the key "c".78with self.assertRaises(ValueError):79_ = scalarization.LinearlyScalarizedOptimizer(80problem=ProblemWithConstantLosses(losses), weights=weights)81
82if __name__ == "__main__":83tf.disable_eager_execution()84tf.test.main()85