google-research
241 строка · 8.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Log data and metrics. Inspired by combini/tools/logger."""
17
18import itertools
19import os.path
20from typing import Mapping, NamedTuple, Optional, Sequence, Tuple, Type, Union
21
22from absl import logging
23import gin
24import tensorflow.compat.v2 as tf
25
26from dedal import multi_task
27from dedal.train import timer
28
29
30# For each head at each level (embeddings/alignments) a list of metrics.
31MetricCLS = Type[tf.metrics.Metric]
32MetaKeys = Union[str, Sequence[str]]
33MetricCLSWithOptionalMetaKeys = Union[MetricCLS, Tuple[MetricCLS, MetaKeys]]
34MultiTaskMetrics = multi_task.Backbone[Sequence[MetricCLSWithOptionalMetaKeys]]
35
36
37class MetricWithMetaKeys(NamedTuple):
38metric: tf.metrics.Metric
39metakeys: Optional[Sequence[str]] = None
40
41
42def metric_factory(m):
43"""Instantiates a tf.metrics.Metric, keeping track of optional metadata keys.
44
45Aims to extend tf.metrics.Metric default update_state, allowing to pass extra
46metadata when necessary. If metadata keys are provided, the metadata tensors
47indexed by those keys will be passed to the metric's update_state method as an
48extra arg `metadata`, containing a tuple of tf.Tensor of length equal to that
49of metadata keys. When no metadata keys are given, the update_state method of
50the metric is invoked as usual.
51
52Args:
53m: Either a tuple (metric_cls, metakeys), where metric_cls is a subclass of
54tf.metrics.Metric and metakeys a sequence of str-valued keys indexing
55metadata needed by the metric's update_state method, or just metric_cls,
56in which case metakeys will be assumed to be empty and no metadata will
57be passed to update_state.
58
59Returns:
60A namedtuple MetricWithMetaKeys such that:
61+ metric contains an instantiated tf.metrics.Metric object of class
62metric_cls.
63+ metakeys contains a (possibly None) sequence of str-valued keys indexing
64additional metadata tensors required by the metric's update_state
65method.
66"""
67if isinstance(m, Sequence): # m: Tuple[MetricCLS, MetaKeys].
68metric_cls, metakeys = m
69metakeys = (metakeys,) if isinstance(metakeys, str) else metakeys
70else: # m: MetricCLS.
71metric_cls = m
72metakeys = None
73return MetricWithMetaKeys(metric=metric_cls(), metakeys=metakeys)
74
75
76@gin.configurable
77class Logger:
78"""A class responsible for logging data and metrics."""
79
80def __init__(
81self,
82workdir,
83strategy,
84split = None,
85task = None,
86scalars = multi_task.Backbone(),
87images = multi_task.Backbone(),
88means = (),
89every = 1000,
90reset_every_step = False,
91start_clock = True):
92"""Initialization.
93
94Args:
95workdir: the parent directory where to store data.
96strategy: distribution strategy.
97split: usually the name of the phase (train, test, valid).
98task: usually the name of the task (train, evaluate, downstream).
99scalars: the scalar metrics to be computed and dumped.
100images: the image metrics to be computed and dumped.
101means: the name of the scalar metrics that will be means. At the very
102least, "loss" and "gradient_norm" will be present.
103every: the periodicity to log the metrics.
104reset_every_step: whether to reset the metrics at every step.
105start_clock: whether or not to start the clock at instantiation.
106"""
107split = '' if split is None else split
108self.workdir = os.path.join(workdir, split).rstrip('/')
109self._split = split
110self._task = task
111self._timer = timer.Timer()
112self._reset_every_step = reset_every_step
113self.training = task == 'train'
114
115# Take the bigger network structure.
116shape = tuple(max(scalars.shape[i], images.shape[i]) for i in range(2))
117enveloppe = multi_task.Backbone.constant_from_shape([], shape)
118
119means = set(means).union(['loss'])
120if self.training:
121means = means.union(['gradient_norm'])
122
123with strategy.scope():
124self._scalars = enveloppe.pack(
125[[metric_factory(m) for m in ms] for ms in scalars], default_value=[])
126self._images = enveloppe.pack(
127[[metric_factory(m) for m in ms] for ms in images], default_value=[])
128self._means = {name: tf.keras.metrics.Mean(name) for name in means}
129
130self._summary_writer = tf.summary.create_file_writer(self.workdir)
131self._every = every
132self._last_step = None if self.training else 0
133
134if start_clock:
135self.restart_clock()
136
137def update_mean(self, name, loss):
138if name not in self._means:
139self._means[name] = tf.keras.metrics.Mean(name=name)
140self._means[name].update_state(loss)
141
142def restart_clock(self):
143return self._timer.restart()
144
145def update(self,
146y_true,
147y_pred,
148weights,
149metadata):
150"""Update the different metrics with the new values."""
151# TODO(oliviert): improve this flatten/unflatten danse.
152# TODO(fllinares): raise exception if key not in metadata?
153y_true = y_pred.unflatten(y_true)
154weights = y_pred.unflatten(weights)
155all_metrics_with_metakeys = self._scalars.pack(
156[a + b for a, b in zip(self._scalars, self._images)])
157for metrics_with_metakeys, label, pred, batch_w in zip(
158all_metrics_with_metakeys, y_true, y_pred, weights):
159for metric, metakeys in metrics_with_metakeys:
160kwargs = ({} if metakeys is None else
161dict(metadata=tuple(metadata.get(k) for k in metakeys)))
162metric.update_state(label, pred, sample_weight=batch_w, **kwargs)
163
164def reset(self):
165for metric in self.metrics:
166metric.reset_states()
167
168def log(self, step):
169"""Log the tf summaries."""
170delta = self.restart_clock()
171with self._summary_writer.as_default():
172n_steps = self._every if self.training else (step - self._last_step)
173tf.summary.scalar('steps_per_sec', n_steps / delta, step=step)
174for metric in self.scalars:
175curr = metric.result()
176curr = curr if isinstance(curr, Mapping) else {metric.name: curr}
177for name, value in curr.items():
178tf.summary.scalar(name, value, step=step)
179for metric in self.images:
180tf.summary.image(metric.name, metric.result(), step=step)
181self._last_step = None if self.training else step
182
183@property
184def metrics(self):
185return self.images + self.scalars
186
187@property
188def images(self):
189return list(m.metric for m in itertools.chain.from_iterable(self._images))
190
191@property
192def scalars(self):
193without_means = list(
194m.metric for m in itertools.chain.from_iterable(self._scalars))
195return without_means + list(self._means.values())
196
197def debug(self, step):
198def metric_to_str(m):
199result = m.result()
200if isinstance(result, Mapping):
201return ', '.join(f'{k}: {v:.3f}' for k, v in result.items())
202return f'{m.name}: {m.result():.3f}'
203
204metrics_str = ', '.join(metric_to_str(m) for m in self.scalars)
205return f'{self._split} step {step}: {metrics_str}'
206
207def log_and_reset(self, step, force = True):
208"""Log the metrics to summaries if the step allows, and reset them.
209
210Args:
211step: the step where we are at now.
212force: should we force the behavior (typically for the last step).
213
214Returns:
215True if the metrics have been logged, False otherwise.
216"""
217if step % self._every == 0 or force:
218logging.info(self.debug(step))
219self.log(step)
220self.reset()
221return True
222if self._reset_every_step:
223self.reset()
224return False
225
226
227@gin.configurable
228class DummyLogger:
229"""A logger that logs nothing."""
230
231def update_mean(self, name, value):
232del name, value
233return
234
235def update(self, *args):
236del args
237return
238
239def log_and_reset(self, step, force = True):
240del step, force
241return False
242