google-research
328 строк · 13.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implements custom losses."""
17
18import functools19from typing import NamedTuple, Optional, Tuple, Union20
21import gin22import tensorflow as tf23
24from dedal import alignment25from dedal import multi_task26
27
28@gin.configurable29class WeightedLoss(NamedTuple):30weight: float31loss: tf.keras.losses.Loss32
33
34MaybeWeightedLoss = Union[WeightedLoss, tf.keras.losses.Loss]35NestedWeights = multi_task.Backbone[Optional[tf.Tensor]]36SWParams = Tuple[tf.Tensor, tf.Tensor, tf.Tensor]37AlignmentOutput = Tuple[tf.Tensor, # Solution values.38Optional[tf.Tensor], # Solution paths.39SWParams, # DP parameters.40]41NaiveAlignmentOutput = Tuple[tf.Tensor, tf.Tensor, SWParams]42
43
44@gin.configurable45class SmithWatermanLoss(tf.losses.Loss):46"""Implements a loss for differentiable local sequence alignment."""47
48def __init__(self,49name = 'smith_waterman_loss',50reduction = tf.losses.Reduction.AUTO):51super().__init__(name=name, reduction=reduction)52
53def call(self, true_alignments_or_paths,54alignment_output):55"""Computes a loss associated with the Smith-Waterman DP.56
57Args:
58true_alignments_or_paths: The ground-truth alignments for the batch. Both
59sparse and dense representations of the alignments are allowed. For the
60sparse case, true_alignments_or_paths is expected to be a
61tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
62enc_trans], 1) such that (pos_x[b][i], pos_y[b][i], enc_trans[b][i])
63represents the i-th transition in the ground-truth alignment for example
64b in the minibatch. Both pos_x and pos_y are assumed to use one-based
65indexing and enc_trans follows the (categorical) 9-state encoding of
66edge types used throughout alignment.py. For the dense case,
67true_alignments_or_paths is instead expected to be a
68tf.Tensor<float>[batch, len_x, len_y, 9] with binary entries,
69representing the trajectory of the indices along the predicted alignment
70paths, by having a one along the taken edges, with nine possible edges
71for each i,j.
72alignment_output: An AlignmentOutput, which is a tuple (solution_values,
73solution_paths, sw_params) such that + 'solution_values' contains a
74tf.Tensor<float>[batch] with the (soft) optimal Smith-Waterman scores
75for the batch. + 'solution_paths', which is not used by the loss,
76optionally contains a tf.Tensor<float>[batch, len1, len2, 9] that
77describes the optimal soft alignments, being None otherwise. +
78'sw_params' contains a tuple (sim_mat, gap_open, gap_extend) of
79tf.Tensor objects parameterizing the Smith-Waterman LP such that +
80sim_mat is a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the
81substitution values for pairs of sequences. + gap_open is a
82tf.Tensor<float>[], tf.Tensor<float>[batch] or tf.Tensor<float>[batch,
83len1, len2] (len1 <= len2) with the penalties for opening a gap. Must
84agree in rank with gap_extend.
85+ gap_extend: a tf.Tensor<float>[], tf.Tensor<float>[batch] or
86tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the
87penalties for with the penalties for extending a gap. Must agree in
88rank with gap_open.
89
90Returns:
91The loss value for each example in the batch.
92"""
93solution_values, _, sw_params = alignment_output94return (solution_values -95alignment.sw_score(sw_params, true_alignments_or_paths))96
97
98@gin.configurable99class BCEAlignmentLoss(tf.losses.Loss):100"""Implements a brute-force BCE loss for pairwise sequence alignment."""101
102def __init__(self,103name = 'bce_alignment_loss',104reduction = tf.losses.Reduction.AUTO,105pad_penalty = 1e8):106super().__init__(name=name, reduction=reduction)107self._pad_penalty = pad_penalty108
109def call(self, true_alignments,110alignment_output):111"""Computes a brute-force BCE loss for pairwise sequence alignment.112
113Args:
114true_alignments: The ground-truth alignments for the batch, given by a
115expected tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
116enc_trans], 1) such that (pos_x[b][i], pos_y[b][i], enc_trans[b][i])
117represents the i-th transition in the ground-truth alignment for example
118b in the minibatch. Both pos_x and pos_y are assumed to use one-based
119indexing and enc_trans follows the (categorical) 9-state encoding of
120edge types used throughout alignment.py.
121alignment_output: A NaiveAlignmentOutput, which is a 3-tuple made of:
122+ The alignment scores: tf.Tensor<float>[batch].
123+ The pairwise match probabilities: tf.Tensor<int>[batch, len, len].
124+ A 3-tuple containing the Smith-Waterman parameters: similarities, gap
125open and gap extend. Similaries is tf.Tensor<float>[batch, len, len],
126the gap penalties can be either tf.Tensor<float>[batch] or
127tf.Tensor<float>[batch, len, len].
128
129Returns:
130The loss value for each example in the batch.
131"""
132_, match_indicators_pred, sw_params = alignment_output133sim_mat, _, _ = sw_params134shape, dtype = sim_mat.shape, match_indicators_pred.dtype135
136match_indices_true = alignment.alignments_to_state_indices(137true_alignments, 'match')138updates_true = tf.ones([tf.shape(match_indices_true)[0]], dtype=dtype)139match_indicators_true = tf.scatter_nd(140match_indices_true, updates_true, shape=shape)141
142raw_losses = tf.losses.binary_crossentropy(143match_indicators_true[Ellipsis, tf.newaxis],144match_indicators_pred[Ellipsis, tf.newaxis])145
146mask = alignment.mask_from_similarities(147sim_mat, dtype=dtype, pad_penalty=self._pad_penalty)148return tf.reduce_sum(mask * raw_losses, axis=[1, 2])149
150
151@gin.configurable152class ProcrustesLoss(tf.losses.Loss):153"""Implements a loss for embeddings, up to a rigid transformation."""154
155def __init__(self,156name = 'procrustes_loss',157reduction = tf.losses.Reduction.AUTO):158super().__init__(name=name, reduction=reduction)159
160def call(self, embs_true, embs_pred):161"""Computes the Procrustes loss between two (batches of) sets of vectors.162
163Args:
164embs_true: a tf.Tensor<float>[batch_size, num_embs, dims] batch of
165'num_embs' embeddings in dimension 'dim'.
166embs_pred: a tf.Tensor<float>[batch_size, num_embs, dims] batch of
167'num_embs' embeddings in dimension 'dim'.
168
169Returns:
170The Procrustes loss value between each pair of embeddings in the batch.
171"""
172embs_true_bar = embs_true - tf.reduce_mean(embs_true, axis=1, keepdims=True)173embs_pred_bar = embs_pred - tf.reduce_mean(embs_pred, axis=1, keepdims=True)174prod = tf.matmul(embs_true_bar, embs_pred_bar, transpose_a=True)175_, u_left, v_right = tf.linalg.svd(prod, full_matrices=True)176rotation_opt = tf.matmul(u_left, v_right, transpose_b=True)177return tf.linalg.norm(178tf.matmul(embs_true_bar, rotation_opt) - embs_pred_bar, axis=(1, 2))179
180
181def pairwise_square_dist(embs_1, embs_2):182"""Returns the matrix of square distances.183
184Args:
185embs_1: tf.Tensor<float>[batch, len, dim].
186embs_2: tf.Tensor<float>[batch, len, dim].
187
188Returns:
189A tf.Tensor<float>[batch, len, len] containing the square distances.
190"""
191gram_embs = tf.matmul(embs_1, embs_2, transpose_b=True)192sq_norm_embs_1 = tf.linalg.norm(embs_1, axis=-1, keepdims=True)**2193sq_norm_embs_2 = tf.linalg.norm(embs_2, axis=-1)**2194return sq_norm_embs_1 + sq_norm_embs_2[:, tf.newaxis, :] - 2 * gram_embs195
196
197@gin.configurable198class ContactLoss(tf.losses.Loss):199"""Implements a loss for contact matrices."""200
201def __init__(self,202name = 'contact_loss',203reduction = tf.losses.Reduction.NONE,204weights_fun=tf.identity,205dist_to_prob=None,206prob_loss=None,207from_embs=False,208threshold=8.,209n_low=16,210n_high=23):211"""Loss for predicted positions, based on ground truth contact information.212
213Args:
214name: the name of the loss
215reduction: how the loss is computed from element-wise losses.
216weights_fun: a weight function, applied on |i-j|, where i, j are the
217matrix indices. (see below).
218dist_to_prob: a function linking predicted pairwise square distance to
219predicted probability.
220prob_loss: a function comparing the predicted probability to ground truth.
221from_embs: whether the loss is computed from predicted embeddings (True)
222or directly a predicted pairwise distance matrix (False, by default).
223threshold: a scaling parameter for the contact functions.
224n_low: int for the weight function
225n_high: int for the weight function
226
227Returns:
228A loss function
229"""
230self._weights_fun = weights_fun231self._dist_to_prob = dist_to_prob232self._prob_loss = prob_loss233if prob_loss is None:234self._prob_loss = functools.partial(235tf.keras.losses.binary_crossentropy, from_logits=True)236self._from_embs = from_embs237self._threshold = threshold238self._n_low = n_low239self._n_high = n_high240super().__init__(name=name, reduction=reduction)241
242def call(self, contact_true, pred):243"""Computes the Contact loss between contact / distance matrices.244
245Args:
246contact_true: a tf.Tensor<float>[batch_size, num_embs, num_embs, 1], a
247batch of binary contact matrices for 'num_embs' embeddings.
248pred: a tf.Tensor<float> of shape either
249+ [batch_size, num_embs, dims] if 'from_embs' is True (embeddings case)
250a batch of 'num_embs' embeddings in dimension 'dim'.
251+ [batch_size, num_embs, num_embs, 1] if 'from_embs' is False (matrix
252case) a batch of pairwise distances for 'num_embs'
253embeddings.
254
255Returns:
256The contact loss values between the contact matrices and predictions
257in the batch. This is computed for an instance matrix in the batch as:
258loss(y, p) = sum_ij w_|i-j| prob_loss(y_ij, p_ij),
259where y is the ground truth contact matrix and p is the predicted
260contact probability matrix.
261+ prob_loss(y, p) is a function comparing y in {0,1} to p in [0,1]
262+ w_|i-j| is weights_fun(|i-j|), and just |i-j| if None.
263If from_embs is true, the predicted matrix is the pairwise distance of the
264predicted embeddings.
265"""
266if self._from_embs: # not yet checked267pairw_dist_pred = pairwise_square_dist(embs_1=pred, embs_2=pred)268else:269pairw_dist_pred = pred270num_embs = tf.shape(pairw_dist_pred)[1]271weights_range = tf.range(num_embs, dtype=tf.float32)272weights_range_square = tf.abs(weights_range[tf.newaxis, :, tf.newaxis] -273weights_range[tf.newaxis, tf.newaxis, :])274weights_square = self._weights_fun(weights_range_square)275contact_true = tf.cast(contact_true, dtype=pred.dtype)276if self._dist_to_prob is not None: # double-check the dummy [1] trail dim.277pairw_dist_pred = self._dist_to_prob(278-pairw_dist_pred / self._threshold**2)279
280mat_losses = self._prob_loss(contact_true, pairw_dist_pred)281return weights_square * mat_losses282
283
284@gin.configurable285class MultiTaskLoss:286"""A loss to combine multiple ones for a model that outputs a Dict."""287
288def __init__(self, losses):289self._losses = losses290# Make sure every loss has a weight.291for level in self._losses.levels:292for i in range(len(level)):293if isinstance(level[i], tf.keras.losses.Loss):294level[i] = (1.0, level[i])295
296def _compute_weight_correction(self, labels, weights=None, epsilon=1e-9):297"""Account for weight sums for a specific head/loss."""298replica_ctx = tf.distribute.get_replica_context()299per_replica = (300tf.shape(labels)[0] if weights is None else tf.math.reduce_sum(weights))301total = replica_ctx.all_reduce('sum', per_replica)302return 1.0 / (tf.cast(total, tf.float32) + epsilon)303
304def __call__(self,305y_true,306y_pred,307weights = None):308# TODO(oliviert): Should we unflatten?309y_true = multi_task.Backbone.unflatten(y_true)310weights = multi_task.Backbone.unflatten(weights)311if y_pred.shape != self._losses.shape:312raise ValueError(313f'The SeqAlign MultiTaskLoss shape {self._losses.shape} is not '314f'matching the predictions shape {y_pred.shape}')315
316total_loss = 0.0317individual_losses = {}318for weighted_loss, label, pred, batch_w in zip(self._losses, y_true, y_pred,319weights):320loss_w, loss_fn = weighted_loss321if loss_fn is None:322continue323loss_w *= self._compute_weight_correction(label, batch_w)324loss = loss_w * tf.math.reduce_sum(325loss_fn(label, pred, sample_weight=batch_w))326total_loss += loss327individual_losses[loss_fn.name] = loss328return total_loss, individual_losses329