google-research

Форк
0
/
smith_waterman_test.py 
350 строк · 13.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for the wavefront_tf_ops module."""
17

18
from absl.testing import parameterized
19
import numpy as np
20
import tensorflow as tf
21

22

23
from dedal import alignment
24
from dedal import smith_waterman as tf_ops
25
from dedal import smith_waterman_np as npy_ops
26

27

28
_DECORATORS = [None, tf.function]
29
# TODO(fllinares): include XLA when running on GPU & TPU
30
# _DECORATORS = [None, tf.function, tf.function(experimental_compile=True)]
31

32

33
def random_sim_mat(b, l1, l2, emb_dim=3):
34
  seq_emb1 = tf.random.normal((b, l1, emb_dim))
35
  seq_emb2 = tf.random.normal((b, l2, emb_dim))
36
  return tf.einsum('nik,njk->nij', seq_emb1, seq_emb2)
37

38

39
def random_gap_penalty(minval, maxval, b=None, l1=None, l2=None):
40
  if b is None:
41
    return tf.random.uniform((), minval=minval, maxval=maxval)
42
  elif l1 is None or l2 is None:
43
    return tf.random.uniform((b,), minval=minval, maxval=maxval)
44
  else:
45
    return tf.random.uniform((b, l1, l2), minval=minval, maxval=maxval)
46

47

48
def best_alignment_brute_force(weights):
49
  len_1, len_2, _ = weights.shape
50
  best_alignment = None
51
  best_value = -np.inf
52
  for alignment_mat in npy_ops.alignment_matrices(len_1, len_2):
53
    value = np.vdot(alignment_mat, weights)
54
    if value > best_value:
55
      best_value = value
56
      best_alignment = alignment_mat
57
  return best_alignment
58

59

60
class SmithWatermanAffineTest(parameterized.TestCase, tf.test.TestCase):
61

62
  def setUp(self):
63
    super().setUp()
64
    tf.random.set_seed(1)
65

66
    single_sub = 6*tf.eye(5) - 5*tf.ones((5, 5))
67
    second_sub = tf.tensor_scatter_nd_update(single_sub,
68
                                             indices=[[0, 0], [1, 1]],
69
                                             updates=[-5, -5])
70
    third_sub = tf.tensor_scatter_nd_update(single_sub,
71
                                            indices=[[0, 0], [2, 2]],
72
                                            updates=[-5, -5])
73
    fourth_sub = tf.tensor_scatter_nd_update(single_sub,
74
                                             indices=[[0, 0], [4, 4]],
75
                                             updates=[-5, -5])
76
    self._toy_sub = tf.stack([single_sub, second_sub, third_sub, fourth_sub])
77
    self._toy_gap_open = 0.03 * tf.ones((self._toy_sub.shape[0],))
78
    self._toy_gap_extend = 0.02 * tf.ones((self._toy_sub.shape[0],))
79
    self._w = alignment.weights_from_sim_mat(self._toy_sub,
80
                                             self._toy_gap_open,
81
                                             self._toy_gap_extend)
82

83
  @parameterized.product(
84
      decorator=_DECORATORS,
85
      b=[1, 8],
86
      # The test used to pass for 'rank0' but does not anymore.
87
      # Skipping this case for now as it's not used anyway.
88
      gap_pen_type=['rank1', 'rank3'],
89
  )
90
  def test_weights_from_sim_mat(self, decorator, b, gap_pen_type):
91
    l1, l2 = 14, 37
92
    minval_open, maxval_open = 10.5, 11.5
93
    minval_extend, maxval_extend = 0.8, 1.2
94

95
    sim_mat = random_sim_mat(b, l1=l1, l2=l2, emb_dim=3)
96
    if gap_pen_type == 'rank0':
97
      gap_open = random_gap_penalty(minval_open, maxval_open)
98
      gap_extend = random_gap_penalty(minval_extend, maxval_extend)
99
    elif gap_pen_type == 'rank1':
100
      gap_open = random_gap_penalty(minval_open, maxval_open, b)
101
      gap_extend = random_gap_penalty(minval_extend, maxval_extend, b)
102
    else:
103
      gap_open = random_gap_penalty(minval_open, maxval_open, b, l1, l2)
104
      gap_extend = random_gap_penalty(minval_extend, maxval_extend, b, l1, l2)
105

106
    weights_from_sim_mat_fn = alignment.weights_from_sim_mat
107
    if decorator is not None:
108
      weights_from_sim_mat_fn = decorator(weights_from_sim_mat_fn)
109

110
    w = weights_from_sim_mat_fn(sim_mat, gap_open, gap_extend)
111

112
    self.assertEqual(w.shape, (b, l1, l2, 9))
113
    self.assertAllEqual(w[Ellipsis, 0], w[Ellipsis, 1])
114
    self.assertAllEqual(w[Ellipsis, 0], w[Ellipsis, 2])
115
    self.assertAllEqual(w[Ellipsis, 0], w[Ellipsis, 3])
116
    if gap_pen_type == 'rank0':
117
      gap_open = tf.fill([b, l1, l2], gap_open)
118
      gap_extend = tf.fill([b, l1, l2], gap_extend)
119
    elif gap_pen_type == 'rank1':
120
      gap_open = tf.tile(gap_open[:, None, None], [1, l1, l2])
121
      gap_extend = tf.tile(gap_extend[:, None, None], [1, l1, l2])
122
    self.assertAllEqual(w[Ellipsis, 4], -gap_open)
123
    self.assertAllEqual(w[Ellipsis, 5], -gap_extend)
124
    self.assertAllEqual(w[Ellipsis, 6], -gap_open)
125
    self.assertAllEqual(w[Ellipsis, 7], -gap_open)
126
    self.assertAllEqual(w[Ellipsis, 8], -gap_extend)
127

128
  @parameterized.product(
129
      decorator=_DECORATORS,
130
      b=[1, 8],
131
  )
132
  def test_wavefrontify(self, decorator, b):
133
    l1, l2, s = 14, 37, 9
134
    minval_open, maxval_open = 10.5, 11.5
135
    minval_extend, maxval_extend = 0.8, 1.2
136

137
    sim_mat = random_sim_mat(b, l1=l1, l2=l2, emb_dim=3)
138
    gap_open = random_gap_penalty(minval_open, maxval_open, b, l1, l2)
139
    gap_extend = random_gap_penalty(minval_extend, maxval_extend, b, l1, l2)
140
    w = alignment.weights_from_sim_mat(sim_mat, gap_open, gap_extend)
141

142
    wavefrontify_fn = tf_ops.wavefrontify
143
    unwavefrontify_fn = tf_ops.unwavefrontify
144
    if decorator is not None:
145
      wavefrontify_fn = decorator(wavefrontify_fn)
146
      unwavefrontify_fn = decorator(unwavefrontify_fn)
147

148
    w_wavefrontified = wavefrontify_fn(w)
149
    w_unwavefrontified = unwavefrontify_fn(w_wavefrontified)
150

151
    self.assertEqual(w_wavefrontified.shape, (l1 + l2 - 1, s, l1, b))
152
    self.assertAllEqual(w_unwavefrontified, w)
153
    for n in tf.range(b):
154
      for a in tf.range(s):
155
        for i in tf.range(l1):
156
          for j in tf.range(l2):
157
            self.assertEqual(w_wavefrontified[i + j, a, i, n], w[n, i, j, a])
158

159
  @parameterized.product(
160
      decorator=_DECORATORS,
161
      tol=[1e-3, 1e-6, 1e-9],
162
  )
163
  def test_toy_smith_waterman(self, decorator, tol):
164
    smith_waterman_fn = tf_ops.hard_sw_affine
165
    if decorator is not None:
166
      smith_waterman_fn = decorator(smith_waterman_fn)
167

168
    values, paths = smith_waterman_fn(self._w, tol)
169
    paths_squeeze = alignment.path_label_squeeze(paths)
170
    all_matches = tf.where(
171
        alignment.paths_to_state_indicators(paths, 'match'))
172

173
    values_test = tf.constant([5.0, 3.0, 2.94, 3.0], dtype=tf.float32)
174
    self.assertAllClose(values, values_test, atol=2 * tol)
175

176
    paths_squeeze_test = tf.constant([[[1., 0., 0., 0., 0.],
177
                                       [0., 2., 0., 0., 0.],
178
                                       [0., 0., 2., 0., 0.],
179
                                       [0., 0., 0., 2., 0.],
180
                                       [0., 0., 0., 0., 2.]],
181
                                      [[0., 0., 0., 0., 0.],
182
                                       [0., 0., 0., 0., 0.],
183
                                       [0., 0., 1., 0., 0.],
184
                                       [0., 0., 0., 2., 0.],
185
                                       [0., 0., 0., 0., 2.]],
186
                                      [[0., 0., 0., 0., 0.],
187
                                       [0., 1., 5., 0., 0.],
188
                                       [0., 0., 8., 0., 0.],
189
                                       [0., 0., 0., 4., 0.],
190
                                       [0., 0., 0., 0., 2.]],
191
                                      [[0., 0., 0., 0., 0.],
192
                                       [0., 1., 0., 0., 0.],
193
                                       [0., 0., 2., 0., 0.],
194
                                       [0., 0., 0., 2., 0.],
195
                                       [0., 0., 0., 0., 0.]]], dtype=tf.float32)
196
    self.assertAllEqual(paths_squeeze, paths_squeeze_test)
197

198
    all_matches_test = tf.constant([[0, 0, 0],
199
                                    [0, 1, 1],
200
                                    [0, 2, 2],
201
                                    [0, 3, 3],
202
                                    [0, 4, 4],
203
                                    [1, 2, 2],
204
                                    [1, 3, 3],
205
                                    [1, 4, 4],
206
                                    [2, 1, 1],
207
                                    [2, 3, 3],
208
                                    [2, 4, 4],
209
                                    [3, 1, 1],
210
                                    [3, 2, 2],
211
                                    [3, 3, 3]], dtype=tf.int32)
212
    self.assertAllEqual(all_matches, all_matches_test)
213

214
  @parameterized.product(
215
      decorator=_DECORATORS,
216
  )
217
  def test_smith_waterman_termination(self, decorator):
218
    smith_waterman_fn = tf_ops.hard_sw_affine
219
    if decorator is not None:
220
      smith_waterman_fn = decorator(smith_waterman_fn)
221
    tol = 1e-6
222

223
    single_sub = tf.concat([- 5*tf.ones((3, 1)),
224
                            6*tf.eye(3) - 5*tf.ones((3, 3))], 1)
225
    toy_sub = tf.expand_dims(single_sub, 0)
226
    toy_gap_open = 0.03 * tf.ones((toy_sub.shape[0],))
227
    toy_gap_extend = 0.02 * tf.ones((toy_sub.shape[0],))
228
    w = alignment.weights_from_sim_mat(toy_sub, toy_gap_open, toy_gap_extend)
229

230
    values, paths = smith_waterman_fn(w, tol=tol)
231
    paths_squeeze = alignment.path_label_squeeze(paths)
232

233
    self.assertAllClose(values, [3.0], atol=2 * tol)
234
    paths_squeeze_test = tf.convert_to_tensor([[[0., 1., 0., 0.],
235
                                                [0., 0., 2., 0.],
236
                                                [0., 0., 0., 2.]]], tf.float32)
237
    self.assertAllEqual(paths_squeeze, paths_squeeze_test)
238

239
  @parameterized.product(
240
      decorator=_DECORATORS,
241
  )
242
  def test_smith_waterman_empty(self, decorator):
243
    smith_waterman_fn = tf_ops.hard_sw_affine
244
    if decorator is not None:
245
      smith_waterman_fn = decorator(smith_waterman_fn)
246
    tol = 1e-6
247

248
    single_sub = - 5*tf.ones((5, 5))
249
    toy_sub = tf.expand_dims(single_sub, 0)
250
    toy_gap_open = 0.03 * tf.ones((toy_sub.shape[0],))
251
    toy_gap_extend = 0.02 * tf.ones((toy_sub.shape[0],))
252
    w = alignment.weights_from_sim_mat(toy_sub, toy_gap_open, toy_gap_extend)
253

254
    values, paths = smith_waterman_fn(w, tol=tol)
255
    paths_squeeze = alignment.path_label_squeeze(paths)
256

257
    self.assertAllClose(values, [0.0], atol=2 * tol)
258
    self.assertAllEqual(paths_squeeze, tf.zeros([1, 5, 5], tf.float32))
259

260
  @parameterized.product(
261
      decorator=_DECORATORS,
262
  )
263
  def test_backtracking_against_autodiff(self, decorator):
264
    def grad_fn(w):
265
      with tf.GradientTape() as tape:
266
        tape.watch(w)
267
        maxes, _ = tf_ops.hard_sw_affine(w)
268
      return tape.gradient(maxes, w)
269
    smith_waterman_fn = tf_ops.hard_sw_affine
270
    if decorator is not None:
271
      grad_fn = decorator(grad_fn)
272
      smith_waterman_fn = decorator(smith_waterman_fn)
273

274
    # Check that autodiff recovers the handwritten backtracking.
275
    _, paths = smith_waterman_fn(self._w)
276
    paths2 = grad_fn(self._w)
277

278
    self.assertAllEqual(paths, paths2)
279

280
  @parameterized.product(
281
      decorator=_DECORATORS,
282
  )
283
  def test_backtracking_against_bruteforce(self, decorator):
284
    smith_waterman_fn = tf_ops.hard_sw_affine
285
    if decorator is not None:
286
      smith_waterman_fn = decorator(smith_waterman_fn)
287

288
    _, paths = smith_waterman_fn(self._w)
289
    paths3 = np.array([best_alignment_brute_force(self._w[i])
290
                       for i in range(self._w.shape[0])])
291

292
    self.assertAllEqual(paths, paths3)
293

294
  @parameterized.product(
295
      decorator=_DECORATORS,
296
  )
297
  def test_perturbation_friendly_version_against_numpy_version(self, decorator):
298
    smith_waterman_fn = tf_ops.hard_sw_affine
299
    if decorator is not None:
300
      smith_waterman_fn = decorator(smith_waterman_fn)
301

302
    maxes, _ = smith_waterman_fn(self._w)
303
    maxes2 = npy_ops.soft_sw_affine(self._toy_sub.numpy(),
304
                                    self._toy_gap_open.numpy(),
305
                                    self._toy_gap_extend.numpy(),
306
                                    temperature=0)
307

308
    self.assertAllClose(maxes, maxes2)
309

310
  @parameterized.product(
311
      decorator=_DECORATORS,
312
  )
313
  def test_soft_version_against_perturbation_friendly_version(self, decorator):
314
    soft_version_fn = tf_ops.soft_sw_affine_fwd
315
    perturbation_friendly_fn = tf_ops.hard_sw_affine
316
    if decorator is not None:
317
      soft_version_fn = decorator(soft_version_fn)
318
      perturbation_friendly_fn = decorator(perturbation_friendly_fn)
319

320
    maxes, _ = perturbation_friendly_fn(self._w)
321
    maxes2 = soft_version_fn(self._toy_sub,
322
                             self._toy_gap_open,
323
                             self._toy_gap_extend,
324
                             temp=None)
325

326
    self.assertAllClose(maxes, maxes2)
327

328
  @parameterized.product(
329
      decorator=_DECORATORS,
330
      temp=[1e-3, 1.0, 1e3],
331
  )
332
  def test_soft_version_against_numpy_version(self, decorator, temp):
333
    soft_version_fn = tf_ops.soft_sw_affine_fwd
334
    if decorator is not None:
335
      soft_version_fn = decorator(soft_version_fn)
336

337
      maxes = npy_ops.soft_sw_affine(self._toy_sub.numpy(),
338
                                     self._toy_gap_open.numpy(),
339
                                     self._toy_gap_extend.numpy(),
340
                                     temperature=temp)
341
      maxes2 = soft_version_fn(self._toy_sub,
342
                               self._toy_gap_open,
343
                               self._toy_gap_extend,
344
                               temp=temp)
345

346
      self.assertAllClose(maxes, maxes2)
347

348

349
if __name__ == '__main__':
350
  tf.test.main()
351

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.