google-research

Форк
0
453 строки · 14.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Functionally annotate a fasta file.
17

18
Write Pfam domain predictions as a TSV with columns
19
- sequence_name (string)
20
- predicted_label (string)
21
- start (int, 1-indexed, inclusive)
22
- end (int, 1-indexed, inclusive)
23
- label_description (string); a human-readable label description.
24
"""
25

26
import io
27
import json
28
import logging
29
import os
30
from typing import Dict, List, Optional
31

32
from absl import app
33
from absl import flags
34
from Bio.SeqIO import FastaIO
35
import pandas as pd
36
import tqdm
37

38
from protenn import inference_lib
39
from protenn import utils
40

41
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # TF c++ logging set to ERROR
42
import tensorflow.compat.v1 as tf  # pylint: disable=g-import-not-at-top,g-bad-import-order
43

44

45
_logger = logging.getLogger('protenn')
46

47

48
_INPUT_FASTA_FILE_PATH_FLAG = flags.DEFINE_string(
49
    'i', None, 'Input fasta file path.'
50
)
51
_OUTPUT_WRITE_PATH_FLAG = flags.DEFINE_string(
52
    'o',
53
    '/dev/stdout',
54
    'Output write path. Default is to print to the terminal.',
55
)
56

57
_NUM_ENSEMBLE_ELEMENTS_FLAG = flags.DEFINE_integer(
58
    'num_ensemble_elements',
59
    1,
60
    'In order to run with more than one ensemble element, you will need to run '
61
    'install_models.py --install_ensemble=true. '
62
    'More ensemble elements takes more time, but tends to be more accurate. '
63
    'Run-time scales linearly with the number of ensemble elements. '
64
    'Maximum value of this flag is {}.'.format(
65
        utils.MAX_NUM_ENSEMBLE_ELS_FOR_INFERENCE
66
    ),
67
)
68
_MIN_DOMAIN_CALL_LENGTH_FLAG = flags.DEFINE_integer(
69
    'min_domain_call_length',
70
    20,
71
    "Don't consider any domain calls valid that are shorter than this length.",
72
)
73
_REPORTING_THRESHOLD_FLAG = flags.DEFINE_float(
74
    'reporting_threshold',
75
    0.025,
76
    'Number between 0 (exclusive) and 1 (inclusive). Predicted labels with '
77
    'confidence at least resporting_threshold will be included in the output.',
78
    lower_bound=1e-30,
79
    upper_bound=1.0,
80
)
81

82
_MODEL_CACHE_PATH_FLAG = flags.DEFINE_string(
83
    'model_cache_path',
84
    os.path.join(os.path.expanduser('~'), 'cached_models'),
85
    'Path from which to use downloaded models and metadata.',
86
)
87

88
# A list of inferrers that all have the same label set.
89
_InferrerEnsemble = List[inference_lib.Inferrer]
90

91

92
def _gcs_path_to_relative_unzipped_path(p):
93
  """Parses GCS path, to gets the last part, and removes .tar.gz."""
94
  return os.path.join(
95
      os.path.basename(os.path.normpath(p)).replace('.tar.gz', ''))
96

97

98
def _get_inferrer_paths(
99
    model_urls, model_cache_path
100
):
101
  """Convert list of model GCS urls to a list of locally cached paths."""
102
  return [
103
      os.path.join(model_cache_path, _gcs_path_to_relative_unzipped_path(p))
104
      for p in model_urls
105
  ]
106

107

108
def load_models(
109
    model_cache_path, num_ensemble_elements
110
):
111
  """Load models from cache path into inferrerLists.
112

113
  Args:
114
    model_cache_path: path that contains downloaded SavedModels and associated
115
      metadata. Same path that was used when installing the models via
116
      install_models. switched from list of list of models to just list of model
117
    num_ensemble_elements: number of ensemble elements of each type to load.
118

119
  Returns:
120
    list_of_inferrers
121

122
  Raises:
123
    ValueError if the models were not found. The exception message describes
124
    that install_models.py needs to be rerun.
125
  """
126
  try:
127
    pfam_inferrer_paths = _get_inferrer_paths(
128
        utils.OSS_PFAM_ZIPPED_MODELS_URLS, model_cache_path
129
    )
130

131
    to_return = []
132
    for p in tqdm.tqdm(
133
        pfam_inferrer_paths[:num_ensemble_elements],
134
        desc='Loading models',
135
        position=0,
136
        leave=True,
137
        dynamic_ncols=True,
138
    ):
139
      to_return.append(inference_lib.Inferrer(p, use_tqdm=False))
140

141
    return to_return
142

143
  except tf.errors.NotFoundError as exc:
144
    err_msg = 'Unable to find cached models in {}.'.format(model_cache_path)
145
    if num_ensemble_elements > 1:
146
      err_msg += (
147
          ' Make sure you have installed the entire ensemble of models by '
148
          'running\n    install_models.py --install_ensemble '
149
          '--model_cache_path={}'.format(model_cache_path))
150
    else:
151
      err_msg += (
152
          ' Make sure you have installed the models by running\n    '
153
          'install_models.py --model_cache_path={}'.format(model_cache_path))
154
    err_msg += '\nThen try rerunning this script.'
155

156
    raise ValueError(err_msg) from exc
157

158

159
def _assert_fasta_parsable(input_text):
160
  with io.StringIO(initial_value=input_text) as f:
161
    fasta_itr = FastaIO.FastaIterator(f)
162
    end_iteration_sentinel = object()
163

164
    # Avoid parsing the entire FASTA contents by using `next`.
165
    # A malformed FASTA file will have no entries in its FastaIterator.
166
    # This is unfortunate (instead of it throwing an error).
167
    if next(fasta_itr, end_iteration_sentinel) is end_iteration_sentinel:
168
      raise ValueError('Failed to parse any input from fasta file. '
169
                       'Consider checking the formatting of your fasta file. '
170
                       'First bit of contents from the fasta file was\n'
171
                       '{}'.format(input_text.splitlines()[:3]))
172

173

174
def parse_input_to_text(input_fasta_path):
175
  """Parses input fasta file.
176

177
  Args:
178
    input_fasta_path: path to FASTA file.
179

180
  Returns:
181
    Contents of file as a string.
182

183
  Raises:
184
    ValueError if parsing the FASTA file gives no records.
185
  """
186
  _logger.info('Parsing input from %s', input_fasta_path)
187
  with tf.io.gfile.GFile(input_fasta_path, 'r') as input_file:
188
    input_text = input_file.read()
189

190
  _assert_fasta_parsable(input_text=input_text)
191
  return input_text
192

193

194
def input_text_to_df(input_text):
195
  """Converts fasta contents to a df with columns sequence_name and sequence."""
196
  with io.StringIO(initial_value=input_text) as f:
197
    fasta_records = list(FastaIO.FastaIterator(f))
198
    fasta_df = pd.DataFrame([(f.name, str(f.seq)) for f in fasta_records],
199
                            columns=['sequence_name', 'sequence'])
200

201
  return fasta_df
202

203

204
def perform_inference(
205
    input_df,
206
    models,
207
    model_cache_path,
208
    reporting_threshold,
209
    min_domain_call_length,
210
):
211
  """Perform inference for Pfam using given models.
212

213
  Args:
214
    input_df: pd.DataFrame with columns sequence_name (str) and sequence (str).
215
    models: list of Pfam inferrers
216
    model_cache_path: path that contains downloaded SavedModels and associated
217
      metadata. Same path that was used when installing the models via
218
      install_models.
219
    reporting_threshold: report labels with mean confidence across ensemble
220
      elements that exceeds this threshold.
221
    min_domain_call_length: don't consider as valid any domain calls shorter
222
      than this length.
223

224
  Returns:
225
    df with columns sequence_name (str), predicted_label (str), start(int),
226
    end (int), description (str).
227
  """
228
  predictions = inference_lib.get_preds_at_or_above_threshold(
229
      input_df=input_df,
230
      inferrer_list=models,
231
      model_cache_path=model_cache_path,
232
      reporting_threshold=reporting_threshold,
233
      min_domain_call_length=min_domain_call_length,
234
  )
235

236
  print('\n')  # Because the tqdm bar is position 1, we need to print a newline.
237

238
  to_return_df = []
239
  for sequence_name, single_seq_predictions in zip(
240
      input_df.sequence_name, predictions
241
  ):
242
    for label, (start, end) in single_seq_predictions:
243
      to_return_df.append({
244
          'sequence_name': sequence_name,
245
          'predicted_label': label,
246
          'start': start,
247
          'end': end,
248
      })
249

250
  return pd.DataFrame(to_return_df)
251

252

253
def _sort_df_multiple_columns(df, key):
254
  """Sort df based on callable key.
255

256
  Args:
257
    df: pd.DataFrame.
258
    key: function from rows of df (namedtuples) to tuple. This is used in the
259
      builtin `sorted` method as the key.
260

261
  Returns:
262
    A sorted copy of df.
263
  """
264
  # Unpack into list to take advantage of builtin sorted function.
265
  # Note that pd.DataFrame.sort_values will not work because sort_values'
266
  # sorting function is applied to each column at a time, whereas we need to
267
  # consider multiple fields at once.
268
  df_rows_sorted = sorted(df.itertuples(index=False), key=key)
269
  return pd.DataFrame(df_rows_sorted, columns=df.columns)
270

271

272
def order_df_for_output(predictions_df):
273
  """Semantically group/sort predictions df for output.
274

275
  Sort order:
276
  Sort by query sequence name as they are in `predictions_df`.
277
  Sort by start index ascending.
278
  Given that, sort by description alphabetically.
279

280
  Args:
281
    predictions_df: df with columns sequence_name (str), predicted_label (str),
282
      start(int), end (int), description (str).
283

284
  Returns:
285
    df with columns sequence_name (str), predicted_label (str), start(int),
286
    end (int), description (str).
287
  """
288
  seq_name_to_original_order = {
289
      item: idx for idx, item in enumerate(predictions_df.sequence_name)
290
  }
291

292
  def _orderer_pfam(df_row):
293
    """See outer function doctsring."""
294
    return (
295
        seq_name_to_original_order[df_row.sequence_name],
296
        df_row.start,
297
        df_row.description,
298
    )
299

300
  pfam_df_sorted = _sort_df_multiple_columns(predictions_df, _orderer_pfam)
301
  return pfam_df_sorted
302

303

304
def format_df_for_output(
305
    predictions_df,
306
    *,
307
    model_cache_path = None,
308
    label_to_description = None,
309
):
310
  """Formats df for outputting.
311

312
  Args:
313
    predictions_df: df with columns sequence_name (str), predicted_label (str),
314
      start (int), end (int).
315
    model_cache_path: path that contains downloaded SavedModels and associated
316
      metadata. Same path that was used when installing the models via
317
      install_models.
318
    label_to_description: contents of label_descriptions.json.gz. Map from label
319
      to a human-readable description.
320

321
  Returns:
322
    df with columns sequence_name (str), predicted_label (str), start(int),
323
    end (int), description (str).
324
  """
325
  predictions_df = predictions_df.copy()
326

327
  if label_to_description is None:
328
    with tf.io.gfile.GFile(
329
        os.path.join(model_cache_path, 'accession_to_description_pfam_35.json')
330
    ) as f:
331
      label_to_description = json.loads(f.read())
332

333
  predictions_df['description'] = predictions_df.predicted_label.apply(
334
      label_to_description.__getitem__
335
  )
336

337
  return order_df_for_output(predictions_df)
338

339

340
def write_output(predictions_df, output_path):
341
  """Write predictions_df to tsv file."""
342
  _logger.info('Writing output to %s', output_path)
343
  with tf.io.gfile.GFile(output_path, 'w') as f:
344
    predictions_df.to_csv(f, sep='\t', index=False)
345

346

347
def run(
348
    input_text,
349
    models,
350
    reporting_threshold,
351
    label_to_description,
352
    model_cache_path,
353
    min_domain_call_length,
354
):
355
  """Runs inference and returns output as a pd.DataFrame.
356

357
  Args:
358
    input_text: contents of a fasta file.
359
    models: List of Pfam inferrers.
360
    reporting_threshold: report labels with mean confidence across ensemble
361
      elements that exceeds this threshold.
362
    label_to_description: contents of label_descriptions.json.gz. Map from label
363
      to a human-readable description.
364
    model_cache_path: path that contains downloaded SavedModels and associated
365
      metadata. Same path that was used when installing the models via
366
      install_models.
367
    min_domain_call_length: don't consider as valid any domain calls shorter
368
      than this length.
369

370
  Returns:
371
    df with columns sequence_name (str), predicted_label (str), start(int), end
372
    (int), description (str).
373
  """
374
  input_df = input_text_to_df(input_text)
375
  predictions_df = perform_inference(
376
      input_df=input_df,
377
      models=models,
378
      model_cache_path=model_cache_path,
379
      reporting_threshold=reporting_threshold,
380
      min_domain_call_length=min_domain_call_length,
381
  )
382

383
  predictions_df = format_df_for_output(
384
      predictions_df=predictions_df,
385
      label_to_description=label_to_description,
386
      model_cache_path=model_cache_path,
387
  )
388

389
  return predictions_df
390

391

392
def load_assets_and_run(
393
    input_fasta_path,
394
    output_path,
395
    num_ensemble_elements,
396
    model_cache_path,
397
    reporting_threshold,
398
    min_domain_call_length,
399
):
400
  """Loads models/metadata, runs inference, and writes output to tsv file.
401

402
  Args:
403
    input_fasta_path: path to FASTA file.
404
    output_path: path to which to write a tsv of inference results.
405
    num_ensemble_elements: Number of ensemble elements to load and perform
406
      inference with.
407
    model_cache_path: path that contains downloaded SavedModels and associated
408
      metadata. Same path that was used when installing the models via
409
      install_models.
410
    reporting_threshold: report labels with mean confidence across ensemble
411
      elements that exceeds this threshold.
412
    min_domain_call_length: don't consider as valid any domain calls shorter
413
      than this length.
414
  """
415
  _logger.info('Running with %d ensemble elements', num_ensemble_elements)
416
  input_text = parse_input_to_text(input_fasta_path)
417

418
  models = load_models(model_cache_path, num_ensemble_elements)
419
  with open(
420
      os.path.join(model_cache_path, 'accession_to_description_pfam_35.json')
421
  ) as f:
422
    label_to_description = json.loads(f.read())
423

424
  predictions_df = run(
425
      input_text,
426
      models,
427
      reporting_threshold,
428
      label_to_description,
429
      model_cache_path=model_cache_path,
430
      min_domain_call_length=min_domain_call_length,
431
  )
432
  write_output(predictions_df, output_path)
433

434

435
def main(_):
436
  # TF logging is too noisy otherwise.
437
  tf.get_logger().setLevel(tf.logging.ERROR)
438

439
  load_assets_and_run(
440
      input_fasta_path=_INPUT_FASTA_FILE_PATH_FLAG.value,
441
      output_path=_OUTPUT_WRITE_PATH_FLAG.value,
442
      num_ensemble_elements=_NUM_ENSEMBLE_ELEMENTS_FLAG.value,
443
      model_cache_path=_MODEL_CACHE_PATH_FLAG.value,
444
      reporting_threshold=_REPORTING_THRESHOLD_FLAG.value,
445
      min_domain_call_length=_MIN_DOMAIN_CALL_LENGTH_FLAG.value,
446
  )
447

448

449
if __name__ == '__main__':
450
  _logger.info('Process started.')
451
  flags.mark_flags_as_required(['i'])
452

453
  app.run(main)
454

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.