google-research

Форк
0
266 строк · 8.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Embedding API for pretrained models."""
17

18
import functools
19
from flax.training import common_utils
20

21
import gin
22
import jax
23
import jax.numpy as jnp
24
import tensorflow.compat.v1 as tf
25

26
from protein_lm import data
27
from protein_lm import models
28
from protein_lm import utils
29

30
SMALL_NEGATIVE = -1e10
31

32

33
def _encode_string_sequences(string_sequences, domain, length):
34
  """Encodes string sequences as sequences of int tokens.
35

36
  Args:
37
    string_sequences: An iterable over strings.
38
    domain: An instance of VariableLengthDiscreteDomain.
39
    length: If provided, crop sequences to this length, otherwise use
40
      domain.length.
41

42
  Returns:
43
    A jax array of shape (batch_size, length) with the encoded sequences.
44
  """
45
  if domain is None:
46
    domain = data.protein_domain
47

48
  if length is None:
49
    length = domain.length
50

51
  # Encode sequences, mark the end with a single EOS, and pad with PAD.
52
  batch = domain.encode(string_sequences, pad=False)
53

54
  max_input_length = max(len(s) for s in string_sequences)
55
  crop_length = min(max_input_length, length)
56
  # We perform the padding manually since domain.encode(..., pad=True)
57
  # uses EOS for padding. We use tf directly rather than seq_utils since
58
  # the latter performs `pre` truncation.
59
  batch = [list(elem) + [domain.vocab.eos] for elem in batch]
60
  batch = tf.keras.preprocessing.sequence.pad_sequences(
61
      batch, maxlen=crop_length, value=domain.vocab.pad,
62
      padding='post', truncating='post')
63

64
  return jnp.asarray(batch)
65

66

67
@gin.configurable
68
def sum_reducer(embedding, mask):
69
  """Returns the sum across the unmasked dimensions.
70

71
  Args:
72
    embedding: An array of shape (batch_size, length, emb_size).
73
    mask: An array of shape (batch_size, length).
74

75
  Returns:
76
    An array of shape (batch_size, emb_size).
77
  """
78
  return jnp.sum(embedding * mask[Ellipsis, jnp.newaxis], axis=1)
79

80

81
@gin.configurable
82
def mean_reducer(embedding, mask):
83
  """Returns the mean across the unmasked dimensions.
84

85
  Args:
86
    embedding: An array of shape (batch_size, length, emb_size).
87
    mask: An array of shape (batch_size, length).
88

89
  Returns:
90
    An array of shape (batch_size, emb_size).
91
  """
92
  return sum_reducer(embedding, mask) / jnp.sum(mask, axis=-1, keepdims=True)
93

94

95
@gin.configurable
96
def max_reducer(embedding, mask):
97
  """Returns the max across the unmasked dimensions.
98

99
  Args:
100
    embedding: An array of shape (batch_size, length, emb_size).
101
    mask: An array of shape (batch_size, length).
102

103
  Returns:
104
    An array of shape (batch_size, emb_size).
105
  """
106
  mask = (-mask + 1) * SMALL_NEGATIVE
107
  return jnp.max(embedding + mask[Ellipsis, jnp.newaxis], axis=1)
108

109

110
@gin.configurable
111
def masked_reduce_fn(embedding,
112
                     inputs,
113
                     reducer_fn=mean_reducer,
114
                     domain=None,
115
                     ignore_eos=False,
116
                     ignore_bos=True,
117
                     ignore_pad=True,
118
                     ignore_mask=True):
119
  """Takes the mean across the length dimension, ignoring special tokens.
120

121
  Args:
122
    embedding: An array of shape (batch_size, length, emb_size).
123
    inputs: An array of shape (batch_size, length).
124
    reducer_fn: A callable to perform the reduction given embedding and mask.
125
    domain: An instance of VariableLengthDiscreteDomain.
126
    ignore_eos: Whether to ignore EOS tokens.
127
    ignore_bos: Whether to ignore BOS tokens.
128
    ignore_pad: Whether to ignore PAD tokens.
129
    ignore_mask: Whether to ignore MASK tokens.
130

131
  Returns:
132
    An array of shape (batch_size, emb_size) with the aggregated embeddings.
133
  """
134
  if domain is None:
135
    domain = data.protein_domain
136

137
  mask_tokens = []
138
  if ignore_eos:
139
    mask_tokens.append(domain.vocab.eos)
140
  if ignore_bos:
141
    mask_tokens.append(domain.vocab.bos)
142
  if ignore_pad:
143
    mask_tokens.append(domain.vocab.pad)
144
  if ignore_mask:
145
    mask_tokens.append(domain.vocab.mask)
146

147
  mask = jnp.ones_like(inputs)
148
  for token in mask_tokens:
149
    if token is not None:
150
      mask *= inputs != token
151

152
  return reducer_fn(embedding, mask)
153

154

155
@functools.lru_cache(10)
156
def get_embed_fn(model=None,
157
                 checkpoint_dir=None,
158
                 domain=None,
159
                 output_head='output_emb',
160
                 reduce_fn=None,
161
                 length=None):
162
  """Get a function that maps sequences to fixed-length embedding vectors.
163

164
  Args:
165
    model: A FlaxModel (e.g. FlaxLM or FlaxBERT).
166
    checkpoint_dir: A string directory where the model checkpoint is stored.
167
    domain: An instance of VariableLengthDiscreteDomain.
168
    output_head: Which model output to return. See embed.FlaxModel.
169
    reduce_fn: Postprocessing function to apply on top of embeddings, such as
170
      `masked_reduce_fn`. The reduce_fn takes and input padded embeddings
171
      and padded inputs (to allow masking the pad dimensions). If None, no
172
      reduction is made.
173
    length: Input sequences will be cropped and padded to have length
174
      N = min(max_len, length), where max_len is the length of the longest
175
      sequence in the input data. If length is None, domain.length is used when
176
      computing N.
177

178
  Returns:
179
    Function which accepts sequences and returns batched embeddings. If the
180
      the sequences are strings, we first encode them into the domain.
181
      Otherwise, we assume that they are already encoded.
182
  """
183
  if model is None:
184
    if checkpoint_dir is None:
185
      raise ValueError('Must provide a loaded model or checkpoint directory.')
186
    # Note that this assumes that the model_cls is stored in the config dict.
187
    model = models.load_model(checkpoint_dir=checkpoint_dir)
188
  else:
189
    if checkpoint_dir is not None:
190
      raise ValueError('Provide only one of `model` or checkpoint directory.')
191

192
  if domain is None:
193
    domain = data.protein_domain
194

195
  def predict_fn(model_target, inputs):
196
    emb = models.predict_step(
197
        model_target,
198
        inputs,
199
        preprocess_fn=model.preprocess,
200
        output_head=output_head)
201

202
    if reduce_fn:
203
      # Pass the inputs to allow padding-aware aggregation.
204
      emb = reduce_fn(emb, inputs)
205
    return emb
206

207
  if model.pmap:
208
    p_predict_step = jax.pmap(predict_fn, axis_name='batch')
209
  else:
210
    p_predict_step = predict_fn
211

212
  def _embed(protein_sequences):
213
    """Encode proteins into a batch, embed, and run reduce_fn on output."""
214
    if isinstance(protein_sequences[0], str):
215
      batch = _encode_string_sequences(protein_sequences,
216
                                       domain=domain, length=length)
217
    else:
218
      if not domain.are_valid(protein_sequences).any():
219
        raise ValueError('Input int-encoded sequences are not valid members '
220
                         'of input domain.')
221
      batch = protein_sequences
222

223
    if model.pmap:
224
      batch = common_utils.shard(batch)
225
    result = p_predict_step(model.optimizer.target, batch)
226

227
    if model.pmap:
228
      # Combine the leading two dimensions (ndevices, batch_size / n_devices)
229
      result = jax.numpy.reshape(result, [-1] + list(result.shape[2:]))
230
    return result
231

232
  return _embed
233

234

235
@gin.configurable
236
class ProteinLMEmbedder(object):
237
  """Embeddings from a pretrained language model.
238

239
  Stateful wrapper around get_embed_fn that calls the embed_fn on batches.
240
  """
241

242
  def __init__(self,
243
               model=None,
244
               checkpoint_dir=None,
245
               domain=None,
246
               output_head='output_emb',
247
               reduce_fn=None,
248
               length=None,
249
               batch_size=64):
250
    """Creates an instance of this class."""
251
    self._embed_fn = get_embed_fn(
252
        model=model,
253
        checkpoint_dir=checkpoint_dir,
254
        domain=domain,
255
        output_head=output_head,
256
        reduce_fn=reduce_fn)
257
    self._batch_size = batch_size
258
    self._domain = domain
259
    self._length = length
260

261
  def __call__(self, sequences):
262
    """Embeds int or string sequences."""
263
    if isinstance(sequences[0], str):
264
      sequences = _encode_string_sequences(sequences, domain=self._domain,
265
                                           length=self._length)
266
    return utils.batch_apply(self._embed_fn, sequences, self._batch_size)
267

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.