google-research

Форк
0
/
alignment.py 
481 строка · 19.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Functions used to manipulate alignments and smith-waterman parameters."""
17

18
from typing import Sequence, Tuple, Union
19

20
import tensorflow as tf
21

22
# Type aliases
23
PackedSWParams = tf.Tensor
24
UnpackedSWParams = Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
25
SWParams = Union[PackedSWParams, UnpackedSWParams]
26

27

28
# SW dynamic program edge types, grouped by associated edge weight kind. See
29
MATCH_STATES = [0, 1, 2, 3]
30
GAP_OPEN_STATES = [4, 6, 7]
31
GAP_EXTEND_STATES = [5, 8]
32
STATES = {
33
    'match': MATCH_STATES,
34
    'gap_open': GAP_OPEN_STATES,
35
    'gap_extend': GAP_EXTEND_STATES,
36
}
37

38

39
def large_compatible_positive(tensor_type):
40
  """Large positive number as Tensor.
41

42
  This function is necessary because the standard value for "inf" in this module
43
  (1e9) cannot be represented using tf.float16.
44

45
  NOTE(fllinares): Borrowed from
46
  tensorflow/python/keras/layers/advanced_activations.py
47
  which is used already in this codebase indirectly (e.g. in self-attention
48
  layers).
49

50
  Args:
51
    tensor_type: a dtype to determine the type.
52

53
  Returns:
54
    A large positive number.
55
  """
56
  if tensor_type == tf.dtypes.float16:
57
    return tf.dtypes.float16.max
58
  return tf.convert_to_tensor(1e9, dtype=tensor_type)
59

60

61
def top_pad(t, v):
62
  """Pads tf.Tensor `t` by prepending `v` along the leading dimension."""
63
  return tf.pad(t, [[1, 0], [0, 0], [0, 0]], constant_values=v)
64

65

66
def left_pad(t, v):
67
  """Pads tf.Tensor `t` by prepending `v` along the second leading dimension."""
68
  return tf.pad(t, [[0, 0], [1, 0], [0, 0]], constant_values=v)
69

70

71
def right_pad(t, v):
72
  """Pads tf.Tensor `t` by appending `v` along the second leading dimension."""
73
  return tf.pad(t, [[0, 0], [0, 1], [0, 0]], constant_values=v)
74

75

76
def alignments_to_paths(
77
    alignments, len_x, len_y):
78
  """Converts sparse representation of alignments into dense paths tensor.
79

80
  Args:
81
    alignments: A tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
82
      enc_trans], 1) such that
83
        (pos_x[b][i], pos_y[b][i], enc_trans[b][i]) represents the i-th
84
      transition in the alignment for the b-th sequence pair in the minibatch.
85
      Both pos_x and pos_y are assumed to use one-based indexing and enc_trans
86
      follows the (categorical) 9-state encoding of edge types used throughout
87
      alignment/smith_waterman.py.
88
    len_x: The (padded) length of "X"/"query" sequences in the minibatch.
89
    len_y: The (padded) length of "Y"/"subject" sequences in the minibatch.
90

91
  Returns:
92
    A tf.Tensor of type tf.float32 and shape (batch_size, len_x, len_y, 9) with
93
    binary entries, representing the trajectory of the indices along the
94
    alignment path, by having a one along the taken edges, with nine possible
95
    edges for each i,j.
96
  """
97
  batch_size = tf.shape(alignments)[0]
98
  align_len = tf.shape(alignments)[-1]
99

100
  # Tensor with the same shape as pos_x, pos_y and enc_trans such that
101
  #   seq_indicators[b][l] = b for all l in [0, align_len).
102
  seq_indicators = tf.multiply(tf.expand_dims(tf.range(batch_size), -1),
103
                               tf.ones((1, align_len), dtype=tf.int32))
104
  # Prepares inputs to scatter_nd.
105
  indices = tf.concat([seq_indicators[Ellipsis, None],
106
                       tf.transpose(alignments, (0, 2, 1))], -1)
107
  indices = tf.reshape(indices, (-1, 4))
108
  updates = tf.ones(tf.shape(indices)[0], dtype=tf.float32)
109
  shape = (batch_size, len_x + 1, len_y + 1, 9)
110

111
  # Note(fllinares): this is a (fairly ugly) hack to deal with padding.
112
  #  - pos_x, pos_y must use one-based indexing instead of zero-based indexing.
113
  #  - we use the (b, 0, 0, 0) entries of paths as "padding dumps".
114
  #  - the resulting tensor will be sliced to remove these starting row/col.
115
  paths = tf.scatter_nd(indices, updates, shape)
116
  return paths[:, 1:, 1:, :]
117

118

119
def alignments_to_state_indices(
120
    alignments,
121
    states,
122
    zero_based_idx = True,
123
):
124
  """Retrieves indices of MATCH/GAP OPEN/GAP EXTEND states in alignments.
125

126
  Args:
127
    alignments: A tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
128
      enc_trans], 1) such that
129
        (pos_x[b][i], pos_y[b][i], enc_trans[b][i]) represents the i-th
130
      transition in the alignment for the b-th sequence pair in the minibatch.
131
      Both pos_x and pos_y are assumed to use one-based indexing and enc_trans
132
      follows the (categorical) 9-state encoding of edge types used throughout
133
      alignment/smith_waterman.py.
134
    states: A Python list of integers in [0, 9), representing an arbitrary
135
      subset of (encoded) edge types. Can also be set to 'match', 'gap_open' or
136
      'gap_extend' to query the set of edge types associated with each of those
137
      conditions.
138
    zero_based_idx: Whether to use zero-based (True) or one-based (False)
139
      indexing for the function's output. Note that, however, alignment must use
140
      one-based indexing regardless of the value of this argument.
141

142
  Returns:
143
    A tf.Tensor `state_indices` of type tf.int32 and shape (n_entries, 3) such
144
    that, for a tf.Tensor `sim_mat` of shape (batch_size, len_x, len_y),
145
      tf.gather_nd(sim_mat, state_indices)
146
    returns the set of entries in `sim_mat` along the alignments described by
147
    `alignment` that correspond to one of the states in `states`.
148

149
    Note(fllinares): this function aims to provide a way to avoid materializing
150
    weights in the crf_loss function in alignment/smith_waterman.py, as
151
    suggested by @mblondel. Some extra care might be needed to keep per-example
152
    losses, as tf.gather_nd will flatten the output by default. For
153
    position-independent gap penalties, only the total number of entries per
154
    example in state_indices would be needed. See `score_from_alignment` below
155
    for extra details.
156
  """
157
  pos_x, pos_y, enc_trans = alignments[:, 0], alignments[:, 1], alignments[:, 2]
158
  states = STATES.get(states, states)
159

160
  # Note(fllinares): another ugly "hack", here we assume one-based idx to encode
161
  # the padding mask implicitly.
162
  padding_mask = tf.logical_and(pos_x > 0, pos_y > 0)
163
  hits = enc_trans == states[0]
164
  for state in states[1:]:
165
    hits = tf.logical_or(hits, enc_trans == state)
166
  hits = tf.logical_and(hits, padding_mask)
167
  indices = tf.cast(tf.where(hits), tf.int32)
168

169
  batch_indices = indices[:, 0]
170
  x_indices = tf.gather_nd(pos_x, indices) - int(zero_based_idx)
171
  y_indices = tf.gather_nd(pos_y, indices) - int(zero_based_idx)
172
  state_indices = tf.stack([batch_indices, x_indices, y_indices], axis=0)
173
  return tf.transpose(state_indices, (1, 0))
174

175

176
def paths_to_state_indicators(
177
    paths,
178
    states,
179
):
180
  """Computes (batch_size, len_x, len_y) tensor of binary state indicators.
181

182
  Args:
183
    paths: A tf.Tensor of type tf.float32 and shape (batch_size, len_x, len_y,
184
      9) with binary entries, representing the trajectory of the indices along
185
      the alignment path, by having a one along the taken edges, with nine
186
      possible edges for each i,j.
187
    states: A Python list of integers in [0, 9), representing an arbitrary
188
      subset of (encoded) edge types. This can also be set to 'match',
189
      'gap_open' or 'gap_extend' to query the set of edge types associated with
190
      each of those conditions.
191

192
  Returns:
193
    A tf.Tensor `state_indicators` of type tf.float32 and shape (batch_size,
194
    len_x, len_y) with binary entries such that
195
      state_indicators[b][i][j] = 1.0
196
    iff the trajectory of the alignment for the b-th sequence pair passes by
197
    character pair (i, j) under one of the states in `states`.
198
  """
199
  states = STATES.get(states, states)
200
  return tf.reduce_max(tf.gather(paths, indices=states, axis=-1), axis=-1)
201

202

203
def sw_score_from_alignments(
204
    sw_params,
205
    alignments,
206
):
207
  """Computes SW score of `alignments` for DP parameterized by `sw_params`.
208

209
  Args:
210
    sw_params: The parameters (sim_mat, gap_open, gap_extend) for the dynamic
211
      program underlying the Smith-Waterman algorithm.
212
      These can be input either as a tuple of tf.Tensor objects or as a single
213
      "packed" tensor of rank 4. See class `SWParamsFromEmbeddings` in module
214
      `sw_params_from_embeddings.py` for additional details.
215
    alignments: A tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
216
      enc_trans], 1) such that
217
        (pos_x[b][i], pos_y[b][i], enc_trans[b][i]) represents the i-th
218
      transition in the alignment for the b-th sequence pair in the minibatch.
219
      Both pos_x and pos_y are assumed to use one-based indexing and enc_trans
220
      follows the (categorical) 9-state encoding of edge types used throughout
221
      alignment/smith_waterman.py.
222

223
  Returns:
224
    A tf.Tensor of type tf.float32 and shape (batch_size,) containing the SW
225
    score of each alignment in the batch.
226
  """
227
  # Ensures SW params are in "unpacked" format.
228
  if isinstance(sw_params, Sequence):  # _UnpackedSWParams format
229
    sim_mat, gap_open, gap_extend = sw_params
230
    gap_open, gap_extend = -gap_open, -gap_extend
231
  else:  # _PackedSWParams format
232
    sim_mat = sw_params[Ellipsis, MATCH_STATES[0]]
233
    gap_open = sw_params[Ellipsis, GAP_OPEN_STATES[0]]
234
    gap_extend = sw_params[Ellipsis, GAP_EXTEND_STATES[0]]
235

236
  batch_size = tf.shape(sim_mat)[0]  # Assumed consistent with gap_open/extend.
237

238
  def dot_by_states(t, states):
239
    """Sums entries of t along alignments for queried states."""
240

241
    def pos_dep_dot(t, states):
242
      """Sums entries of t[b,l1,l2] along alignments for queried states."""
243
      state_indices = alignments_to_state_indices(alignments, states)
244
      batch_indices = state_indices[:, 0]
245
      state_entries_along_path = tf.gather_nd(t, state_indices)
246
      total_per_example = tf.math.unsorted_segment_sum(
247
          state_entries_along_path, batch_indices, batch_size)
248
      return total_per_example
249

250
    def pos_indep_dot(t, states):
251
      """Sums entries of t[b] along alignments for queried states."""
252
      state_indices = alignments_to_state_indices(alignments, states)
253
      batch_indices = state_indices[:, 0]
254
      # Note(fllinares): tf.math.bincount unsupported in TPU :(
255
      n_state_entries_along_path = tf.math.unsorted_segment_sum(
256
          tf.ones_like(batch_indices, tf.float32), batch_indices, batch_size)
257
      total_per_example = t * n_state_entries_along_path
258
      return total_per_example
259

260
    return (pos_dep_dot(t, states) if t.shape.rank == 3
261
            else pos_indep_dot(t, states))
262

263
  sim_per_example = dot_by_states(sim_mat, 'match')
264
  gap_open_per_example = dot_by_states(gap_open, 'gap_open')
265
  gap_extend_per_example = dot_by_states(gap_extend, 'gap_extend')
266

267
  return sim_per_example + gap_open_per_example + gap_extend_per_example
268

269

270
def sw_score_from_paths(sw_params, paths):
271
  """Computes SW score of `paths` for DP parameterized by `sw_params`.
272

273
  Args:
274
    sw_params: The parameters (sim_mat, gap_open, gap_extend) for the dynamic
275
      program underlying the Smith-Waterman algorithm.
276
      These can be input either as a tuple of tf.Tensor objects or as a single
277
      "packed" tensor of rank 4. See class `SWParamsFromEmbeddings` in module
278
      `sw_params_from_embeddings.py` for additional details.
279
    paths: A tf.Tensor of type tf.float32 and shape (batch_size, len_x, len_y,
280
      9) with binary entries, representing the trajectory of the indices along
281
      the alignment path, by having a one along the taken edges, with nine
282
      possible edges for each i,j.
283

284
  Returns:
285
    A tf.Tensor of type tf.float32 and shape (batch_size,) containing the SW
286
    score of each alignment in the batch.
287
  """
288
  if isinstance(sw_params, Sequence):  # _UnpackedSWParams format
289
    sw_params = weights_from_sim_mat(*sw_params)
290
  return tf.reduce_sum(sw_params * paths, axis=[1, 2, 3])
291

292

293
def sw_score(
294
    sw_params,
295
    alignments_or_paths,
296
):
297
  """Wraps over sw_score_from_paths and sw_score_from_alignments."""
298
  if alignments_or_paths.shape.rank == 3:  # Sparse format
299
    return sw_score_from_alignments(sw_params, alignments_or_paths)
300
  else:  # tf.Tensor format
301
    return sw_score_from_paths(sw_params, alignments_or_paths)
302

303

304
def mask_from_similarities(sim_mat,
305
                           dtype = tf.float32,
306
                           pad_penalty = 1e8):
307
  """Recovers padding / special token mask from a similarities tensor.
308

309
  Args:
310
    sim_mat: A tf.Tensor<float>[batch, len, len] of pairwise similarities. It is
311
      assumed that entries corresponding to padding / special tokens have been
312
      masked by being set to have magnitude greater than pad_penalty.
313
    dtype: The desired dtype for the output mask.
314
    pad_penalty: The magnitude above which entries are considered to have been
315
      masked.
316

317
  Returns:
318
    A tf.Tensor<dtype>[batch, len, len] with binary entries, with 1.0 signifying
319
    "real" tokens and 0.0 padding / special tokens.
320
  """
321
  mask = tf.logical_and(sim_mat > -pad_penalty, sim_mat < pad_penalty)
322
  return tf.cast(mask, dtype)
323

324

325
def broadcast_to_rank(t, rank, axis = -1):
326
  """Appends dimensions to tf.Tensor `t` at axis `axis` to match rank `rank`."""
327
  rank_t = t.shape.rank  # Assumes ranks are known at compile time (static).
328
  for _ in range(rank - rank_t):
329
    t = tf.expand_dims(t, axis=axis)
330
  return t
331

332

333
def broadcast_to_shape(
334
    t,
335
    shape,
336
):
337
  """Appends dimensions to and tiles tf.Tensor t to match desired shape."""
338
  rank = len(shape)
339
  t = broadcast_to_rank(t, rank, axis=-1)
340
  return tf.tile(t, shape // tf.shape(t))
341

342

343
def weights_from_sim_mat(
344
    sim_mat,
345
    gap_open,
346
    gap_extend,
347
):
348
  """Computes the edge weights for the Smith-Waterman LP.
349

350
  Args:
351
    sim_mat: a tf.Tensor<float>[batch, len1, len2] with the substitution values
352
      for pairs of sequences.
353
    gap_open: a tf.Tensor<float>[batch, len1, len2] or tf.Tensor<float>[batch]
354
      of penalties for opening a gap.
355
    gap_extend: a tf.Tensor<float>[batch, len1, len2] or tf.Tensor<float>[batch]
356
      of penalties for extending a gap.
357

358
  Returns:
359
    A single tf.Tensor<float>[batch, len1, len2, 9] of edge weights for nine
360
    edge types. These correspond to a (strict) subset of allowed (from, to)
361
    state transitions between four state types, namely, start, match, gap_in_x
362
    and gap_in_y. Along the last dimension:
363
    + The first four (0:4) indices form a tf.Tensor<float>[batch, len1, len2, 4]
364
      of weights for all edges leading into match states. That is, these
365
      represent transitions (start, match), (match, match), (gap_in_x, match)
366
      and (gap_in_y, match), respectively.
367
    + The next two (4:6) indices form a tf.Tensor<float>[batch, len1, len2, 2]
368
      of weights for all edges leading into gap_in_x states. These represent
369
      transitions (match, gap_in_x) and (gap_in_x, gap_in_x), respectively. Note
370
      that, by convention, (gap_in_y, gap_in_x) transitions are disallowed.
371
    + The last three (6:9) indices form a tf.Tensor<float>[batch, len1, len2, 3]
372
      of weights for all edges leading into gap_in_y states. These represent
373
      transitions (match, gap_in_y) and (gap_in_x, gap_in_y) and, finally,
374
      (gap_in_y, gap_in_y), respectively.
375
  """
376
  b, l1, l2 = tf.shape(sim_mat)[0], tf.shape(sim_mat)[1], tf.shape(sim_mat)[2]
377

378
  sim_mat = broadcast_to_shape(sim_mat, [b, l1, l2, 4])
379
  gap_open = broadcast_to_shape(gap_open, [b, l1, l2, 1])
380
  gap_extend = broadcast_to_shape(gap_extend, [b, l1, l2, 1])
381

382
  weights_m = sim_mat
383
  weights_x = tf.concat([-gap_open, -gap_extend], axis=-1)
384
  weights_y = tf.concat([-gap_open, weights_x], axis=-1)
385

386
  return tf.concat([weights_m, weights_x, weights_y], axis=-1)
387

388

389
def adjoint_weights_from_sim_mat(
390
    weights,
391
    gap_open_shape,
392
    gap_extend_shape,
393
):
394
  """Computes the adjoint of `weights_from_sim_mat`.
395

396
  Viewing `weights_from_sim_mat` as a linear map weights = A sw_params, this
397
  function implements the linear map A^{T} weights. Primarily to be used when
398
  implementing custom_gradients in functions downstream.
399

400
  Args:
401
    weights: a tf.Tensor<float>[batch, len1, len2, 9].
402
    gap_open_shape: a tf.TensorShape representing the shape of gap_open in
403
      sw_params.
404
    gap_extend_shape: a tf.TensorShape representing the shape of gap_extend in
405
      sw_params.
406

407
  Returns:
408
    A tuple (sim_mat_out, gap_open_out, gap_extend_out) such that
409
      + sim_mat_out is a tf.Tensor<float>[batch, len1, len2] representing the
410
        elements of A^{T} weights corresponding to sim_mat.
411
      + gap_open_out is a tf.Tensor<float>[gap_open_shape] representing the
412
        elements of A^{T} weights corresponding to gap_open_shape.
413
      + gap_extend_out is a tf.Tensor<float>[gap_extend_shape] representing the
414
        elements of A^{T} weights corresponding to gap_extend_out.
415
  """
416
  sim_mat_out = tf.reduce_sum(weights[Ellipsis, :4], axis=-1)
417

418
  # Aggregates output across positions / examples too when appropriate.
419
  gap_open_out = - (weights[Ellipsis, 4] + weights[Ellipsis, 6] + weights[Ellipsis, 7])
420
  if gap_open_shape.rank == 1:
421
    gap_open_out = tf.reduce_sum(gap_open_out, axis=[1, 2])
422
  elif gap_open_shape.rank == 0:
423
    gap_open_out = tf.reduce_sum(gap_open_out)
424

425
  gap_extend_out = - (weights[Ellipsis, 5] + weights[Ellipsis, 8])
426
  if gap_extend_shape.rank == 1:
427
    gap_extend_out = tf.reduce_sum(gap_extend_out, axis=[1, 2])
428
  elif gap_extend_shape.rank == 0:
429
    gap_extend_out = tf.reduce_sum(gap_extend_out)
430

431
  return sim_mat_out, gap_open_out, gap_extend_out
432

433

434
def length(alignments_or_paths):
435
  """Computes the lengths in batch of sparse / dense alignments."""
436
  if alignments_or_paths.shape.rank == 3:  # Sparse format.
437
    pos_x, pos_y = alignments_or_paths[:, 0], alignments_or_paths[:, 1]
438
    padding_mask = tf.logical_and(pos_x > 0, pos_y > 0)
439
    return tf.reduce_sum(tf.cast(padding_mask, tf.float32), axis=-1)
440
  else:  # Dense format.
441
    return tf.reduce_sum(alignments_or_paths, axis=[1, 2, 3])
442

443

444
def state_count(alignments_or_paths, states):
445
  """Counts match/gap_open/gap_extend in batch of sparse / dense alignments."""
446
  if alignments_or_paths.shape.rank == 3:  # Sparse format.
447
    batch_size = tf.shape(alignments_or_paths)[0]
448
    state_indices = alignments_to_state_indices(alignments_or_paths, states)
449
    batch_indicators = state_indices[:, 0]
450
    ones = tf.ones_like(batch_indicators, tf.float32)
451
    return tf.math.unsorted_segment_sum(ones, batch_indicators, batch_size)
452
  else:  # Dense format.
453
    state_indicators = paths_to_state_indicators(alignments_or_paths, states)
454
    return tf.reduce_sum(state_indicators, axis=[1, 2])
455

456

457
def endpoints(alignments_or_paths, start = True):
458
  """Computes the endpoints in batch of sparse / dense alignments."""
459
  if alignments_or_paths.shape.rank == 3:  # Sparse format.
460
    pos = alignments_or_paths[:, :2]
461
    return pos[Ellipsis, 0] if start else tf.reduce_max(pos, axis=-1)
462
  else:  # Dense format.
463
    shape = tf.shape(alignments_or_paths)
464
    batch_size = shape[0]
465
    len_x, len_y = shape[1], shape[2]
466
    matches = paths_to_state_indicators(alignments_or_paths, 'match')
467
    matches = tf.reshape(matches, [batch_size, -1])
468
    matches = matches if start else matches[:, ::-1]
469
    raveled_indices = tf.cast(tf.argmax(matches, axis=-1), tf.int32)
470
    start_x = tf.cast(tf.math.floor(raveled_indices / len_x), tf.int32)
471
    start_y = raveled_indices - start_x * len_x
472
    # Uses one-based indexing for consistency with sparse format.
473
    endpoint_x = start_x + 1 if start else len_x - start_x
474
    endpoint_y = start_y + 1 if start else len_y - start_y
475
    return tf.stack([endpoint_x, endpoint_y])
476

477

478
def path_label_squeeze(paths):
479
  """Returns a weights sum of paths solutions, for visualization."""
480
  v_range = tf.range(1, tf.shape(paths)[-1] + 1, dtype=paths.dtype)
481
  return tf.einsum('ijkn,n->ijk', paths, v_range)
482

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.