google-research

Форк
0
/
align_metrics.py 
421 строка · 15.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Custom metrics for sequence alignment.
17

18
This module defines the following types, which serve as inputs to all metrics
19
implemented here:
20

21
+ GroundTruthAlignment is A tf.Tensor<int>[batch, 3, align_len] that can be
22
  written as tf.stack([pos_x, pos_y, enc_trans], 1) such that
23
    (pos_x[b][i], pos_y[b][i], enc_trans[b][i]) represents the i-th transition
24
  in the ground-truth alignment for example b in the minibatch.
25
  Both pos_x and pos_y are assumed to use one-based indexing and enc_trans
26
  follows the (categorical) 9-state encoding of edge types used throughout
27
  `learning/brain/research/combini/diff_opt/alignment/tf_ops.py`.
28

29
+ SWParams is a tuple (sim_mat, gap_open, gap_extend) parameterizing the
30
  Smith-Waterman LP such that
31
  + sim_mat is a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the
32
    substitution values for pairs of sequences.
33
  + gap_open is a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) or
34
    tf.Tensor<float>[batch] with the penalties for opening a gap. Must agree
35
    in rank with gap_extend.
36
  + gap_extend is a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) or
37
    tf.Tensor<float>[batch] with the penalties for extending a gap. Must agree
38
    in rank with gap_open.
39

40
+ AlignmentOutput is a tuple (solution_values, solution_paths, sw_params) such
41
  that
42
  + 'solution_values' contains a tf.Tensor<float>[batch] with the (soft) optimal
43
    Smith-Waterman scores for the batch.
44
  + 'solution_paths' contains a tf.Tensor<float>[batch, len1, len2, 9] that
45
    describes the optimal soft alignments.
46
  + 'sw_params' is a SWParams tuple as described above.
47
"""
48

49

50
import functools
51
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Type, Union
52

53
import gin
54
import tensorflow as tf
55

56
from dedal import alignment
57

58

59
GroundTruthAlignment = tf.Tensor
60
PredictedPaths = tf.Tensor
61
SWParams = Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor, tf.Tensor]]
62
AlignmentOutput = Tuple[tf.Tensor, Optional[PredictedPaths], SWParams]
63
NaiveAlignmentOutput = Tuple[tf.Tensor, tf.Tensor, SWParams]
64

65

66
def confusion_matrix(
67
    alignments_true,
68
    sol_paths_pred):
69
  """Computes true, predicted and actual positives for a batch of alignments."""
70
  batch_size = tf.shape(alignments_true)[0]
71

72
  # Computes the number of true positives per example as an (sparse) inner
73
  # product of two binary tensors of shape (batch_size, len_x, len_y) via
74
  # indexing. Entirely avoids materializing one of the two tensors explicitly.
75
  match_indices_true = alignment.alignments_to_state_indices(
76
      alignments_true, 'match')  # [n_aligned_chars_true, 3]
77
  match_indicators_pred = alignment.paths_to_state_indicators(
78
      sol_paths_pred, 'match')  # [batch, len_x, len_y]
79
  batch_indicators = match_indices_true[:, 0]  # [n_aligned_chars_true]
80
  matches_flat = tf.gather_nd(
81
      match_indicators_pred, match_indices_true)  # [n_aligned_chars_true]
82
  true_positives = tf.math.unsorted_segment_sum(
83
      matches_flat, batch_indicators, batch_size)  # [batch]
84

85
  # Compute number of predicted and ground-truth positives per example.
86
  pred_positives = tf.reduce_sum(match_indicators_pred, axis=[1, 2])
87
  # Note(fllinares): tf.math.bincount unsupported in TPU :(
88
  cond_positives = tf.math.unsorted_segment_sum(
89
      tf.ones_like(batch_indicators, tf.float32),
90
      batch_indicators,
91
      batch_size)  # [batch]
92
  return true_positives, pred_positives, cond_positives
93

94

95
@gin.configurable
96
class AlignmentPrecisionRecall(tf.metrics.Metric):
97
  """Implements precision and recall metrics for sequence alignment."""
98

99
  def __init__(self,
100
               name = 'alignment_pr',
101
               threshold = None,
102
               **kwargs):
103
    super().__init__(name=name, **kwargs)
104
    self._threshold = threshold
105
    self._true_positives = tf.metrics.Mean()  # TP
106
    self._pred_positives = tf.metrics.Mean()  # TP + FP
107
    self._cond_positives = tf.metrics.Mean()  # TP + FN
108

109
  def update_state(
110
      self,
111
      alignments_true,
112
      alignments_pred,
113
      sample_weight = None):
114
    """Updates TP, TP + FP and TP + FN for a batch of true, pred alignments."""
115
    if alignments_pred[1] is None:
116
      return
117

118
    sol_paths_pred = alignments_pred[1]
119
    if self._threshold is not None:  # Otherwise, we assume already binarized.
120
      sol_paths_pred = tf.cast(sol_paths_pred >= self._threshold, tf.float32)
121

122
    true_positives, pred_positives, cond_positives = confusion_matrix(
123
        alignments_true, sol_paths_pred)
124

125
    self._true_positives.update_state(true_positives, sample_weight)
126
    self._pred_positives.update_state(pred_positives, sample_weight)
127
    self._cond_positives.update_state(cond_positives, sample_weight)
128

129
  def result(self):
130
    true_positives = self._true_positives.result()
131
    pred_positives = self._pred_positives.result()
132
    cond_positives = self._cond_positives.result()
133
    precision = tf.where(
134
        true_positives > 0.0, true_positives / pred_positives, 0.0)
135
    recall = tf.where(
136
        true_positives > 0.0, true_positives / cond_positives, 0.0)
137
    f1 = 2.0 * (precision * recall) / (precision + recall)
138
    return {
139
        f'{self.name}/precision': precision,
140
        f'{self.name}/recall': recall,
141
        f'{self.name}/f1': f1,
142
    }
143

144
  def reset_states(self):
145
    self._true_positives.reset_states()
146
    self._pred_positives.reset_states()
147
    self._cond_positives.reset_states()
148

149

150
@gin.configurable
151
class NaiveAlignmentPrecisionRecall(tf.metrics.Metric):
152
  """Implements precision and recall metrics for (naive) sequence alignment."""
153

154
  def __init__(self,
155
               name = 'naive_alignment_pr',
156
               threshold = None,
157
               **kwargs):
158
    super().__init__(name=name, **kwargs)
159
    self._precision = tf.metrics.Precision(thresholds=threshold)
160
    self._recall = tf.metrics.Recall(thresholds=threshold)
161

162
  def update_state(
163
      self,
164
      alignments_true,
165
      alignments_pred,
166
      sample_weight = None):
167
    """Updates precision, recall for a batch of true, pred alignments."""
168
    if alignments_pred[1] is None:
169
      return
170

171
    _, match_indicators_pred, sw_params = alignments_pred
172
    sim_mat, _, _ = sw_params
173
    shape, dtype = sim_mat.shape, match_indicators_pred.dtype
174

175
    match_indices_true = alignment.alignments_to_state_indices(
176
        alignments_true, 'match')
177
    updates_true = tf.ones([tf.shape(match_indices_true)[0]], dtype=dtype)
178
    match_indicators_true = tf.scatter_nd(
179
        match_indices_true, updates_true, shape=shape)
180

181
    batch = tf.shape(sample_weight)[0]
182
    sample_weight = tf.reshape(sample_weight, [batch, 1, 1])
183
    mask = alignment.mask_from_similarities(sim_mat, dtype=dtype)
184

185
    self._precision.update_state(
186
        match_indicators_true, match_indicators_pred, sample_weight * mask)
187
    self._recall.update_state(
188
        match_indicators_true, match_indicators_pred, sample_weight * mask)
189

190
  def result(self):
191
    precision, recall = self._precision.result(), self._recall.result()
192
    f1 = 2.0 * (precision * recall) / (precision + recall)
193
    return {
194
        f'{self.name}/precision': precision,
195
        f'{self.name}/recall': recall,
196
        f'{self.name}/f1': f1,
197
    }
198

199
  def reset_states(self):
200
    self._precision.reset_states()
201
    self._recall.reset_states()
202

203

204
@gin.configurable
205
class AlignmentMSE(tf.metrics.Mean):
206
  """Implements mean squared error metric for sequence alignment."""
207

208
  def __init__(self, name = 'alignment_mse', **kwargs):
209
    super().__init__(name=name, **kwargs)
210

211
  def update_state(
212
      self,
213
      alignments_true,
214
      alignments_pred,
215
      sample_weight = None):
216
    """Updates mean squared error for a batch of true vs pred alignments."""
217
    if alignments_pred[1] is None:
218
      return
219

220
    sol_paths_pred = alignments_pred[1]
221
    len_x, len_y = tf.shape(sol_paths_pred)[1], tf.shape(sol_paths_pred)[2]
222
    sol_paths_true = alignment.alignments_to_paths(
223
        alignments_true, len_x, len_y)
224
    mse = tf.reduce_sum((sol_paths_pred - sol_paths_true) ** 2, axis=[1, 2, 3])
225
    super().update_state(mse, sample_weight)
226

227

228
@gin.configurable
229
class MeanList(tf.metrics.Metric):
230
  """Means over ground-truth and predictions for positive and negative pairs."""
231

232
  def __init__(self,
233
               positive_keys = ('true', 'pred_pos'),
234
               negative_keys = ('pred_neg',),
235
               **kwargs):
236
    super().__init__(**kwargs)
237
    self._keys = tuple(positive_keys) + tuple(negative_keys)
238
    self._process_negatives = bool(len(negative_keys))
239
    self._means = {}
240

241
  def _split(
242
      self,
243
      inputs,
244
      return_neg = True,
245
  ):
246
    if not self._process_negatives:
247
      return (inputs,)
248
    pos = tf.nest.map_structure(lambda t: t[:tf.shape(t)[0] // 2], inputs)
249
    if return_neg:
250
      neg = tf.nest.map_structure(lambda t: t[tf.shape(t)[0] // 2:], inputs)
251
    return (pos, neg) if return_neg else (pos,)
252

253
  def result(self):
254
    return {f'{self.name}/{k}': m.result() for k, m in self._means.items()}
255

256
  def reset_states(self):
257
    for mean in self._means.values():
258
      mean.reset_states()
259

260

261
@gin.configurable
262
class AlignmentStats(MeanList):
263
  """Tracks alignment length, number of matches and number of gaps."""
264
  STATS = ('length', 'n_match', 'n_gap')
265

266
  def __init__(self,
267
               name = 'alignment_stats',
268
               process_negatives = True,
269
               **kwargs):
270
    negative_keys = ('pred_neg',) if process_negatives else ()
271
    super().__init__(name=name, negative_keys=negative_keys, **kwargs)
272
    for stat in self.STATS:
273
      self._means.update({f'{stat}/{k}': tf.metrics.Mean() for k in self._keys})
274
    self._stat_fn = {
275
        'length': alignment.length,
276
        'n_match': functools.partial(alignment.state_count, states='match'),
277
        'n_gap': functools.partial(alignment.state_count, states='gap_open'),
278
    }
279

280
  def update_state(
281
      self,
282
      alignments_true,
283
      alignments_pred,
284
      sample_weight = None):
285
    """Updates alignment stats for a batch of true and predicted alignments."""
286
    del sample_weight  # Logic in this metric controlled by process_negatives.
287
    if alignments_pred[1] is None:
288
      return
289

290
    vals = self._split(alignments_true, False) + self._split(alignments_pred[1])
291
    for stat in self.STATS:
292
      for k, tensor in zip(self._keys, vals):
293
        self._means[f'{stat}/{k}'].update_state(self._stat_fn[stat](tensor))
294

295

296
@gin.configurable
297
class AlignmentScore(MeanList):
298
  """Tracks alignment score / solution value."""
299

300
  def __init__(self,
301
               name = 'alignment_score',
302
               process_negatives = True,
303
               **kwargs):
304
    negative_keys = ('pred_neg',) if process_negatives else ()
305
    super().__init__(name=name, negative_keys=negative_keys, **kwargs)
306
    self._means.update({k: tf.metrics.Mean() for k in self._keys})
307

308
  def update_state(
309
      self,
310
      alignments_true,
311
      alignments_pred,
312
      sample_weight = None):
313
    """Updates alignment scores for a batch of true and predicted alignments."""
314
    del sample_weight  # Logic in this metric controlled by process_negatives.
315

316
    vals_true = (self._split(alignments_pred[2], False) +
317
                 self._split(alignments_true, False))
318
    self._means[self._keys[0]].update_state(alignment.sw_score(*vals_true))
319

320
    vals_pred = self._split(alignments_pred[0])
321
    for k, tensor in zip(self._keys[1:], vals_pred):
322
      self._means[k].update_state(tensor)
323

324

325
@gin.configurable
326
class SWParamsStats(MeanList):
327
  """Tracks Smith-Waterman substitution costs and gap penalties."""
328
  PARAMS = ('sim_mat', 'gap_open', 'gap_extend')
329

330
  def __init__(self,
331
               name = 'sw_params_stats',
332
               process_negatives = True,
333
               **kwargs):
334
    positive_keys = ('pred_pos',)
335
    negative_keys = ('pred_neg',) if process_negatives else ()
336
    super().__init__(name=name,
337
                     positive_keys=positive_keys,
338
                     negative_keys=negative_keys,
339
                     **kwargs)
340
    for p in self.PARAMS:
341
      self._means.update({f'{p}/{k}': tf.metrics.Mean() for k in self._keys})
342

343
  def update_state(
344
      self,
345
      alignments_true,
346
      alignments_pred,
347
      sample_weight = None):
348
    """Updates SW param stats for a batch of true and predicted alignments."""
349
    del alignments_true  # Present for compatibility with SeqAlign.
350
    del sample_weight  # Logic in this metric controlled by process_negatives.
351

352
    vals = self._split(alignments_pred[2])
353
    for k, sw_params in zip(self._keys, vals):
354
      for p, t in zip(self.PARAMS, sw_params):
355
        # Prevents entries corresponding to padding from being tracked.
356
        mask = alignment.mask_from_similarities(t)
357
        self._means[f'{p}/{k}'].update_state(t, sample_weight=mask)
358

359

360
@gin.configurable
361
class StratifyByPID(tf.metrics.Metric):
362
  """Wraps Keras metric, accounting only for examples in given PID bins."""
363

364
  def __init__(self,
365
               metric_cls,
366
               lower = None,
367
               upper = None,
368
               step = None,
369
               pid_definition = '3',
370
               **kwargs):
371
    self._lower = lower if lower is not None else 0.0
372
    if isinstance(step, Sequence):
373
      self._upper = self._lower + sum(step)  # Ignores arg. Not used, remove?
374
      self._steps = step
375
    else:
376
      self._upper = upper if upper is not None else 1.0
377
      step = step if step is not None else self._upper - self._lower
378
      self._steps = (step,)
379

380
    self._stratified_metrics = []
381
    lower = self._lower
382
    for step in self._steps:
383
      upper = lower + step
384
      self._stratified_metrics.append((metric_cls(), lower, upper))
385
      lower = upper
386

387
    self._pid_definition = pid_definition
388

389
    stratify_by_pid_str = f'stratify_by_pid{self._pid_definition}'
390
    super().__init__(
391
        name=f'{stratify_by_pid_str}/{self._stratified_metrics[0][0].name}',
392
        **kwargs)
393

394
  def update_state(self,
395
                   y_true,
396
                   y_pred,
397
                   sample_weight,
398
                   metadata):
399
    pid = metadata[0]
400
    no_pid_info = pid == -1
401
    for metric, lower, upper in self._stratified_metrics:
402
      in_bin = tf.logical_and(
403
          tf.logical_or(pid == self._lower, pid > lower), pid <= upper)
404
      keep_mask = tf.logical_or(in_bin, no_pid_info)
405
      metric.update_state(
406
          y_true, y_pred, sample_weight=tf.where(keep_mask, sample_weight, 0.0))
407

408
  def result(self):
409
    res = {}
410
    for metric, lower, upper in self._stratified_metrics:
411
      res_i = metric.result()
412
      suffix = f'PID{self._pid_definition}:{lower:.2f}-{upper:.2f}'
413
      if isinstance(res_i, Mapping):
414
        res.update({f'{k}/{suffix}': v for k, v in res_i.items()})
415
      else:
416
        res[f'{self._stratified_metrics[0][0].name}/{suffix}'] = res_i
417
    return res
418

419
  def reset_states(self):
420
    for metric, _, _ in self._stratified_metrics:
421
      metric.reset_states()
422

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.