google-research
453 строки · 14.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Functionally annotate a fasta file.
17
18Write Pfam domain predictions as a TSV with columns
19- sequence_name (string)
20- predicted_label (string)
21- start (int, 1-indexed, inclusive)
22- end (int, 1-indexed, inclusive)
23- label_description (string); a human-readable label description.
24"""
25
26import io
27import json
28import logging
29import os
30from typing import Dict, List, Optional
31
32from absl import app
33from absl import flags
34from Bio.SeqIO import FastaIO
35import pandas as pd
36import tqdm
37
38from protenn import inference_lib
39from protenn import utils
40
41os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # TF c++ logging set to ERROR
42import tensorflow.compat.v1 as tf # pylint: disable=g-import-not-at-top,g-bad-import-order
43
44
45_logger = logging.getLogger('protenn')
46
47
48_INPUT_FASTA_FILE_PATH_FLAG = flags.DEFINE_string(
49'i', None, 'Input fasta file path.'
50)
51_OUTPUT_WRITE_PATH_FLAG = flags.DEFINE_string(
52'o',
53'/dev/stdout',
54'Output write path. Default is to print to the terminal.',
55)
56
57_NUM_ENSEMBLE_ELEMENTS_FLAG = flags.DEFINE_integer(
58'num_ensemble_elements',
591,
60'In order to run with more than one ensemble element, you will need to run '
61'install_models.py --install_ensemble=true. '
62'More ensemble elements takes more time, but tends to be more accurate. '
63'Run-time scales linearly with the number of ensemble elements. '
64'Maximum value of this flag is {}.'.format(
65utils.MAX_NUM_ENSEMBLE_ELS_FOR_INFERENCE
66),
67)
68_MIN_DOMAIN_CALL_LENGTH_FLAG = flags.DEFINE_integer(
69'min_domain_call_length',
7020,
71"Don't consider any domain calls valid that are shorter than this length.",
72)
73_REPORTING_THRESHOLD_FLAG = flags.DEFINE_float(
74'reporting_threshold',
750.025,
76'Number between 0 (exclusive) and 1 (inclusive). Predicted labels with '
77'confidence at least resporting_threshold will be included in the output.',
78lower_bound=1e-30,
79upper_bound=1.0,
80)
81
82_MODEL_CACHE_PATH_FLAG = flags.DEFINE_string(
83'model_cache_path',
84os.path.join(os.path.expanduser('~'), 'cached_models'),
85'Path from which to use downloaded models and metadata.',
86)
87
88# A list of inferrers that all have the same label set.
89_InferrerEnsemble = List[inference_lib.Inferrer]
90
91
92def _gcs_path_to_relative_unzipped_path(p):
93"""Parses GCS path, to gets the last part, and removes .tar.gz."""
94return os.path.join(
95os.path.basename(os.path.normpath(p)).replace('.tar.gz', ''))
96
97
98def _get_inferrer_paths(
99model_urls, model_cache_path
100):
101"""Convert list of model GCS urls to a list of locally cached paths."""
102return [
103os.path.join(model_cache_path, _gcs_path_to_relative_unzipped_path(p))
104for p in model_urls
105]
106
107
108def load_models(
109model_cache_path, num_ensemble_elements
110):
111"""Load models from cache path into inferrerLists.
112
113Args:
114model_cache_path: path that contains downloaded SavedModels and associated
115metadata. Same path that was used when installing the models via
116install_models. switched from list of list of models to just list of model
117num_ensemble_elements: number of ensemble elements of each type to load.
118
119Returns:
120list_of_inferrers
121
122Raises:
123ValueError if the models were not found. The exception message describes
124that install_models.py needs to be rerun.
125"""
126try:
127pfam_inferrer_paths = _get_inferrer_paths(
128utils.OSS_PFAM_ZIPPED_MODELS_URLS, model_cache_path
129)
130
131to_return = []
132for p in tqdm.tqdm(
133pfam_inferrer_paths[:num_ensemble_elements],
134desc='Loading models',
135position=0,
136leave=True,
137dynamic_ncols=True,
138):
139to_return.append(inference_lib.Inferrer(p, use_tqdm=False))
140
141return to_return
142
143except tf.errors.NotFoundError as exc:
144err_msg = 'Unable to find cached models in {}.'.format(model_cache_path)
145if num_ensemble_elements > 1:
146err_msg += (
147' Make sure you have installed the entire ensemble of models by '
148'running\n install_models.py --install_ensemble '
149'--model_cache_path={}'.format(model_cache_path))
150else:
151err_msg += (
152' Make sure you have installed the models by running\n '
153'install_models.py --model_cache_path={}'.format(model_cache_path))
154err_msg += '\nThen try rerunning this script.'
155
156raise ValueError(err_msg) from exc
157
158
159def _assert_fasta_parsable(input_text):
160with io.StringIO(initial_value=input_text) as f:
161fasta_itr = FastaIO.FastaIterator(f)
162end_iteration_sentinel = object()
163
164# Avoid parsing the entire FASTA contents by using `next`.
165# A malformed FASTA file will have no entries in its FastaIterator.
166# This is unfortunate (instead of it throwing an error).
167if next(fasta_itr, end_iteration_sentinel) is end_iteration_sentinel:
168raise ValueError('Failed to parse any input from fasta file. '
169'Consider checking the formatting of your fasta file. '
170'First bit of contents from the fasta file was\n'
171'{}'.format(input_text.splitlines()[:3]))
172
173
174def parse_input_to_text(input_fasta_path):
175"""Parses input fasta file.
176
177Args:
178input_fasta_path: path to FASTA file.
179
180Returns:
181Contents of file as a string.
182
183Raises:
184ValueError if parsing the FASTA file gives no records.
185"""
186_logger.info('Parsing input from %s', input_fasta_path)
187with tf.io.gfile.GFile(input_fasta_path, 'r') as input_file:
188input_text = input_file.read()
189
190_assert_fasta_parsable(input_text=input_text)
191return input_text
192
193
194def input_text_to_df(input_text):
195"""Converts fasta contents to a df with columns sequence_name and sequence."""
196with io.StringIO(initial_value=input_text) as f:
197fasta_records = list(FastaIO.FastaIterator(f))
198fasta_df = pd.DataFrame([(f.name, str(f.seq)) for f in fasta_records],
199columns=['sequence_name', 'sequence'])
200
201return fasta_df
202
203
204def perform_inference(
205input_df,
206models,
207model_cache_path,
208reporting_threshold,
209min_domain_call_length,
210):
211"""Perform inference for Pfam using given models.
212
213Args:
214input_df: pd.DataFrame with columns sequence_name (str) and sequence (str).
215models: list of Pfam inferrers
216model_cache_path: path that contains downloaded SavedModels and associated
217metadata. Same path that was used when installing the models via
218install_models.
219reporting_threshold: report labels with mean confidence across ensemble
220elements that exceeds this threshold.
221min_domain_call_length: don't consider as valid any domain calls shorter
222than this length.
223
224Returns:
225df with columns sequence_name (str), predicted_label (str), start(int),
226end (int), description (str).
227"""
228predictions = inference_lib.get_preds_at_or_above_threshold(
229input_df=input_df,
230inferrer_list=models,
231model_cache_path=model_cache_path,
232reporting_threshold=reporting_threshold,
233min_domain_call_length=min_domain_call_length,
234)
235
236print('\n') # Because the tqdm bar is position 1, we need to print a newline.
237
238to_return_df = []
239for sequence_name, single_seq_predictions in zip(
240input_df.sequence_name, predictions
241):
242for label, (start, end) in single_seq_predictions:
243to_return_df.append({
244'sequence_name': sequence_name,
245'predicted_label': label,
246'start': start,
247'end': end,
248})
249
250return pd.DataFrame(to_return_df)
251
252
253def _sort_df_multiple_columns(df, key):
254"""Sort df based on callable key.
255
256Args:
257df: pd.DataFrame.
258key: function from rows of df (namedtuples) to tuple. This is used in the
259builtin `sorted` method as the key.
260
261Returns:
262A sorted copy of df.
263"""
264# Unpack into list to take advantage of builtin sorted function.
265# Note that pd.DataFrame.sort_values will not work because sort_values'
266# sorting function is applied to each column at a time, whereas we need to
267# consider multiple fields at once.
268df_rows_sorted = sorted(df.itertuples(index=False), key=key)
269return pd.DataFrame(df_rows_sorted, columns=df.columns)
270
271
272def order_df_for_output(predictions_df):
273"""Semantically group/sort predictions df for output.
274
275Sort order:
276Sort by query sequence name as they are in `predictions_df`.
277Sort by start index ascending.
278Given that, sort by description alphabetically.
279
280Args:
281predictions_df: df with columns sequence_name (str), predicted_label (str),
282start(int), end (int), description (str).
283
284Returns:
285df with columns sequence_name (str), predicted_label (str), start(int),
286end (int), description (str).
287"""
288seq_name_to_original_order = {
289item: idx for idx, item in enumerate(predictions_df.sequence_name)
290}
291
292def _orderer_pfam(df_row):
293"""See outer function doctsring."""
294return (
295seq_name_to_original_order[df_row.sequence_name],
296df_row.start,
297df_row.description,
298)
299
300pfam_df_sorted = _sort_df_multiple_columns(predictions_df, _orderer_pfam)
301return pfam_df_sorted
302
303
304def format_df_for_output(
305predictions_df,
306*,
307model_cache_path = None,
308label_to_description = None,
309):
310"""Formats df for outputting.
311
312Args:
313predictions_df: df with columns sequence_name (str), predicted_label (str),
314start (int), end (int).
315model_cache_path: path that contains downloaded SavedModels and associated
316metadata. Same path that was used when installing the models via
317install_models.
318label_to_description: contents of label_descriptions.json.gz. Map from label
319to a human-readable description.
320
321Returns:
322df with columns sequence_name (str), predicted_label (str), start(int),
323end (int), description (str).
324"""
325predictions_df = predictions_df.copy()
326
327if label_to_description is None:
328with tf.io.gfile.GFile(
329os.path.join(model_cache_path, 'accession_to_description_pfam_35.json')
330) as f:
331label_to_description = json.loads(f.read())
332
333predictions_df['description'] = predictions_df.predicted_label.apply(
334label_to_description.__getitem__
335)
336
337return order_df_for_output(predictions_df)
338
339
340def write_output(predictions_df, output_path):
341"""Write predictions_df to tsv file."""
342_logger.info('Writing output to %s', output_path)
343with tf.io.gfile.GFile(output_path, 'w') as f:
344predictions_df.to_csv(f, sep='\t', index=False)
345
346
347def run(
348input_text,
349models,
350reporting_threshold,
351label_to_description,
352model_cache_path,
353min_domain_call_length,
354):
355"""Runs inference and returns output as a pd.DataFrame.
356
357Args:
358input_text: contents of a fasta file.
359models: List of Pfam inferrers.
360reporting_threshold: report labels with mean confidence across ensemble
361elements that exceeds this threshold.
362label_to_description: contents of label_descriptions.json.gz. Map from label
363to a human-readable description.
364model_cache_path: path that contains downloaded SavedModels and associated
365metadata. Same path that was used when installing the models via
366install_models.
367min_domain_call_length: don't consider as valid any domain calls shorter
368than this length.
369
370Returns:
371df with columns sequence_name (str), predicted_label (str), start(int), end
372(int), description (str).
373"""
374input_df = input_text_to_df(input_text)
375predictions_df = perform_inference(
376input_df=input_df,
377models=models,
378model_cache_path=model_cache_path,
379reporting_threshold=reporting_threshold,
380min_domain_call_length=min_domain_call_length,
381)
382
383predictions_df = format_df_for_output(
384predictions_df=predictions_df,
385label_to_description=label_to_description,
386model_cache_path=model_cache_path,
387)
388
389return predictions_df
390
391
392def load_assets_and_run(
393input_fasta_path,
394output_path,
395num_ensemble_elements,
396model_cache_path,
397reporting_threshold,
398min_domain_call_length,
399):
400"""Loads models/metadata, runs inference, and writes output to tsv file.
401
402Args:
403input_fasta_path: path to FASTA file.
404output_path: path to which to write a tsv of inference results.
405num_ensemble_elements: Number of ensemble elements to load and perform
406inference with.
407model_cache_path: path that contains downloaded SavedModels and associated
408metadata. Same path that was used when installing the models via
409install_models.
410reporting_threshold: report labels with mean confidence across ensemble
411elements that exceeds this threshold.
412min_domain_call_length: don't consider as valid any domain calls shorter
413than this length.
414"""
415_logger.info('Running with %d ensemble elements', num_ensemble_elements)
416input_text = parse_input_to_text(input_fasta_path)
417
418models = load_models(model_cache_path, num_ensemble_elements)
419with open(
420os.path.join(model_cache_path, 'accession_to_description_pfam_35.json')
421) as f:
422label_to_description = json.loads(f.read())
423
424predictions_df = run(
425input_text,
426models,
427reporting_threshold,
428label_to_description,
429model_cache_path=model_cache_path,
430min_domain_call_length=min_domain_call_length,
431)
432write_output(predictions_df, output_path)
433
434
435def main(_):
436# TF logging is too noisy otherwise.
437tf.get_logger().setLevel(tf.logging.ERROR)
438
439load_assets_and_run(
440input_fasta_path=_INPUT_FASTA_FILE_PATH_FLAG.value,
441output_path=_OUTPUT_WRITE_PATH_FLAG.value,
442num_ensemble_elements=_NUM_ENSEMBLE_ELEMENTS_FLAG.value,
443model_cache_path=_MODEL_CACHE_PATH_FLAG.value,
444reporting_threshold=_REPORTING_THRESHOLD_FLAG.value,
445min_domain_call_length=_MIN_DOMAIN_CALL_LENGTH_FLAG.value,
446)
447
448
449if __name__ == '__main__':
450_logger.info('Process started.')
451flags.mark_flags_as_required(['i'])
452
453app.run(main)
454