google-research
77 строк · 2.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Optimizers based on scalarization.
17
18One of the simplest approaches to optimizing multi-loss problems is to scalarize
19to a real objective by combining the individual losses. Depending on how the
20scalarization is performed, different optimization algorithms arise.
21"""
22
23import gin24import tensorflow.compat.v1 as tf25
26
27from yoto.optimizers import base as optimizers_base28from yoto.optimizers import distributions29
30
31@gin.configurable("LinearlyScalarizedOptimizer")32class LinearlyScalarizedOptimizer(optimizers_base.MultiLossOptimizer):33r"""An optimizer that linearly scalarizes the losss.34
35Namely, if the losses are loss_1, ..., loss_n, then it minimizes
36\sum_i loss_i * weight_i,
37for fixed weights. The weights can be either randomly drawn from one of the
38supported distributions, or fixed.
39"""
40
41def __init__(self, problem, weights,42batch_size=None, seed=17):43"""Initializes the optimizer.44
45Args:
46problem: An instance of `problems.Problem`.
47weights: Either `distributions.DistributionSpec` class or a
48dictionary mapping the loss names to their corresponding
49weights.
50batch_size: Passed to the initializer of `MultiLossOptimizer`.
51seed: random seed to be used for sampling the weights.
52"""
53super(LinearlyScalarizedOptimizer, self).__init__(54problem, batch_size=batch_size)55sampled_weights = distributions.get_samples_as_dicts(56weights, names=self._losses_names, seed=seed)[0]57self._check_weights_dict(sampled_weights)58self._weights = sampled_weights59
60def compute_train_loss_and_update_op(self, inputs, base_optimizer):61losses, metrics = self._problem.losses_and_metrics(inputs, training=True)62del metrics63linearized_loss = 0.64for loss_name, loss_value in losses.items():65linearized_loss += tf.reduce_mean(loss_value * self._weights[loss_name])66train_op = base_optimizer.minimize(67linearized_loss, global_step=tf.train.get_or_create_global_step())68self.normal_vars = tf.trainable_variables()69return linearized_loss, train_op70
71def compute_eval_loss(self, inputs):72losses, metrics = self._problem.losses_and_metrics(inputs, training=False)73del metrics74linearized_loss = 0.75for loss_name, loss_value in losses.items():76linearized_loss += tf.reduce_mean(loss_value * self._weights[loss_name])77return linearized_loss78