google-research
481 строка · 19.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Functions used to manipulate alignments and smith-waterman parameters."""
17
18from typing import Sequence, Tuple, Union
19
20import tensorflow as tf
21
22# Type aliases
23PackedSWParams = tf.Tensor
24UnpackedSWParams = Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
25SWParams = Union[PackedSWParams, UnpackedSWParams]
26
27
28# SW dynamic program edge types, grouped by associated edge weight kind. See
29MATCH_STATES = [0, 1, 2, 3]
30GAP_OPEN_STATES = [4, 6, 7]
31GAP_EXTEND_STATES = [5, 8]
32STATES = {
33'match': MATCH_STATES,
34'gap_open': GAP_OPEN_STATES,
35'gap_extend': GAP_EXTEND_STATES,
36}
37
38
39def large_compatible_positive(tensor_type):
40"""Large positive number as Tensor.
41
42This function is necessary because the standard value for "inf" in this module
43(1e9) cannot be represented using tf.float16.
44
45NOTE(fllinares): Borrowed from
46tensorflow/python/keras/layers/advanced_activations.py
47which is used already in this codebase indirectly (e.g. in self-attention
48layers).
49
50Args:
51tensor_type: a dtype to determine the type.
52
53Returns:
54A large positive number.
55"""
56if tensor_type == tf.dtypes.float16:
57return tf.dtypes.float16.max
58return tf.convert_to_tensor(1e9, dtype=tensor_type)
59
60
61def top_pad(t, v):
62"""Pads tf.Tensor `t` by prepending `v` along the leading dimension."""
63return tf.pad(t, [[1, 0], [0, 0], [0, 0]], constant_values=v)
64
65
66def left_pad(t, v):
67"""Pads tf.Tensor `t` by prepending `v` along the second leading dimension."""
68return tf.pad(t, [[0, 0], [1, 0], [0, 0]], constant_values=v)
69
70
71def right_pad(t, v):
72"""Pads tf.Tensor `t` by appending `v` along the second leading dimension."""
73return tf.pad(t, [[0, 0], [0, 1], [0, 0]], constant_values=v)
74
75
76def alignments_to_paths(
77alignments, len_x, len_y):
78"""Converts sparse representation of alignments into dense paths tensor.
79
80Args:
81alignments: A tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
82enc_trans], 1) such that
83(pos_x[b][i], pos_y[b][i], enc_trans[b][i]) represents the i-th
84transition in the alignment for the b-th sequence pair in the minibatch.
85Both pos_x and pos_y are assumed to use one-based indexing and enc_trans
86follows the (categorical) 9-state encoding of edge types used throughout
87alignment/smith_waterman.py.
88len_x: The (padded) length of "X"/"query" sequences in the minibatch.
89len_y: The (padded) length of "Y"/"subject" sequences in the minibatch.
90
91Returns:
92A tf.Tensor of type tf.float32 and shape (batch_size, len_x, len_y, 9) with
93binary entries, representing the trajectory of the indices along the
94alignment path, by having a one along the taken edges, with nine possible
95edges for each i,j.
96"""
97batch_size = tf.shape(alignments)[0]
98align_len = tf.shape(alignments)[-1]
99
100# Tensor with the same shape as pos_x, pos_y and enc_trans such that
101# seq_indicators[b][l] = b for all l in [0, align_len).
102seq_indicators = tf.multiply(tf.expand_dims(tf.range(batch_size), -1),
103tf.ones((1, align_len), dtype=tf.int32))
104# Prepares inputs to scatter_nd.
105indices = tf.concat([seq_indicators[Ellipsis, None],
106tf.transpose(alignments, (0, 2, 1))], -1)
107indices = tf.reshape(indices, (-1, 4))
108updates = tf.ones(tf.shape(indices)[0], dtype=tf.float32)
109shape = (batch_size, len_x + 1, len_y + 1, 9)
110
111# Note(fllinares): this is a (fairly ugly) hack to deal with padding.
112# - pos_x, pos_y must use one-based indexing instead of zero-based indexing.
113# - we use the (b, 0, 0, 0) entries of paths as "padding dumps".
114# - the resulting tensor will be sliced to remove these starting row/col.
115paths = tf.scatter_nd(indices, updates, shape)
116return paths[:, 1:, 1:, :]
117
118
119def alignments_to_state_indices(
120alignments,
121states,
122zero_based_idx = True,
123):
124"""Retrieves indices of MATCH/GAP OPEN/GAP EXTEND states in alignments.
125
126Args:
127alignments: A tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
128enc_trans], 1) such that
129(pos_x[b][i], pos_y[b][i], enc_trans[b][i]) represents the i-th
130transition in the alignment for the b-th sequence pair in the minibatch.
131Both pos_x and pos_y are assumed to use one-based indexing and enc_trans
132follows the (categorical) 9-state encoding of edge types used throughout
133alignment/smith_waterman.py.
134states: A Python list of integers in [0, 9), representing an arbitrary
135subset of (encoded) edge types. Can also be set to 'match', 'gap_open' or
136'gap_extend' to query the set of edge types associated with each of those
137conditions.
138zero_based_idx: Whether to use zero-based (True) or one-based (False)
139indexing for the function's output. Note that, however, alignment must use
140one-based indexing regardless of the value of this argument.
141
142Returns:
143A tf.Tensor `state_indices` of type tf.int32 and shape (n_entries, 3) such
144that, for a tf.Tensor `sim_mat` of shape (batch_size, len_x, len_y),
145tf.gather_nd(sim_mat, state_indices)
146returns the set of entries in `sim_mat` along the alignments described by
147`alignment` that correspond to one of the states in `states`.
148
149Note(fllinares): this function aims to provide a way to avoid materializing
150weights in the crf_loss function in alignment/smith_waterman.py, as
151suggested by @mblondel. Some extra care might be needed to keep per-example
152losses, as tf.gather_nd will flatten the output by default. For
153position-independent gap penalties, only the total number of entries per
154example in state_indices would be needed. See `score_from_alignment` below
155for extra details.
156"""
157pos_x, pos_y, enc_trans = alignments[:, 0], alignments[:, 1], alignments[:, 2]
158states = STATES.get(states, states)
159
160# Note(fllinares): another ugly "hack", here we assume one-based idx to encode
161# the padding mask implicitly.
162padding_mask = tf.logical_and(pos_x > 0, pos_y > 0)
163hits = enc_trans == states[0]
164for state in states[1:]:
165hits = tf.logical_or(hits, enc_trans == state)
166hits = tf.logical_and(hits, padding_mask)
167indices = tf.cast(tf.where(hits), tf.int32)
168
169batch_indices = indices[:, 0]
170x_indices = tf.gather_nd(pos_x, indices) - int(zero_based_idx)
171y_indices = tf.gather_nd(pos_y, indices) - int(zero_based_idx)
172state_indices = tf.stack([batch_indices, x_indices, y_indices], axis=0)
173return tf.transpose(state_indices, (1, 0))
174
175
176def paths_to_state_indicators(
177paths,
178states,
179):
180"""Computes (batch_size, len_x, len_y) tensor of binary state indicators.
181
182Args:
183paths: A tf.Tensor of type tf.float32 and shape (batch_size, len_x, len_y,
1849) with binary entries, representing the trajectory of the indices along
185the alignment path, by having a one along the taken edges, with nine
186possible edges for each i,j.
187states: A Python list of integers in [0, 9), representing an arbitrary
188subset of (encoded) edge types. This can also be set to 'match',
189'gap_open' or 'gap_extend' to query the set of edge types associated with
190each of those conditions.
191
192Returns:
193A tf.Tensor `state_indicators` of type tf.float32 and shape (batch_size,
194len_x, len_y) with binary entries such that
195state_indicators[b][i][j] = 1.0
196iff the trajectory of the alignment for the b-th sequence pair passes by
197character pair (i, j) under one of the states in `states`.
198"""
199states = STATES.get(states, states)
200return tf.reduce_max(tf.gather(paths, indices=states, axis=-1), axis=-1)
201
202
203def sw_score_from_alignments(
204sw_params,
205alignments,
206):
207"""Computes SW score of `alignments` for DP parameterized by `sw_params`.
208
209Args:
210sw_params: The parameters (sim_mat, gap_open, gap_extend) for the dynamic
211program underlying the Smith-Waterman algorithm.
212These can be input either as a tuple of tf.Tensor objects or as a single
213"packed" tensor of rank 4. See class `SWParamsFromEmbeddings` in module
214`sw_params_from_embeddings.py` for additional details.
215alignments: A tf.Tensor<int>[batch, 3, align_len] = tf.stack([pos_x, pos_y,
216enc_trans], 1) such that
217(pos_x[b][i], pos_y[b][i], enc_trans[b][i]) represents the i-th
218transition in the alignment for the b-th sequence pair in the minibatch.
219Both pos_x and pos_y are assumed to use one-based indexing and enc_trans
220follows the (categorical) 9-state encoding of edge types used throughout
221alignment/smith_waterman.py.
222
223Returns:
224A tf.Tensor of type tf.float32 and shape (batch_size,) containing the SW
225score of each alignment in the batch.
226"""
227# Ensures SW params are in "unpacked" format.
228if isinstance(sw_params, Sequence): # _UnpackedSWParams format
229sim_mat, gap_open, gap_extend = sw_params
230gap_open, gap_extend = -gap_open, -gap_extend
231else: # _PackedSWParams format
232sim_mat = sw_params[Ellipsis, MATCH_STATES[0]]
233gap_open = sw_params[Ellipsis, GAP_OPEN_STATES[0]]
234gap_extend = sw_params[Ellipsis, GAP_EXTEND_STATES[0]]
235
236batch_size = tf.shape(sim_mat)[0] # Assumed consistent with gap_open/extend.
237
238def dot_by_states(t, states):
239"""Sums entries of t along alignments for queried states."""
240
241def pos_dep_dot(t, states):
242"""Sums entries of t[b,l1,l2] along alignments for queried states."""
243state_indices = alignments_to_state_indices(alignments, states)
244batch_indices = state_indices[:, 0]
245state_entries_along_path = tf.gather_nd(t, state_indices)
246total_per_example = tf.math.unsorted_segment_sum(
247state_entries_along_path, batch_indices, batch_size)
248return total_per_example
249
250def pos_indep_dot(t, states):
251"""Sums entries of t[b] along alignments for queried states."""
252state_indices = alignments_to_state_indices(alignments, states)
253batch_indices = state_indices[:, 0]
254# Note(fllinares): tf.math.bincount unsupported in TPU :(
255n_state_entries_along_path = tf.math.unsorted_segment_sum(
256tf.ones_like(batch_indices, tf.float32), batch_indices, batch_size)
257total_per_example = t * n_state_entries_along_path
258return total_per_example
259
260return (pos_dep_dot(t, states) if t.shape.rank == 3
261else pos_indep_dot(t, states))
262
263sim_per_example = dot_by_states(sim_mat, 'match')
264gap_open_per_example = dot_by_states(gap_open, 'gap_open')
265gap_extend_per_example = dot_by_states(gap_extend, 'gap_extend')
266
267return sim_per_example + gap_open_per_example + gap_extend_per_example
268
269
270def sw_score_from_paths(sw_params, paths):
271"""Computes SW score of `paths` for DP parameterized by `sw_params`.
272
273Args:
274sw_params: The parameters (sim_mat, gap_open, gap_extend) for the dynamic
275program underlying the Smith-Waterman algorithm.
276These can be input either as a tuple of tf.Tensor objects or as a single
277"packed" tensor of rank 4. See class `SWParamsFromEmbeddings` in module
278`sw_params_from_embeddings.py` for additional details.
279paths: A tf.Tensor of type tf.float32 and shape (batch_size, len_x, len_y,
2809) with binary entries, representing the trajectory of the indices along
281the alignment path, by having a one along the taken edges, with nine
282possible edges for each i,j.
283
284Returns:
285A tf.Tensor of type tf.float32 and shape (batch_size,) containing the SW
286score of each alignment in the batch.
287"""
288if isinstance(sw_params, Sequence): # _UnpackedSWParams format
289sw_params = weights_from_sim_mat(*sw_params)
290return tf.reduce_sum(sw_params * paths, axis=[1, 2, 3])
291
292
293def sw_score(
294sw_params,
295alignments_or_paths,
296):
297"""Wraps over sw_score_from_paths and sw_score_from_alignments."""
298if alignments_or_paths.shape.rank == 3: # Sparse format
299return sw_score_from_alignments(sw_params, alignments_or_paths)
300else: # tf.Tensor format
301return sw_score_from_paths(sw_params, alignments_or_paths)
302
303
304def mask_from_similarities(sim_mat,
305dtype = tf.float32,
306pad_penalty = 1e8):
307"""Recovers padding / special token mask from a similarities tensor.
308
309Args:
310sim_mat: A tf.Tensor<float>[batch, len, len] of pairwise similarities. It is
311assumed that entries corresponding to padding / special tokens have been
312masked by being set to have magnitude greater than pad_penalty.
313dtype: The desired dtype for the output mask.
314pad_penalty: The magnitude above which entries are considered to have been
315masked.
316
317Returns:
318A tf.Tensor<dtype>[batch, len, len] with binary entries, with 1.0 signifying
319"real" tokens and 0.0 padding / special tokens.
320"""
321mask = tf.logical_and(sim_mat > -pad_penalty, sim_mat < pad_penalty)
322return tf.cast(mask, dtype)
323
324
325def broadcast_to_rank(t, rank, axis = -1):
326"""Appends dimensions to tf.Tensor `t` at axis `axis` to match rank `rank`."""
327rank_t = t.shape.rank # Assumes ranks are known at compile time (static).
328for _ in range(rank - rank_t):
329t = tf.expand_dims(t, axis=axis)
330return t
331
332
333def broadcast_to_shape(
334t,
335shape,
336):
337"""Appends dimensions to and tiles tf.Tensor t to match desired shape."""
338rank = len(shape)
339t = broadcast_to_rank(t, rank, axis=-1)
340return tf.tile(t, shape // tf.shape(t))
341
342
343def weights_from_sim_mat(
344sim_mat,
345gap_open,
346gap_extend,
347):
348"""Computes the edge weights for the Smith-Waterman LP.
349
350Args:
351sim_mat: a tf.Tensor<float>[batch, len1, len2] with the substitution values
352for pairs of sequences.
353gap_open: a tf.Tensor<float>[batch, len1, len2] or tf.Tensor<float>[batch]
354of penalties for opening a gap.
355gap_extend: a tf.Tensor<float>[batch, len1, len2] or tf.Tensor<float>[batch]
356of penalties for extending a gap.
357
358Returns:
359A single tf.Tensor<float>[batch, len1, len2, 9] of edge weights for nine
360edge types. These correspond to a (strict) subset of allowed (from, to)
361state transitions between four state types, namely, start, match, gap_in_x
362and gap_in_y. Along the last dimension:
363+ The first four (0:4) indices form a tf.Tensor<float>[batch, len1, len2, 4]
364of weights for all edges leading into match states. That is, these
365represent transitions (start, match), (match, match), (gap_in_x, match)
366and (gap_in_y, match), respectively.
367+ The next two (4:6) indices form a tf.Tensor<float>[batch, len1, len2, 2]
368of weights for all edges leading into gap_in_x states. These represent
369transitions (match, gap_in_x) and (gap_in_x, gap_in_x), respectively. Note
370that, by convention, (gap_in_y, gap_in_x) transitions are disallowed.
371+ The last three (6:9) indices form a tf.Tensor<float>[batch, len1, len2, 3]
372of weights for all edges leading into gap_in_y states. These represent
373transitions (match, gap_in_y) and (gap_in_x, gap_in_y) and, finally,
374(gap_in_y, gap_in_y), respectively.
375"""
376b, l1, l2 = tf.shape(sim_mat)[0], tf.shape(sim_mat)[1], tf.shape(sim_mat)[2]
377
378sim_mat = broadcast_to_shape(sim_mat, [b, l1, l2, 4])
379gap_open = broadcast_to_shape(gap_open, [b, l1, l2, 1])
380gap_extend = broadcast_to_shape(gap_extend, [b, l1, l2, 1])
381
382weights_m = sim_mat
383weights_x = tf.concat([-gap_open, -gap_extend], axis=-1)
384weights_y = tf.concat([-gap_open, weights_x], axis=-1)
385
386return tf.concat([weights_m, weights_x, weights_y], axis=-1)
387
388
389def adjoint_weights_from_sim_mat(
390weights,
391gap_open_shape,
392gap_extend_shape,
393):
394"""Computes the adjoint of `weights_from_sim_mat`.
395
396Viewing `weights_from_sim_mat` as a linear map weights = A sw_params, this
397function implements the linear map A^{T} weights. Primarily to be used when
398implementing custom_gradients in functions downstream.
399
400Args:
401weights: a tf.Tensor<float>[batch, len1, len2, 9].
402gap_open_shape: a tf.TensorShape representing the shape of gap_open in
403sw_params.
404gap_extend_shape: a tf.TensorShape representing the shape of gap_extend in
405sw_params.
406
407Returns:
408A tuple (sim_mat_out, gap_open_out, gap_extend_out) such that
409+ sim_mat_out is a tf.Tensor<float>[batch, len1, len2] representing the
410elements of A^{T} weights corresponding to sim_mat.
411+ gap_open_out is a tf.Tensor<float>[gap_open_shape] representing the
412elements of A^{T} weights corresponding to gap_open_shape.
413+ gap_extend_out is a tf.Tensor<float>[gap_extend_shape] representing the
414elements of A^{T} weights corresponding to gap_extend_out.
415"""
416sim_mat_out = tf.reduce_sum(weights[Ellipsis, :4], axis=-1)
417
418# Aggregates output across positions / examples too when appropriate.
419gap_open_out = - (weights[Ellipsis, 4] + weights[Ellipsis, 6] + weights[Ellipsis, 7])
420if gap_open_shape.rank == 1:
421gap_open_out = tf.reduce_sum(gap_open_out, axis=[1, 2])
422elif gap_open_shape.rank == 0:
423gap_open_out = tf.reduce_sum(gap_open_out)
424
425gap_extend_out = - (weights[Ellipsis, 5] + weights[Ellipsis, 8])
426if gap_extend_shape.rank == 1:
427gap_extend_out = tf.reduce_sum(gap_extend_out, axis=[1, 2])
428elif gap_extend_shape.rank == 0:
429gap_extend_out = tf.reduce_sum(gap_extend_out)
430
431return sim_mat_out, gap_open_out, gap_extend_out
432
433
434def length(alignments_or_paths):
435"""Computes the lengths in batch of sparse / dense alignments."""
436if alignments_or_paths.shape.rank == 3: # Sparse format.
437pos_x, pos_y = alignments_or_paths[:, 0], alignments_or_paths[:, 1]
438padding_mask = tf.logical_and(pos_x > 0, pos_y > 0)
439return tf.reduce_sum(tf.cast(padding_mask, tf.float32), axis=-1)
440else: # Dense format.
441return tf.reduce_sum(alignments_or_paths, axis=[1, 2, 3])
442
443
444def state_count(alignments_or_paths, states):
445"""Counts match/gap_open/gap_extend in batch of sparse / dense alignments."""
446if alignments_or_paths.shape.rank == 3: # Sparse format.
447batch_size = tf.shape(alignments_or_paths)[0]
448state_indices = alignments_to_state_indices(alignments_or_paths, states)
449batch_indicators = state_indices[:, 0]
450ones = tf.ones_like(batch_indicators, tf.float32)
451return tf.math.unsorted_segment_sum(ones, batch_indicators, batch_size)
452else: # Dense format.
453state_indicators = paths_to_state_indicators(alignments_or_paths, states)
454return tf.reduce_sum(state_indicators, axis=[1, 2])
455
456
457def endpoints(alignments_or_paths, start = True):
458"""Computes the endpoints in batch of sparse / dense alignments."""
459if alignments_or_paths.shape.rank == 3: # Sparse format.
460pos = alignments_or_paths[:, :2]
461return pos[Ellipsis, 0] if start else tf.reduce_max(pos, axis=-1)
462else: # Dense format.
463shape = tf.shape(alignments_or_paths)
464batch_size = shape[0]
465len_x, len_y = shape[1], shape[2]
466matches = paths_to_state_indicators(alignments_or_paths, 'match')
467matches = tf.reshape(matches, [batch_size, -1])
468matches = matches if start else matches[:, ::-1]
469raveled_indices = tf.cast(tf.argmax(matches, axis=-1), tf.int32)
470start_x = tf.cast(tf.math.floor(raveled_indices / len_x), tf.int32)
471start_y = raveled_indices - start_x * len_x
472# Uses one-based indexing for consistency with sparse format.
473endpoint_x = start_x + 1 if start else len_x - start_x
474endpoint_y = start_y + 1 if start else len_y - start_y
475return tf.stack([endpoint_x, endpoint_y])
476
477
478def path_label_squeeze(paths):
479"""Returns a weights sum of paths solutions, for visualization."""
480v_range = tf.range(1, tf.shape(paths)[-1] + 1, dtype=paths.dtype)
481return tf.einsum('ijkn,n->ijk', paths, v_range)
482