google-research

Форк
0
/
scalarization.py 
77 строк · 2.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Optimizers based on scalarization.
17

18
One of the simplest approaches to optimizing multi-loss problems is to scalarize
19
to a real objective by combining the individual losses. Depending on how the
20
scalarization is performed, different optimization algorithms arise.
21
"""
22

23
import gin
24
import tensorflow.compat.v1 as tf
25

26

27
from yoto.optimizers import base as optimizers_base
28
from yoto.optimizers import distributions
29

30

31
@gin.configurable("LinearlyScalarizedOptimizer")
32
class LinearlyScalarizedOptimizer(optimizers_base.MultiLossOptimizer):
33
  r"""An optimizer that linearly scalarizes the losss.
34

35
  Namely, if the losses are loss_1, ..., loss_n, then it minimizes
36
    \sum_i loss_i * weight_i,
37
  for fixed weights. The weights can be either randomly drawn from one of the
38
  supported distributions, or fixed.
39
  """
40

41
  def __init__(self, problem, weights,
42
               batch_size=None, seed=17):
43
    """Initializes the optimizer.
44

45
    Args:
46
      problem: An instance of `problems.Problem`.
47
      weights: Either `distributions.DistributionSpec` class or a
48
        dictionary mapping the loss names to their corresponding
49
        weights.
50
      batch_size: Passed to the initializer of `MultiLossOptimizer`.
51
      seed: random seed to be used for sampling the weights.
52
    """
53
    super(LinearlyScalarizedOptimizer, self).__init__(
54
        problem, batch_size=batch_size)
55
    sampled_weights = distributions.get_samples_as_dicts(
56
        weights, names=self._losses_names, seed=seed)[0]
57
    self._check_weights_dict(sampled_weights)
58
    self._weights = sampled_weights
59

60
  def compute_train_loss_and_update_op(self, inputs, base_optimizer):
61
    losses, metrics = self._problem.losses_and_metrics(inputs, training=True)
62
    del metrics
63
    linearized_loss = 0.
64
    for loss_name, loss_value in losses.items():
65
      linearized_loss += tf.reduce_mean(loss_value * self._weights[loss_name])
66
    train_op = base_optimizer.minimize(
67
        linearized_loss, global_step=tf.train.get_or_create_global_step())
68
    self.normal_vars = tf.trainable_variables()
69
    return linearized_loss, train_op
70

71
  def compute_eval_loss(self, inputs):
72
    losses, metrics = self._problem.losses_and_metrics(inputs, training=False)
73
    del metrics
74
    linearized_loss = 0.
75
    for loss_name, loss_value in losses.items():
76
      linearized_loss += tf.reduce_mean(loss_value * self._weights[loss_name])
77
    return linearized_loss
78

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.