google-research
266 строк · 8.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Embedding API for pretrained models."""
17
18import functools
19from flax.training import common_utils
20
21import gin
22import jax
23import jax.numpy as jnp
24import tensorflow.compat.v1 as tf
25
26from protein_lm import data
27from protein_lm import models
28from protein_lm import utils
29
30SMALL_NEGATIVE = -1e10
31
32
33def _encode_string_sequences(string_sequences, domain, length):
34"""Encodes string sequences as sequences of int tokens.
35
36Args:
37string_sequences: An iterable over strings.
38domain: An instance of VariableLengthDiscreteDomain.
39length: If provided, crop sequences to this length, otherwise use
40domain.length.
41
42Returns:
43A jax array of shape (batch_size, length) with the encoded sequences.
44"""
45if domain is None:
46domain = data.protein_domain
47
48if length is None:
49length = domain.length
50
51# Encode sequences, mark the end with a single EOS, and pad with PAD.
52batch = domain.encode(string_sequences, pad=False)
53
54max_input_length = max(len(s) for s in string_sequences)
55crop_length = min(max_input_length, length)
56# We perform the padding manually since domain.encode(..., pad=True)
57# uses EOS for padding. We use tf directly rather than seq_utils since
58# the latter performs `pre` truncation.
59batch = [list(elem) + [domain.vocab.eos] for elem in batch]
60batch = tf.keras.preprocessing.sequence.pad_sequences(
61batch, maxlen=crop_length, value=domain.vocab.pad,
62padding='post', truncating='post')
63
64return jnp.asarray(batch)
65
66
67@gin.configurable
68def sum_reducer(embedding, mask):
69"""Returns the sum across the unmasked dimensions.
70
71Args:
72embedding: An array of shape (batch_size, length, emb_size).
73mask: An array of shape (batch_size, length).
74
75Returns:
76An array of shape (batch_size, emb_size).
77"""
78return jnp.sum(embedding * mask[Ellipsis, jnp.newaxis], axis=1)
79
80
81@gin.configurable
82def mean_reducer(embedding, mask):
83"""Returns the mean across the unmasked dimensions.
84
85Args:
86embedding: An array of shape (batch_size, length, emb_size).
87mask: An array of shape (batch_size, length).
88
89Returns:
90An array of shape (batch_size, emb_size).
91"""
92return sum_reducer(embedding, mask) / jnp.sum(mask, axis=-1, keepdims=True)
93
94
95@gin.configurable
96def max_reducer(embedding, mask):
97"""Returns the max across the unmasked dimensions.
98
99Args:
100embedding: An array of shape (batch_size, length, emb_size).
101mask: An array of shape (batch_size, length).
102
103Returns:
104An array of shape (batch_size, emb_size).
105"""
106mask = (-mask + 1) * SMALL_NEGATIVE
107return jnp.max(embedding + mask[Ellipsis, jnp.newaxis], axis=1)
108
109
110@gin.configurable
111def masked_reduce_fn(embedding,
112inputs,
113reducer_fn=mean_reducer,
114domain=None,
115ignore_eos=False,
116ignore_bos=True,
117ignore_pad=True,
118ignore_mask=True):
119"""Takes the mean across the length dimension, ignoring special tokens.
120
121Args:
122embedding: An array of shape (batch_size, length, emb_size).
123inputs: An array of shape (batch_size, length).
124reducer_fn: A callable to perform the reduction given embedding and mask.
125domain: An instance of VariableLengthDiscreteDomain.
126ignore_eos: Whether to ignore EOS tokens.
127ignore_bos: Whether to ignore BOS tokens.
128ignore_pad: Whether to ignore PAD tokens.
129ignore_mask: Whether to ignore MASK tokens.
130
131Returns:
132An array of shape (batch_size, emb_size) with the aggregated embeddings.
133"""
134if domain is None:
135domain = data.protein_domain
136
137mask_tokens = []
138if ignore_eos:
139mask_tokens.append(domain.vocab.eos)
140if ignore_bos:
141mask_tokens.append(domain.vocab.bos)
142if ignore_pad:
143mask_tokens.append(domain.vocab.pad)
144if ignore_mask:
145mask_tokens.append(domain.vocab.mask)
146
147mask = jnp.ones_like(inputs)
148for token in mask_tokens:
149if token is not None:
150mask *= inputs != token
151
152return reducer_fn(embedding, mask)
153
154
155@functools.lru_cache(10)
156def get_embed_fn(model=None,
157checkpoint_dir=None,
158domain=None,
159output_head='output_emb',
160reduce_fn=None,
161length=None):
162"""Get a function that maps sequences to fixed-length embedding vectors.
163
164Args:
165model: A FlaxModel (e.g. FlaxLM or FlaxBERT).
166checkpoint_dir: A string directory where the model checkpoint is stored.
167domain: An instance of VariableLengthDiscreteDomain.
168output_head: Which model output to return. See embed.FlaxModel.
169reduce_fn: Postprocessing function to apply on top of embeddings, such as
170`masked_reduce_fn`. The reduce_fn takes and input padded embeddings
171and padded inputs (to allow masking the pad dimensions). If None, no
172reduction is made.
173length: Input sequences will be cropped and padded to have length
174N = min(max_len, length), where max_len is the length of the longest
175sequence in the input data. If length is None, domain.length is used when
176computing N.
177
178Returns:
179Function which accepts sequences and returns batched embeddings. If the
180the sequences are strings, we first encode them into the domain.
181Otherwise, we assume that they are already encoded.
182"""
183if model is None:
184if checkpoint_dir is None:
185raise ValueError('Must provide a loaded model or checkpoint directory.')
186# Note that this assumes that the model_cls is stored in the config dict.
187model = models.load_model(checkpoint_dir=checkpoint_dir)
188else:
189if checkpoint_dir is not None:
190raise ValueError('Provide only one of `model` or checkpoint directory.')
191
192if domain is None:
193domain = data.protein_domain
194
195def predict_fn(model_target, inputs):
196emb = models.predict_step(
197model_target,
198inputs,
199preprocess_fn=model.preprocess,
200output_head=output_head)
201
202if reduce_fn:
203# Pass the inputs to allow padding-aware aggregation.
204emb = reduce_fn(emb, inputs)
205return emb
206
207if model.pmap:
208p_predict_step = jax.pmap(predict_fn, axis_name='batch')
209else:
210p_predict_step = predict_fn
211
212def _embed(protein_sequences):
213"""Encode proteins into a batch, embed, and run reduce_fn on output."""
214if isinstance(protein_sequences[0], str):
215batch = _encode_string_sequences(protein_sequences,
216domain=domain, length=length)
217else:
218if not domain.are_valid(protein_sequences).any():
219raise ValueError('Input int-encoded sequences are not valid members '
220'of input domain.')
221batch = protein_sequences
222
223if model.pmap:
224batch = common_utils.shard(batch)
225result = p_predict_step(model.optimizer.target, batch)
226
227if model.pmap:
228# Combine the leading two dimensions (ndevices, batch_size / n_devices)
229result = jax.numpy.reshape(result, [-1] + list(result.shape[2:]))
230return result
231
232return _embed
233
234
235@gin.configurable
236class ProteinLMEmbedder(object):
237"""Embeddings from a pretrained language model.
238
239Stateful wrapper around get_embed_fn that calls the embed_fn on batches.
240"""
241
242def __init__(self,
243model=None,
244checkpoint_dir=None,
245domain=None,
246output_head='output_emb',
247reduce_fn=None,
248length=None,
249batch_size=64):
250"""Creates an instance of this class."""
251self._embed_fn = get_embed_fn(
252model=model,
253checkpoint_dir=checkpoint_dir,
254domain=domain,
255output_head=output_head,
256reduce_fn=reduce_fn)
257self._batch_size = batch_size
258self._domain = domain
259self._length = length
260
261def __call__(self, sequences):
262"""Embeds int or string sequences."""
263if isinstance(sequences[0], str):
264sequences = _encode_string_sequences(sequences, domain=self._domain,
265length=self._length)
266return utils.batch_apply(self._embed_fn, sequences, self._batch_size)
267