google-research
421 строка · 15.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Custom metrics for sequence alignment.
17
18This module defines the following types, which serve as inputs to all metrics
19implemented here:
20
21+ GroundTruthAlignment is A tf.Tensor<int>[batch, 3, align_len] that can be
22written as tf.stack([pos_x, pos_y, enc_trans], 1) such that
23(pos_x[b][i], pos_y[b][i], enc_trans[b][i]) represents the i-th transition
24in the ground-truth alignment for example b in the minibatch.
25Both pos_x and pos_y are assumed to use one-based indexing and enc_trans
26follows the (categorical) 9-state encoding of edge types used throughout
27`learning/brain/research/combini/diff_opt/alignment/tf_ops.py`.
28
29+ SWParams is a tuple (sim_mat, gap_open, gap_extend) parameterizing the
30Smith-Waterman LP such that
31+ sim_mat is a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the
32substitution values for pairs of sequences.
33+ gap_open is a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) or
34tf.Tensor<float>[batch] with the penalties for opening a gap. Must agree
35in rank with gap_extend.
36+ gap_extend is a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) or
37tf.Tensor<float>[batch] with the penalties for extending a gap. Must agree
38in rank with gap_open.
39
40+ AlignmentOutput is a tuple (solution_values, solution_paths, sw_params) such
41that
42+ 'solution_values' contains a tf.Tensor<float>[batch] with the (soft) optimal
43Smith-Waterman scores for the batch.
44+ 'solution_paths' contains a tf.Tensor<float>[batch, len1, len2, 9] that
45describes the optimal soft alignments.
46+ 'sw_params' is a SWParams tuple as described above.
47"""
48
49
50import functools51from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Type, Union52
53import gin54import tensorflow as tf55
56from dedal import alignment57
58
59GroundTruthAlignment = tf.Tensor60PredictedPaths = tf.Tensor61SWParams = Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor, tf.Tensor]]62AlignmentOutput = Tuple[tf.Tensor, Optional[PredictedPaths], SWParams]63NaiveAlignmentOutput = Tuple[tf.Tensor, tf.Tensor, SWParams]64
65
66def confusion_matrix(67alignments_true,68sol_paths_pred):69"""Computes true, predicted and actual positives for a batch of alignments."""70batch_size = tf.shape(alignments_true)[0]71
72# Computes the number of true positives per example as an (sparse) inner73# product of two binary tensors of shape (batch_size, len_x, len_y) via74# indexing. Entirely avoids materializing one of the two tensors explicitly.75match_indices_true = alignment.alignments_to_state_indices(76alignments_true, 'match') # [n_aligned_chars_true, 3]77match_indicators_pred = alignment.paths_to_state_indicators(78sol_paths_pred, 'match') # [batch, len_x, len_y]79batch_indicators = match_indices_true[:, 0] # [n_aligned_chars_true]80matches_flat = tf.gather_nd(81match_indicators_pred, match_indices_true) # [n_aligned_chars_true]82true_positives = tf.math.unsorted_segment_sum(83matches_flat, batch_indicators, batch_size) # [batch]84
85# Compute number of predicted and ground-truth positives per example.86pred_positives = tf.reduce_sum(match_indicators_pred, axis=[1, 2])87# Note(fllinares): tf.math.bincount unsupported in TPU :(88cond_positives = tf.math.unsorted_segment_sum(89tf.ones_like(batch_indicators, tf.float32),90batch_indicators,91batch_size) # [batch]92return true_positives, pred_positives, cond_positives93
94
95@gin.configurable96class AlignmentPrecisionRecall(tf.metrics.Metric):97"""Implements precision and recall metrics for sequence alignment."""98
99def __init__(self,100name = 'alignment_pr',101threshold = None,102**kwargs):103super().__init__(name=name, **kwargs)104self._threshold = threshold105self._true_positives = tf.metrics.Mean() # TP106self._pred_positives = tf.metrics.Mean() # TP + FP107self._cond_positives = tf.metrics.Mean() # TP + FN108
109def update_state(110self,111alignments_true,112alignments_pred,113sample_weight = None):114"""Updates TP, TP + FP and TP + FN for a batch of true, pred alignments."""115if alignments_pred[1] is None:116return117
118sol_paths_pred = alignments_pred[1]119if self._threshold is not None: # Otherwise, we assume already binarized.120sol_paths_pred = tf.cast(sol_paths_pred >= self._threshold, tf.float32)121
122true_positives, pred_positives, cond_positives = confusion_matrix(123alignments_true, sol_paths_pred)124
125self._true_positives.update_state(true_positives, sample_weight)126self._pred_positives.update_state(pred_positives, sample_weight)127self._cond_positives.update_state(cond_positives, sample_weight)128
129def result(self):130true_positives = self._true_positives.result()131pred_positives = self._pred_positives.result()132cond_positives = self._cond_positives.result()133precision = tf.where(134true_positives > 0.0, true_positives / pred_positives, 0.0)135recall = tf.where(136true_positives > 0.0, true_positives / cond_positives, 0.0)137f1 = 2.0 * (precision * recall) / (precision + recall)138return {139f'{self.name}/precision': precision,140f'{self.name}/recall': recall,141f'{self.name}/f1': f1,142}143
144def reset_states(self):145self._true_positives.reset_states()146self._pred_positives.reset_states()147self._cond_positives.reset_states()148
149
150@gin.configurable151class NaiveAlignmentPrecisionRecall(tf.metrics.Metric):152"""Implements precision and recall metrics for (naive) sequence alignment."""153
154def __init__(self,155name = 'naive_alignment_pr',156threshold = None,157**kwargs):158super().__init__(name=name, **kwargs)159self._precision = tf.metrics.Precision(thresholds=threshold)160self._recall = tf.metrics.Recall(thresholds=threshold)161
162def update_state(163self,164alignments_true,165alignments_pred,166sample_weight = None):167"""Updates precision, recall for a batch of true, pred alignments."""168if alignments_pred[1] is None:169return170
171_, match_indicators_pred, sw_params = alignments_pred172sim_mat, _, _ = sw_params173shape, dtype = sim_mat.shape, match_indicators_pred.dtype174
175match_indices_true = alignment.alignments_to_state_indices(176alignments_true, 'match')177updates_true = tf.ones([tf.shape(match_indices_true)[0]], dtype=dtype)178match_indicators_true = tf.scatter_nd(179match_indices_true, updates_true, shape=shape)180
181batch = tf.shape(sample_weight)[0]182sample_weight = tf.reshape(sample_weight, [batch, 1, 1])183mask = alignment.mask_from_similarities(sim_mat, dtype=dtype)184
185self._precision.update_state(186match_indicators_true, match_indicators_pred, sample_weight * mask)187self._recall.update_state(188match_indicators_true, match_indicators_pred, sample_weight * mask)189
190def result(self):191precision, recall = self._precision.result(), self._recall.result()192f1 = 2.0 * (precision * recall) / (precision + recall)193return {194f'{self.name}/precision': precision,195f'{self.name}/recall': recall,196f'{self.name}/f1': f1,197}198
199def reset_states(self):200self._precision.reset_states()201self._recall.reset_states()202
203
204@gin.configurable205class AlignmentMSE(tf.metrics.Mean):206"""Implements mean squared error metric for sequence alignment."""207
208def __init__(self, name = 'alignment_mse', **kwargs):209super().__init__(name=name, **kwargs)210
211def update_state(212self,213alignments_true,214alignments_pred,215sample_weight = None):216"""Updates mean squared error for a batch of true vs pred alignments."""217if alignments_pred[1] is None:218return219
220sol_paths_pred = alignments_pred[1]221len_x, len_y = tf.shape(sol_paths_pred)[1], tf.shape(sol_paths_pred)[2]222sol_paths_true = alignment.alignments_to_paths(223alignments_true, len_x, len_y)224mse = tf.reduce_sum((sol_paths_pred - sol_paths_true) ** 2, axis=[1, 2, 3])225super().update_state(mse, sample_weight)226
227
228@gin.configurable229class MeanList(tf.metrics.Metric):230"""Means over ground-truth and predictions for positive and negative pairs."""231
232def __init__(self,233positive_keys = ('true', 'pred_pos'),234negative_keys = ('pred_neg',),235**kwargs):236super().__init__(**kwargs)237self._keys = tuple(positive_keys) + tuple(negative_keys)238self._process_negatives = bool(len(negative_keys))239self._means = {}240
241def _split(242self,243inputs,244return_neg = True,245):246if not self._process_negatives:247return (inputs,)248pos = tf.nest.map_structure(lambda t: t[:tf.shape(t)[0] // 2], inputs)249if return_neg:250neg = tf.nest.map_structure(lambda t: t[tf.shape(t)[0] // 2:], inputs)251return (pos, neg) if return_neg else (pos,)252
253def result(self):254return {f'{self.name}/{k}': m.result() for k, m in self._means.items()}255
256def reset_states(self):257for mean in self._means.values():258mean.reset_states()259
260
261@gin.configurable262class AlignmentStats(MeanList):263"""Tracks alignment length, number of matches and number of gaps."""264STATS = ('length', 'n_match', 'n_gap')265
266def __init__(self,267name = 'alignment_stats',268process_negatives = True,269**kwargs):270negative_keys = ('pred_neg',) if process_negatives else ()271super().__init__(name=name, negative_keys=negative_keys, **kwargs)272for stat in self.STATS:273self._means.update({f'{stat}/{k}': tf.metrics.Mean() for k in self._keys})274self._stat_fn = {275'length': alignment.length,276'n_match': functools.partial(alignment.state_count, states='match'),277'n_gap': functools.partial(alignment.state_count, states='gap_open'),278}279
280def update_state(281self,282alignments_true,283alignments_pred,284sample_weight = None):285"""Updates alignment stats for a batch of true and predicted alignments."""286del sample_weight # Logic in this metric controlled by process_negatives.287if alignments_pred[1] is None:288return289
290vals = self._split(alignments_true, False) + self._split(alignments_pred[1])291for stat in self.STATS:292for k, tensor in zip(self._keys, vals):293self._means[f'{stat}/{k}'].update_state(self._stat_fn[stat](tensor))294
295
296@gin.configurable297class AlignmentScore(MeanList):298"""Tracks alignment score / solution value."""299
300def __init__(self,301name = 'alignment_score',302process_negatives = True,303**kwargs):304negative_keys = ('pred_neg',) if process_negatives else ()305super().__init__(name=name, negative_keys=negative_keys, **kwargs)306self._means.update({k: tf.metrics.Mean() for k in self._keys})307
308def update_state(309self,310alignments_true,311alignments_pred,312sample_weight = None):313"""Updates alignment scores for a batch of true and predicted alignments."""314del sample_weight # Logic in this metric controlled by process_negatives.315
316vals_true = (self._split(alignments_pred[2], False) +317self._split(alignments_true, False))318self._means[self._keys[0]].update_state(alignment.sw_score(*vals_true))319
320vals_pred = self._split(alignments_pred[0])321for k, tensor in zip(self._keys[1:], vals_pred):322self._means[k].update_state(tensor)323
324
325@gin.configurable326class SWParamsStats(MeanList):327"""Tracks Smith-Waterman substitution costs and gap penalties."""328PARAMS = ('sim_mat', 'gap_open', 'gap_extend')329
330def __init__(self,331name = 'sw_params_stats',332process_negatives = True,333**kwargs):334positive_keys = ('pred_pos',)335negative_keys = ('pred_neg',) if process_negatives else ()336super().__init__(name=name,337positive_keys=positive_keys,338negative_keys=negative_keys,339**kwargs)340for p in self.PARAMS:341self._means.update({f'{p}/{k}': tf.metrics.Mean() for k in self._keys})342
343def update_state(344self,345alignments_true,346alignments_pred,347sample_weight = None):348"""Updates SW param stats for a batch of true and predicted alignments."""349del alignments_true # Present for compatibility with SeqAlign.350del sample_weight # Logic in this metric controlled by process_negatives.351
352vals = self._split(alignments_pred[2])353for k, sw_params in zip(self._keys, vals):354for p, t in zip(self.PARAMS, sw_params):355# Prevents entries corresponding to padding from being tracked.356mask = alignment.mask_from_similarities(t)357self._means[f'{p}/{k}'].update_state(t, sample_weight=mask)358
359
360@gin.configurable361class StratifyByPID(tf.metrics.Metric):362"""Wraps Keras metric, accounting only for examples in given PID bins."""363
364def __init__(self,365metric_cls,366lower = None,367upper = None,368step = None,369pid_definition = '3',370**kwargs):371self._lower = lower if lower is not None else 0.0372if isinstance(step, Sequence):373self._upper = self._lower + sum(step) # Ignores arg. Not used, remove?374self._steps = step375else:376self._upper = upper if upper is not None else 1.0377step = step if step is not None else self._upper - self._lower378self._steps = (step,)379
380self._stratified_metrics = []381lower = self._lower382for step in self._steps:383upper = lower + step384self._stratified_metrics.append((metric_cls(), lower, upper))385lower = upper386
387self._pid_definition = pid_definition388
389stratify_by_pid_str = f'stratify_by_pid{self._pid_definition}'390super().__init__(391name=f'{stratify_by_pid_str}/{self._stratified_metrics[0][0].name}',392**kwargs)393
394def update_state(self,395y_true,396y_pred,397sample_weight,398metadata):399pid = metadata[0]400no_pid_info = pid == -1401for metric, lower, upper in self._stratified_metrics:402in_bin = tf.logical_and(403tf.logical_or(pid == self._lower, pid > lower), pid <= upper)404keep_mask = tf.logical_or(in_bin, no_pid_info)405metric.update_state(406y_true, y_pred, sample_weight=tf.where(keep_mask, sample_weight, 0.0))407
408def result(self):409res = {}410for metric, lower, upper in self._stratified_metrics:411res_i = metric.result()412suffix = f'PID{self._pid_definition}:{lower:.2f}-{upper:.2f}'413if isinstance(res_i, Mapping):414res.update({f'{k}/{suffix}': v for k, v in res_i.items()})415else:416res[f'{self._stratified_metrics[0][0].name}/{suffix}'] = res_i417return res418
419def reset_states(self):420for metric, _, _ in self._stratified_metrics:421metric.reset_states()422