google-research

Форк
0
328 строк · 13.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implements custom losses."""
17

18
import functools
19
from typing import NamedTuple, Optional, Tuple, Union
20

21
import gin
22
import tensorflow as tf
23

24
from dedal import alignment
25
from dedal import multi_task
26

27

28
@gin.configurable
29
class WeightedLoss(NamedTuple):
30
  weight: float
31
  loss: tf.keras.losses.Loss
32

33

34
MaybeWeightedLoss = Union[WeightedLoss, tf.keras.losses.Loss]
35
NestedWeights = multi_task.Backbone[Optional[tf.Tensor]]
36
SWParams = Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
37
AlignmentOutput = Tuple[tf.Tensor,  # Solution values.
38
                        Optional[tf.Tensor],  # Solution paths.
39
                        SWParams,  # DP parameters.
40
                       ]
41
NaiveAlignmentOutput = Tuple[tf.Tensor, tf.Tensor, SWParams]
42

43

44
@gin.configurable
45
class SmithWatermanLoss(tf.losses.Loss):
46
  """Implements a loss for differentiable local sequence alignment."""
47

48
  def __init__(self,
49
               name = 'smith_waterman_loss',
50
               reduction = tf.losses.Reduction.AUTO):
51
    super().__init__(name=name, reduction=reduction)
52

53
  def call(self, true_alignments_or_paths,
54
           alignment_output):
55
    """Computes a loss associated with the Smith-Waterman DP.
56

57
    Args:
58
      true_alignments_or_paths: The ground-truth alignments for the batch. Both
59
        sparse and dense representations of the alignments are allowed. For the
60
        sparse case, true_alignments_or_paths is expected to be a
61
        tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
62
        enc_trans], 1) such that (pos_x[b][i], pos_y[b][i], enc_trans[b][i])
63
        represents the i-th transition in the ground-truth alignment for example
64
        b in the minibatch. Both pos_x and pos_y are assumed to use one-based
65
        indexing and enc_trans follows the (categorical) 9-state encoding of
66
        edge types used throughout alignment.py. For the dense case,
67
        true_alignments_or_paths is instead expected to be a
68
        tf.Tensor<float>[batch, len_x, len_y, 9] with binary entries,
69
        representing the trajectory of the indices along the predicted alignment
70
        paths, by having a one along the taken edges, with nine possible edges
71
        for each i,j.
72
      alignment_output: An AlignmentOutput, which is a tuple (solution_values,
73
        solution_paths, sw_params) such that + 'solution_values' contains a
74
        tf.Tensor<float>[batch] with the (soft) optimal Smith-Waterman scores
75
        for the batch. + 'solution_paths', which is not used by the loss,
76
        optionally contains a tf.Tensor<float>[batch, len1, len2, 9] that
77
        describes the optimal soft alignments, being None otherwise. +
78
        'sw_params' contains a tuple (sim_mat, gap_open, gap_extend) of
79
        tf.Tensor objects parameterizing the Smith-Waterman LP such that +
80
        sim_mat is a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the
81
        substitution values for pairs of sequences. + gap_open is a
82
        tf.Tensor<float>[], tf.Tensor<float>[batch] or tf.Tensor<float>[batch,
83
        len1, len2] (len1 <= len2) with the penalties for opening a gap. Must
84
        agree in rank with gap_extend.
85
          + gap_extend: a tf.Tensor<float>[], tf.Tensor<float>[batch] or
86
            tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the
87
            penalties for with the penalties for extending a gap. Must agree in
88
            rank with gap_open.
89

90
    Returns:
91
      The loss value for each example in the batch.
92
    """
93
    solution_values, _, sw_params = alignment_output
94
    return (solution_values -
95
            alignment.sw_score(sw_params, true_alignments_or_paths))
96

97

98
@gin.configurable
99
class BCEAlignmentLoss(tf.losses.Loss):
100
  """Implements a brute-force BCE loss for pairwise sequence alignment."""
101

102
  def __init__(self,
103
               name = 'bce_alignment_loss',
104
               reduction = tf.losses.Reduction.AUTO,
105
               pad_penalty = 1e8):
106
    super().__init__(name=name, reduction=reduction)
107
    self._pad_penalty = pad_penalty
108

109
  def call(self, true_alignments,
110
           alignment_output):
111
    """Computes a brute-force BCE loss for pairwise sequence alignment.
112

113
    Args:
114
      true_alignments: The ground-truth alignments for the batch, given by a
115
        expected tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
116
        enc_trans], 1) such that (pos_x[b][i], pos_y[b][i], enc_trans[b][i])
117
        represents the i-th transition in the ground-truth alignment for example
118
        b in the minibatch. Both pos_x and pos_y are assumed to use one-based
119
        indexing and enc_trans follows the (categorical) 9-state encoding of
120
        edge types used throughout alignment.py.
121
      alignment_output: A NaiveAlignmentOutput, which is a 3-tuple made of:
122
        + The alignment scores: tf.Tensor<float>[batch].
123
        + The pairwise match probabilities: tf.Tensor<int>[batch, len, len].
124
        + A 3-tuple containing the Smith-Waterman parameters: similarities, gap
125
          open and gap extend. Similaries is tf.Tensor<float>[batch, len, len],
126
          the gap penalties can be either tf.Tensor<float>[batch] or
127
          tf.Tensor<float>[batch, len, len].
128

129
    Returns:
130
      The loss value for each example in the batch.
131
    """
132
    _, match_indicators_pred, sw_params = alignment_output
133
    sim_mat, _, _ = sw_params
134
    shape, dtype = sim_mat.shape, match_indicators_pred.dtype
135

136
    match_indices_true = alignment.alignments_to_state_indices(
137
        true_alignments, 'match')
138
    updates_true = tf.ones([tf.shape(match_indices_true)[0]], dtype=dtype)
139
    match_indicators_true = tf.scatter_nd(
140
        match_indices_true, updates_true, shape=shape)
141

142
    raw_losses = tf.losses.binary_crossentropy(
143
        match_indicators_true[Ellipsis, tf.newaxis],
144
        match_indicators_pred[Ellipsis, tf.newaxis])
145

146
    mask = alignment.mask_from_similarities(
147
        sim_mat, dtype=dtype, pad_penalty=self._pad_penalty)
148
    return tf.reduce_sum(mask * raw_losses, axis=[1, 2])
149

150

151
@gin.configurable
152
class ProcrustesLoss(tf.losses.Loss):
153
  """Implements a loss for embeddings, up to a rigid transformation."""
154

155
  def __init__(self,
156
               name = 'procrustes_loss',
157
               reduction = tf.losses.Reduction.AUTO):
158
    super().__init__(name=name, reduction=reduction)
159

160
  def call(self, embs_true, embs_pred):
161
    """Computes the Procrustes loss between two (batches of) sets of vectors.
162

163
    Args:
164
      embs_true: a tf.Tensor<float>[batch_size, num_embs, dims] batch of
165
        'num_embs' embeddings in dimension 'dim'.
166
      embs_pred: a tf.Tensor<float>[batch_size, num_embs, dims] batch of
167
        'num_embs' embeddings in dimension 'dim'.
168

169
    Returns:
170
      The Procrustes loss value between each pair of embeddings in the batch.
171
    """
172
    embs_true_bar = embs_true - tf.reduce_mean(embs_true, axis=1, keepdims=True)
173
    embs_pred_bar = embs_pred - tf.reduce_mean(embs_pred, axis=1, keepdims=True)
174
    prod = tf.matmul(embs_true_bar, embs_pred_bar, transpose_a=True)
175
    _, u_left, v_right = tf.linalg.svd(prod, full_matrices=True)
176
    rotation_opt = tf.matmul(u_left, v_right, transpose_b=True)
177
    return tf.linalg.norm(
178
        tf.matmul(embs_true_bar, rotation_opt) - embs_pred_bar, axis=(1, 2))
179

180

181
def pairwise_square_dist(embs_1, embs_2):
182
  """Returns the matrix of square distances.
183

184
  Args:
185
    embs_1: tf.Tensor<float>[batch, len, dim].
186
    embs_2: tf.Tensor<float>[batch, len, dim].
187

188
  Returns:
189
    A tf.Tensor<float>[batch, len, len] containing the square distances.
190
  """
191
  gram_embs = tf.matmul(embs_1, embs_2, transpose_b=True)
192
  sq_norm_embs_1 = tf.linalg.norm(embs_1, axis=-1, keepdims=True)**2
193
  sq_norm_embs_2 = tf.linalg.norm(embs_2, axis=-1)**2
194
  return sq_norm_embs_1 + sq_norm_embs_2[:, tf.newaxis, :] - 2 * gram_embs
195

196

197
@gin.configurable
198
class ContactLoss(tf.losses.Loss):
199
  """Implements a loss for contact matrices."""
200

201
  def __init__(self,
202
               name = 'contact_loss',
203
               reduction = tf.losses.Reduction.NONE,
204
               weights_fun=tf.identity,
205
               dist_to_prob=None,
206
               prob_loss=None,
207
               from_embs=False,
208
               threshold=8.,
209
               n_low=16,
210
               n_high=23):
211
    """Loss for predicted positions, based on ground truth contact information.
212

213
    Args:
214
      name: the name of the loss
215
      reduction: how the loss is computed from element-wise losses.
216
      weights_fun: a weight function, applied on |i-j|, where i, j are the
217
        matrix indices. (see below).
218
      dist_to_prob: a function linking predicted pairwise square distance to
219
        predicted probability.
220
      prob_loss: a function comparing the predicted probability to ground truth.
221
      from_embs: whether the loss is computed from predicted embeddings (True)
222
        or directly a predicted pairwise distance matrix (False, by default).
223
      threshold: a scaling parameter for the contact functions.
224
      n_low: int for the weight function
225
      n_high: int for the weight function
226

227
    Returns:
228
      A loss function
229
    """
230
    self._weights_fun = weights_fun
231
    self._dist_to_prob = dist_to_prob
232
    self._prob_loss = prob_loss
233
    if prob_loss is None:
234
      self._prob_loss = functools.partial(
235
          tf.keras.losses.binary_crossentropy, from_logits=True)
236
    self._from_embs = from_embs
237
    self._threshold = threshold
238
    self._n_low = n_low
239
    self._n_high = n_high
240
    super().__init__(name=name, reduction=reduction)
241

242
  def call(self, contact_true, pred):
243
    """Computes the Contact loss between contact / distance matrices.
244

245
    Args:
246
      contact_true: a tf.Tensor<float>[batch_size, num_embs, num_embs, 1], a
247
        batch of binary contact matrices for 'num_embs' embeddings.
248
      pred: a tf.Tensor<float> of shape either
249
        + [batch_size, num_embs, dims] if 'from_embs' is True (embeddings case)
250
         a batch of 'num_embs' embeddings in dimension 'dim'.
251
        + [batch_size, num_embs, num_embs, 1] if 'from_embs' is False (matrix
252
         case) a batch of pairwise distances for 'num_embs'
253
        embeddings.
254

255
    Returns:
256
      The contact loss values between the contact matrices and predictions
257
      in the batch. This is computed for an instance matrix in the batch as:
258
        loss(y, p) = sum_ij w_|i-j| prob_loss(y_ij, p_ij),
259
      where y is the ground truth contact matrix and p is the predicted
260
      contact probability matrix.
261
        + prob_loss(y, p) is a function comparing y in {0,1} to p in [0,1]
262
        + w_|i-j| is weights_fun(|i-j|), and just |i-j| if None.
263
      If from_embs is true, the predicted matrix is the pairwise distance of the
264
      predicted embeddings.
265
    """
266
    if self._from_embs:  # not yet checked
267
      pairw_dist_pred = pairwise_square_dist(embs_1=pred, embs_2=pred)
268
    else:
269
      pairw_dist_pred = pred
270
    num_embs = tf.shape(pairw_dist_pred)[1]
271
    weights_range = tf.range(num_embs, dtype=tf.float32)
272
    weights_range_square = tf.abs(weights_range[tf.newaxis, :, tf.newaxis] -
273
                                  weights_range[tf.newaxis, tf.newaxis, :])
274
    weights_square = self._weights_fun(weights_range_square)
275
    contact_true = tf.cast(contact_true, dtype=pred.dtype)
276
    if self._dist_to_prob is not None:  # double-check the dummy [1] trail dim.
277
      pairw_dist_pred = self._dist_to_prob(
278
          -pairw_dist_pred / self._threshold**2)
279

280
    mat_losses = self._prob_loss(contact_true, pairw_dist_pred)
281
    return weights_square * mat_losses
282

283

284
@gin.configurable
285
class MultiTaskLoss:
286
  """A loss to combine multiple ones for a model that outputs a Dict."""
287

288
  def __init__(self, losses):
289
    self._losses = losses
290
    # Make sure every loss has a weight.
291
    for level in self._losses.levels:
292
      for i in range(len(level)):
293
        if isinstance(level[i], tf.keras.losses.Loss):
294
          level[i] = (1.0, level[i])
295

296
  def _compute_weight_correction(self, labels, weights=None, epsilon=1e-9):
297
    """Account for weight sums for a specific head/loss."""
298
    replica_ctx = tf.distribute.get_replica_context()
299
    per_replica = (
300
        tf.shape(labels)[0] if weights is None else tf.math.reduce_sum(weights))
301
    total = replica_ctx.all_reduce('sum', per_replica)
302
    return 1.0 / (tf.cast(total, tf.float32) + epsilon)
303

304
  def __call__(self,
305
               y_true,
306
               y_pred,
307
               weights = None):
308
    # TODO(oliviert): Should we unflatten?
309
    y_true = multi_task.Backbone.unflatten(y_true)
310
    weights = multi_task.Backbone.unflatten(weights)
311
    if y_pred.shape != self._losses.shape:
312
      raise ValueError(
313
          f'The SeqAlign MultiTaskLoss shape {self._losses.shape} is not '
314
          f'matching the predictions shape {y_pred.shape}')
315

316
    total_loss = 0.0
317
    individual_losses = {}
318
    for weighted_loss, label, pred, batch_w in zip(self._losses, y_true, y_pred,
319
                                                   weights):
320
      loss_w, loss_fn = weighted_loss
321
      if loss_fn is None:
322
        continue
323
      loss_w *= self._compute_weight_correction(label, batch_w)
324
      loss = loss_w * tf.math.reduce_sum(
325
          loss_fn(label, pred, sample_weight=batch_w))
326
      total_loss += loss
327
      individual_losses[loss_fn.name] = loss
328
    return total_loss, individual_losses
329

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.