google-research

Форк
0
/
modules_test.py 
103 строки · 3.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for Flax modules."""
17

18
import functools
19

20
from absl.testing import parameterized
21
import jax.numpy as jnp
22
import jax.random as jrandom
23
import numpy as np
24
import tensorflow.compat.v1 as tf
25

26
from protein_lm import domains
27
from protein_lm import models
28
from protein_lm import modules
29

30
lm_cls = functools.partial(
31
    models.FlaxLM,
32
    num_layers=1,
33
    num_heads=1,
34
    emb_dim=64,
35
    mlp_dim=64,
36
    qkv_dim=64)
37

38

39
class ModulesTest(tf.test.TestCase, parameterized.TestCase):
40

41
  @parameterized.parameters(
42
      (modules.AddLearnedPositionalEncodings,),
43
      (modules.AddSinusoidalPositionalEncodings,))
44
  def test_positional_encodings(self, positional_encoding_module):
45
    """Tests that the model runs with both types of positional encodings."""
46
    domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=2)
47
    lm = lm_cls(domain=domain,
48
                positional_encoding_module=positional_encoding_module)
49
    lm.sample(1)
50

51
  def test_embeddings_for_one_hot(self):
52
    """Tests that the embeddings match for int and one-hot representations."""
53
    vocab_size = 10
54
    emb_dim = 7
55
    x_int = jnp.array([[1, 3], [2, 8]])
56
    module = modules.Embed.partial(
57
        num_embeddings=vocab_size, num_features=emb_dim)
58
    _, params = module.init(jrandom.PRNGKey(0), x_int)
59
    emb_int = module.call(params, x_int)
60
    x_one_hot = jnp.eye(vocab_size)[x_int]
61
    emb_one_hot = module.call(params, x_one_hot)
62
    self.assertAllEqual(emb_int, emb_one_hot)
63

64
  def test_embeddings_for_dist(self):
65
    """Tests that the embeddings for soft inputs contain both tokens."""
66
    vocab_size = 5
67
    emb_dim = 7
68
    x_int = np.array([[1], [3]])
69
    module = modules.Embed.partial(
70
        num_embeddings=vocab_size, num_features=emb_dim)
71
    _, params = module.init(jrandom.PRNGKey(0), x_int)
72
    emb_int = module.call(params, x_int)
73
    x_dist = np.array([[[0, 0.25, 0, 0.75, 0]], [[0, 0.5, 0, 0.5, 0]]])
74
    emb_dist = np.array(module.call(params, x_dist))
75
    emb_expected = np.array([[emb_int[0, 0] * 0.25 + emb_int[1, 0] * 0.75],
76
                             [emb_int[0, 0] * 0.5 + emb_int[1, 0] * 0.5]])
77
    self.assertAllClose(emb_dist, emb_expected)
78

79
  @parameterized.parameters(
80
      ('logits', False),
81
      (['logits'], True),
82
      (('logits', 'output_emb'), True),
83
  )
84
  def test_output_head(self, output_head, multiple_heads):
85
    domain = domains.FixedLengthDiscreteDomain(vocab_size=2, length=2)
86
    inputs = domain.sample_uniformly(8)
87
    lm = lm_cls(domain=domain, pmap=False)
88
    outputs = models.predict_step(
89
        lm.optimizer.target,
90
        inputs,
91
        preprocess_fn=lm.preprocess,
92
        output_head=output_head)
93
    if multiple_heads:
94
      self.assertIsInstance(outputs, dict)
95
      self.assertLen(outputs, len(output_head))
96
    else:
97
      # We should have gotten a single output, the logits.
98
      self.assertEqual(outputs.shape,
99
                       (inputs.shape[0], inputs.shape[1], lm.vocab_size))
100

101

102
if __name__ == '__main__':
103
  tf.test.main()
104

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.