google-research
350 строк · 13.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for the wavefront_tf_ops module."""
17
18from absl.testing import parameterized19import numpy as np20import tensorflow as tf21
22
23from dedal import alignment24from dedal import smith_waterman as tf_ops25from dedal import smith_waterman_np as npy_ops26
27
28_DECORATORS = [None, tf.function]29# TODO(fllinares): include XLA when running on GPU & TPU
30# _DECORATORS = [None, tf.function, tf.function(experimental_compile=True)]
31
32
33def random_sim_mat(b, l1, l2, emb_dim=3):34seq_emb1 = tf.random.normal((b, l1, emb_dim))35seq_emb2 = tf.random.normal((b, l2, emb_dim))36return tf.einsum('nik,njk->nij', seq_emb1, seq_emb2)37
38
39def random_gap_penalty(minval, maxval, b=None, l1=None, l2=None):40if b is None:41return tf.random.uniform((), minval=minval, maxval=maxval)42elif l1 is None or l2 is None:43return tf.random.uniform((b,), minval=minval, maxval=maxval)44else:45return tf.random.uniform((b, l1, l2), minval=minval, maxval=maxval)46
47
48def best_alignment_brute_force(weights):49len_1, len_2, _ = weights.shape50best_alignment = None51best_value = -np.inf52for alignment_mat in npy_ops.alignment_matrices(len_1, len_2):53value = np.vdot(alignment_mat, weights)54if value > best_value:55best_value = value56best_alignment = alignment_mat57return best_alignment58
59
60class SmithWatermanAffineTest(parameterized.TestCase, tf.test.TestCase):61
62def setUp(self):63super().setUp()64tf.random.set_seed(1)65
66single_sub = 6*tf.eye(5) - 5*tf.ones((5, 5))67second_sub = tf.tensor_scatter_nd_update(single_sub,68indices=[[0, 0], [1, 1]],69updates=[-5, -5])70third_sub = tf.tensor_scatter_nd_update(single_sub,71indices=[[0, 0], [2, 2]],72updates=[-5, -5])73fourth_sub = tf.tensor_scatter_nd_update(single_sub,74indices=[[0, 0], [4, 4]],75updates=[-5, -5])76self._toy_sub = tf.stack([single_sub, second_sub, third_sub, fourth_sub])77self._toy_gap_open = 0.03 * tf.ones((self._toy_sub.shape[0],))78self._toy_gap_extend = 0.02 * tf.ones((self._toy_sub.shape[0],))79self._w = alignment.weights_from_sim_mat(self._toy_sub,80self._toy_gap_open,81self._toy_gap_extend)82
83@parameterized.product(84decorator=_DECORATORS,85b=[1, 8],86# The test used to pass for 'rank0' but does not anymore.87# Skipping this case for now as it's not used anyway.88gap_pen_type=['rank1', 'rank3'],89)90def test_weights_from_sim_mat(self, decorator, b, gap_pen_type):91l1, l2 = 14, 3792minval_open, maxval_open = 10.5, 11.593minval_extend, maxval_extend = 0.8, 1.294
95sim_mat = random_sim_mat(b, l1=l1, l2=l2, emb_dim=3)96if gap_pen_type == 'rank0':97gap_open = random_gap_penalty(minval_open, maxval_open)98gap_extend = random_gap_penalty(minval_extend, maxval_extend)99elif gap_pen_type == 'rank1':100gap_open = random_gap_penalty(minval_open, maxval_open, b)101gap_extend = random_gap_penalty(minval_extend, maxval_extend, b)102else:103gap_open = random_gap_penalty(minval_open, maxval_open, b, l1, l2)104gap_extend = random_gap_penalty(minval_extend, maxval_extend, b, l1, l2)105
106weights_from_sim_mat_fn = alignment.weights_from_sim_mat107if decorator is not None:108weights_from_sim_mat_fn = decorator(weights_from_sim_mat_fn)109
110w = weights_from_sim_mat_fn(sim_mat, gap_open, gap_extend)111
112self.assertEqual(w.shape, (b, l1, l2, 9))113self.assertAllEqual(w[Ellipsis, 0], w[Ellipsis, 1])114self.assertAllEqual(w[Ellipsis, 0], w[Ellipsis, 2])115self.assertAllEqual(w[Ellipsis, 0], w[Ellipsis, 3])116if gap_pen_type == 'rank0':117gap_open = tf.fill([b, l1, l2], gap_open)118gap_extend = tf.fill([b, l1, l2], gap_extend)119elif gap_pen_type == 'rank1':120gap_open = tf.tile(gap_open[:, None, None], [1, l1, l2])121gap_extend = tf.tile(gap_extend[:, None, None], [1, l1, l2])122self.assertAllEqual(w[Ellipsis, 4], -gap_open)123self.assertAllEqual(w[Ellipsis, 5], -gap_extend)124self.assertAllEqual(w[Ellipsis, 6], -gap_open)125self.assertAllEqual(w[Ellipsis, 7], -gap_open)126self.assertAllEqual(w[Ellipsis, 8], -gap_extend)127
128@parameterized.product(129decorator=_DECORATORS,130b=[1, 8],131)132def test_wavefrontify(self, decorator, b):133l1, l2, s = 14, 37, 9134minval_open, maxval_open = 10.5, 11.5135minval_extend, maxval_extend = 0.8, 1.2136
137sim_mat = random_sim_mat(b, l1=l1, l2=l2, emb_dim=3)138gap_open = random_gap_penalty(minval_open, maxval_open, b, l1, l2)139gap_extend = random_gap_penalty(minval_extend, maxval_extend, b, l1, l2)140w = alignment.weights_from_sim_mat(sim_mat, gap_open, gap_extend)141
142wavefrontify_fn = tf_ops.wavefrontify143unwavefrontify_fn = tf_ops.unwavefrontify144if decorator is not None:145wavefrontify_fn = decorator(wavefrontify_fn)146unwavefrontify_fn = decorator(unwavefrontify_fn)147
148w_wavefrontified = wavefrontify_fn(w)149w_unwavefrontified = unwavefrontify_fn(w_wavefrontified)150
151self.assertEqual(w_wavefrontified.shape, (l1 + l2 - 1, s, l1, b))152self.assertAllEqual(w_unwavefrontified, w)153for n in tf.range(b):154for a in tf.range(s):155for i in tf.range(l1):156for j in tf.range(l2):157self.assertEqual(w_wavefrontified[i + j, a, i, n], w[n, i, j, a])158
159@parameterized.product(160decorator=_DECORATORS,161tol=[1e-3, 1e-6, 1e-9],162)163def test_toy_smith_waterman(self, decorator, tol):164smith_waterman_fn = tf_ops.hard_sw_affine165if decorator is not None:166smith_waterman_fn = decorator(smith_waterman_fn)167
168values, paths = smith_waterman_fn(self._w, tol)169paths_squeeze = alignment.path_label_squeeze(paths)170all_matches = tf.where(171alignment.paths_to_state_indicators(paths, 'match'))172
173values_test = tf.constant([5.0, 3.0, 2.94, 3.0], dtype=tf.float32)174self.assertAllClose(values, values_test, atol=2 * tol)175
176paths_squeeze_test = tf.constant([[[1., 0., 0., 0., 0.],177[0., 2., 0., 0., 0.],178[0., 0., 2., 0., 0.],179[0., 0., 0., 2., 0.],180[0., 0., 0., 0., 2.]],181[[0., 0., 0., 0., 0.],182[0., 0., 0., 0., 0.],183[0., 0., 1., 0., 0.],184[0., 0., 0., 2., 0.],185[0., 0., 0., 0., 2.]],186[[0., 0., 0., 0., 0.],187[0., 1., 5., 0., 0.],188[0., 0., 8., 0., 0.],189[0., 0., 0., 4., 0.],190[0., 0., 0., 0., 2.]],191[[0., 0., 0., 0., 0.],192[0., 1., 0., 0., 0.],193[0., 0., 2., 0., 0.],194[0., 0., 0., 2., 0.],195[0., 0., 0., 0., 0.]]], dtype=tf.float32)196self.assertAllEqual(paths_squeeze, paths_squeeze_test)197
198all_matches_test = tf.constant([[0, 0, 0],199[0, 1, 1],200[0, 2, 2],201[0, 3, 3],202[0, 4, 4],203[1, 2, 2],204[1, 3, 3],205[1, 4, 4],206[2, 1, 1],207[2, 3, 3],208[2, 4, 4],209[3, 1, 1],210[3, 2, 2],211[3, 3, 3]], dtype=tf.int32)212self.assertAllEqual(all_matches, all_matches_test)213
214@parameterized.product(215decorator=_DECORATORS,216)217def test_smith_waterman_termination(self, decorator):218smith_waterman_fn = tf_ops.hard_sw_affine219if decorator is not None:220smith_waterman_fn = decorator(smith_waterman_fn)221tol = 1e-6222
223single_sub = tf.concat([- 5*tf.ones((3, 1)),2246*tf.eye(3) - 5*tf.ones((3, 3))], 1)225toy_sub = tf.expand_dims(single_sub, 0)226toy_gap_open = 0.03 * tf.ones((toy_sub.shape[0],))227toy_gap_extend = 0.02 * tf.ones((toy_sub.shape[0],))228w = alignment.weights_from_sim_mat(toy_sub, toy_gap_open, toy_gap_extend)229
230values, paths = smith_waterman_fn(w, tol=tol)231paths_squeeze = alignment.path_label_squeeze(paths)232
233self.assertAllClose(values, [3.0], atol=2 * tol)234paths_squeeze_test = tf.convert_to_tensor([[[0., 1., 0., 0.],235[0., 0., 2., 0.],236[0., 0., 0., 2.]]], tf.float32)237self.assertAllEqual(paths_squeeze, paths_squeeze_test)238
239@parameterized.product(240decorator=_DECORATORS,241)242def test_smith_waterman_empty(self, decorator):243smith_waterman_fn = tf_ops.hard_sw_affine244if decorator is not None:245smith_waterman_fn = decorator(smith_waterman_fn)246tol = 1e-6247
248single_sub = - 5*tf.ones((5, 5))249toy_sub = tf.expand_dims(single_sub, 0)250toy_gap_open = 0.03 * tf.ones((toy_sub.shape[0],))251toy_gap_extend = 0.02 * tf.ones((toy_sub.shape[0],))252w = alignment.weights_from_sim_mat(toy_sub, toy_gap_open, toy_gap_extend)253
254values, paths = smith_waterman_fn(w, tol=tol)255paths_squeeze = alignment.path_label_squeeze(paths)256
257self.assertAllClose(values, [0.0], atol=2 * tol)258self.assertAllEqual(paths_squeeze, tf.zeros([1, 5, 5], tf.float32))259
260@parameterized.product(261decorator=_DECORATORS,262)263def test_backtracking_against_autodiff(self, decorator):264def grad_fn(w):265with tf.GradientTape() as tape:266tape.watch(w)267maxes, _ = tf_ops.hard_sw_affine(w)268return tape.gradient(maxes, w)269smith_waterman_fn = tf_ops.hard_sw_affine270if decorator is not None:271grad_fn = decorator(grad_fn)272smith_waterman_fn = decorator(smith_waterman_fn)273
274# Check that autodiff recovers the handwritten backtracking.275_, paths = smith_waterman_fn(self._w)276paths2 = grad_fn(self._w)277
278self.assertAllEqual(paths, paths2)279
280@parameterized.product(281decorator=_DECORATORS,282)283def test_backtracking_against_bruteforce(self, decorator):284smith_waterman_fn = tf_ops.hard_sw_affine285if decorator is not None:286smith_waterman_fn = decorator(smith_waterman_fn)287
288_, paths = smith_waterman_fn(self._w)289paths3 = np.array([best_alignment_brute_force(self._w[i])290for i in range(self._w.shape[0])])291
292self.assertAllEqual(paths, paths3)293
294@parameterized.product(295decorator=_DECORATORS,296)297def test_perturbation_friendly_version_against_numpy_version(self, decorator):298smith_waterman_fn = tf_ops.hard_sw_affine299if decorator is not None:300smith_waterman_fn = decorator(smith_waterman_fn)301
302maxes, _ = smith_waterman_fn(self._w)303maxes2 = npy_ops.soft_sw_affine(self._toy_sub.numpy(),304self._toy_gap_open.numpy(),305self._toy_gap_extend.numpy(),306temperature=0)307
308self.assertAllClose(maxes, maxes2)309
310@parameterized.product(311decorator=_DECORATORS,312)313def test_soft_version_against_perturbation_friendly_version(self, decorator):314soft_version_fn = tf_ops.soft_sw_affine_fwd315perturbation_friendly_fn = tf_ops.hard_sw_affine316if decorator is not None:317soft_version_fn = decorator(soft_version_fn)318perturbation_friendly_fn = decorator(perturbation_friendly_fn)319
320maxes, _ = perturbation_friendly_fn(self._w)321maxes2 = soft_version_fn(self._toy_sub,322self._toy_gap_open,323self._toy_gap_extend,324temp=None)325
326self.assertAllClose(maxes, maxes2)327
328@parameterized.product(329decorator=_DECORATORS,330temp=[1e-3, 1.0, 1e3],331)332def test_soft_version_against_numpy_version(self, decorator, temp):333soft_version_fn = tf_ops.soft_sw_affine_fwd334if decorator is not None:335soft_version_fn = decorator(soft_version_fn)336
337maxes = npy_ops.soft_sw_affine(self._toy_sub.numpy(),338self._toy_gap_open.numpy(),339self._toy_gap_extend.numpy(),340temperature=temp)341maxes2 = soft_version_fn(self._toy_sub,342self._toy_gap_open,343self._toy_gap_extend,344temp=temp)345
346self.assertAllClose(maxes, maxes2)347
348
349if __name__ == '__main__':350tf.test.main()351