google-research

Форк
0
241 строка · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Log data and metrics. Inspired by combini/tools/logger."""
17

18
import itertools
19
import os.path
20
from typing import Mapping, NamedTuple, Optional, Sequence, Tuple, Type, Union
21

22
from absl import logging
23
import gin
24
import tensorflow.compat.v2 as tf
25

26
from dedal import multi_task
27
from dedal.train import timer
28

29

30
# For each head at each level (embeddings/alignments) a list of metrics.
31
MetricCLS = Type[tf.metrics.Metric]
32
MetaKeys = Union[str, Sequence[str]]
33
MetricCLSWithOptionalMetaKeys = Union[MetricCLS, Tuple[MetricCLS, MetaKeys]]
34
MultiTaskMetrics = multi_task.Backbone[Sequence[MetricCLSWithOptionalMetaKeys]]
35

36

37
class MetricWithMetaKeys(NamedTuple):
38
  metric: tf.metrics.Metric
39
  metakeys: Optional[Sequence[str]] = None
40

41

42
def metric_factory(m):
43
  """Instantiates a tf.metrics.Metric, keeping track of optional metadata keys.
44

45
  Aims to extend tf.metrics.Metric default update_state, allowing to pass extra
46
  metadata when necessary. If metadata keys are provided, the metadata tensors
47
  indexed by those keys will be passed to the metric's update_state method as an
48
  extra arg `metadata`, containing a tuple of tf.Tensor of length equal to that
49
  of metadata keys. When no metadata keys are given, the update_state method of
50
  the metric is invoked as usual.
51

52
  Args:
53
    m: Either a tuple (metric_cls, metakeys), where metric_cls is a subclass of
54
      tf.metrics.Metric and metakeys a sequence of str-valued keys indexing
55
      metadata needed by the metric's update_state method, or just metric_cls,
56
      in which case metakeys will be assumed to be empty and no metadata will
57
      be passed to update_state.
58

59
  Returns:
60
    A namedtuple MetricWithMetaKeys such that:
61
      + metric contains an instantiated tf.metrics.Metric object of class
62
        metric_cls.
63
      + metakeys contains a (possibly None) sequence of str-valued keys indexing
64
        additional metadata tensors required by the metric's update_state
65
        method.
66
  """
67
  if isinstance(m, Sequence):  # m: Tuple[MetricCLS, MetaKeys].
68
    metric_cls, metakeys = m
69
    metakeys = (metakeys,) if isinstance(metakeys, str) else metakeys
70
  else:  # m: MetricCLS.
71
    metric_cls = m
72
    metakeys = None
73
  return MetricWithMetaKeys(metric=metric_cls(), metakeys=metakeys)
74

75

76
@gin.configurable
77
class Logger:
78
  """A class responsible for logging data and metrics."""
79

80
  def __init__(
81
      self,
82
      workdir,
83
      strategy,
84
      split = None,
85
      task = None,
86
      scalars = multi_task.Backbone(),
87
      images = multi_task.Backbone(),
88
      means = (),
89
      every = 1000,
90
      reset_every_step = False,
91
      start_clock = True):
92
    """Initialization.
93

94
    Args:
95
      workdir: the parent directory where to store data.
96
      strategy: distribution strategy.
97
      split: usually the name of the phase (train, test, valid).
98
      task: usually the name of the task (train, evaluate, downstream).
99
      scalars: the scalar metrics to be computed and dumped.
100
      images: the image metrics to be computed and dumped.
101
      means: the name of the scalar metrics that will be means. At the very
102
        least, "loss" and "gradient_norm" will be present.
103
      every: the periodicity to log the metrics.
104
      reset_every_step: whether to reset the metrics at every step.
105
      start_clock: whether or not to start the clock at instantiation.
106
    """
107
    split = '' if split is None else split
108
    self.workdir = os.path.join(workdir, split).rstrip('/')
109
    self._split = split
110
    self._task = task
111
    self._timer = timer.Timer()
112
    self._reset_every_step = reset_every_step
113
    self.training = task == 'train'
114

115
    # Take the bigger network structure.
116
    shape = tuple(max(scalars.shape[i], images.shape[i]) for i in range(2))
117
    enveloppe = multi_task.Backbone.constant_from_shape([], shape)
118

119
    means = set(means).union(['loss'])
120
    if self.training:
121
      means = means.union(['gradient_norm'])
122

123
    with strategy.scope():
124
      self._scalars = enveloppe.pack(
125
          [[metric_factory(m) for m in ms] for ms in scalars], default_value=[])
126
      self._images = enveloppe.pack(
127
          [[metric_factory(m) for m in ms] for ms in images], default_value=[])
128
      self._means = {name: tf.keras.metrics.Mean(name) for name in means}
129

130
    self._summary_writer = tf.summary.create_file_writer(self.workdir)
131
    self._every = every
132
    self._last_step = None if self.training else 0
133

134
    if start_clock:
135
      self.restart_clock()
136

137
  def update_mean(self, name, loss):
138
    if name not in self._means:
139
      self._means[name] = tf.keras.metrics.Mean(name=name)
140
    self._means[name].update_state(loss)
141

142
  def restart_clock(self):
143
    return self._timer.restart()
144

145
  def update(self,
146
             y_true,
147
             y_pred,
148
             weights,
149
             metadata):
150
    """Update the different metrics with the new values."""
151
    # TODO(oliviert): improve this flatten/unflatten danse.
152
    # TODO(fllinares): raise exception if key not in metadata?
153
    y_true = y_pred.unflatten(y_true)
154
    weights = y_pred.unflatten(weights)
155
    all_metrics_with_metakeys = self._scalars.pack(
156
        [a + b for a, b in zip(self._scalars, self._images)])
157
    for metrics_with_metakeys, label, pred, batch_w in zip(
158
        all_metrics_with_metakeys, y_true, y_pred, weights):
159
      for metric, metakeys in metrics_with_metakeys:
160
        kwargs = ({} if metakeys is None else
161
                  dict(metadata=tuple(metadata.get(k) for k in metakeys)))
162
        metric.update_state(label, pred, sample_weight=batch_w, **kwargs)
163

164
  def reset(self):
165
    for metric in self.metrics:
166
      metric.reset_states()
167

168
  def log(self, step):
169
    """Log the tf summaries."""
170
    delta = self.restart_clock()
171
    with self._summary_writer.as_default():
172
      n_steps = self._every if self.training else (step - self._last_step)
173
      tf.summary.scalar('steps_per_sec', n_steps / delta, step=step)
174
      for metric in self.scalars:
175
        curr = metric.result()
176
        curr = curr if isinstance(curr, Mapping) else {metric.name: curr}
177
        for name, value in curr.items():
178
          tf.summary.scalar(name, value, step=step)
179
      for metric in self.images:
180
        tf.summary.image(metric.name, metric.result(), step=step)
181
    self._last_step = None if self.training else step
182

183
  @property
184
  def metrics(self):
185
    return self.images + self.scalars
186

187
  @property
188
  def images(self):
189
    return list(m.metric for m in itertools.chain.from_iterable(self._images))
190

191
  @property
192
  def scalars(self):
193
    without_means = list(
194
        m.metric for m in itertools.chain.from_iterable(self._scalars))
195
    return without_means + list(self._means.values())
196

197
  def debug(self, step):
198
    def metric_to_str(m):
199
      result = m.result()
200
      if isinstance(result, Mapping):
201
        return ', '.join(f'{k}: {v:.3f}' for k, v in result.items())
202
      return f'{m.name}: {m.result():.3f}'
203

204
    metrics_str = ', '.join(metric_to_str(m) for m in self.scalars)
205
    return f'{self._split} step {step}: {metrics_str}'
206

207
  def log_and_reset(self, step, force = True):
208
    """Log the metrics to summaries if the step allows, and reset them.
209

210
    Args:
211
      step: the step where we are at now.
212
      force: should we force the behavior (typically for the last step).
213

214
    Returns:
215
      True if the metrics have been logged, False otherwise.
216
    """
217
    if step % self._every == 0 or force:
218
      logging.info(self.debug(step))
219
      self.log(step)
220
      self.reset()
221
      return True
222
    if self._reset_every_step:
223
      self.reset()
224
    return False
225

226

227
@gin.configurable
228
class DummyLogger:
229
  """A logger that logs nothing."""
230

231
  def update_mean(self, name, value):
232
    del name, value
233
    return
234

235
  def update(self, *args):
236
    del args
237
    return
238

239
  def log_and_reset(self, step, force = True):
240
    del step, force
241
    return False
242

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.