google-research

Форк
0
894 строки · 34.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Pipelines to generate datasets for the alignment and homology tasks."""
17

18
import collections
19
import functools
20
import itertools
21
import random
22
from typing import Callable, Dict, Iterable, Iterator, Optional
23

24
import apache_beam as beam
25
import numpy as np
26
import tensorflow as tf
27

28
from dedal.preprocessing import alignment
29
from dedal.preprocessing import schemas
30
from dedal.preprocessing import schemas_lib
31
from dedal.preprocessing import types
32
from dedal.preprocessing import utils
33

34

35
# Type aliases
36
Record = types.Record
37
Array = np.ndarray
38
PRNG = random.Random
39

40

41
# Constants
42
ALIGNMENT_FIELDS = (
43
    'key', 'pfam_acc', 'clan_acc', 'seq_start', 'seq_end', 'passed_qc',
44
    'sequence', 'gapped_sequence',
45
    'hmm_hit_seq_start', 'hmm_hit_seq_end', 'hmm_hit_clan_acc',
46
    'other_hit_seq_start', 'other_hit_seq_end', 'other_hit_type_id')
47
OTHER_REGIONS = (
48
    'coiled_coil', 'disorder', 'low_complexity', 'sig_p', 'transmembrane')
49
CONFOUNDING_REGIONS = ('clans',)
50
SPLITS = ('train', 'iid_validation', 'ood_validation', 'iid_test', 'ood_test')
51
PREFIXES = ('n', 'c')  # Indices of the two flanks (N-terminus, C-terminus).
52
SUFFIXES = ('x', 'y')  # Indices of the two regions in the pairs.
53
AA_PROB_TABLE = {'A': 0.0832177442683677,
54
                 'C': 0.013846953304580488,
55
                 'D': 0.05746458363960445,
56
                 'E': 0.0662881179936497,
57
                 'F': 0.03796971998594428,
58
                 'G': 0.06893382361885941,
59
                 'H': 0.021288200487753935,
60
                 'I': 0.05563140763547133,
61
                 'K': 0.05514607272883951,
62
                 'L': 0.09466292698151958,
63
                 'M': 0.021795269979457313,
64
                 'N': 0.042539379191607996,
65
                 'P': 0.04820225072073993,
66
                 'Q': 0.0397779223030366,
67
                 'R': 0.05446797313841446,
68
                 'S': 0.07095409709482754,
69
                 'T': 0.05849503653768657,
70
                 'V': 0.06729418111891682,
71
                 'W': 0.011878265406494127,
72
                 'Y': 0.030146073864228278}
73

74

75
def load_pfam_accs(path):
76
  """Returns a mapping from Pfam accessions to unique integer indices.
77

78
  Args:
79
    path: The path to a plain text file containing one Pfam accession per line.
80

81
  Returns:
82
  A dictionary mapping Pfam accessions to integer-valued identifiers numbered
83
  between 0 (inclusive) and the total number of distinct Pfam accessions
84
  (exclusive).
85
  """
86
  pfam_acc_to_index = {}
87
  with tf.io.gfile.GFile(path, 'r') as f:
88
    for i, line in enumerate(f):
89
      pfam_acc = line.strip()
90
      if pfam_acc in pfam_acc_to_index:
91
        raise ValueError(f'Key {pfam_acc} is duplicated in {path}.')
92
      pfam_acc_to_index[pfam_acc] = i
93
  return pfam_acc_to_index
94

95

96
class ReadParsedPfamData(beam.PTransform):
97
  """Reads TSV data generated by Step 1. of the preprocessing pipeline."""
98

99
  def __init__(
100
      self,
101
      file_pattern,
102
      dataset_splits_path,
103
      fields_to_keep = None,
104
      with_flank_seeds = False,
105
      max_len = None,
106
      filter_by_qc = True,
107
  ):
108
    self.file_pattern = file_pattern
109
    self.dataset_splits_path = dataset_splits_path
110
    self.fields_to_keep = fields_to_keep
111
    self.with_flank_seeds = with_flank_seeds
112
    self.max_len = max_len
113
    self.filter_by_qc = filter_by_qc
114
    self._schema_cls = (schemas.ExtendedParsedPfamRow if self.with_flank_seeds
115
                        else schemas.ParsedPfamRow)
116

117
  def expand(
118
      self,
119
      root,
120
  ):
121
    read_alignment_data_cls = functools.partial(
122
        schemas_lib.ReadFromTable,
123
        schema_cls=self._schema_cls,
124
        key_field='key',
125
        skip_header_lines=1,
126
        fields_to_keep=self.fields_to_keep)
127

128
    read_dataset_splits_cls = functools.partial(
129
        schemas_lib.ReadFromTable,
130
        schema_cls=schemas.DatasetSplits,
131
        key_field='key',
132
        skip_header_lines=1,
133
        fields_to_keep=('split',))
134

135
    # Reads Pfam data from the preceding pipeline. Each element in the output
136
    # `PCollection` represents a different Pfam region.
137
    regions = root | 'ReadRegions' >> read_alignment_data_cls(self.file_pattern)
138

139
    # Optionally, removes any Pfam regions that did not pass the quality control
140
    # checks from the `PCollection`.
141
    if self.filter_by_qc:
142
      regions = regions | 'QCFilter' >> beam.Filter(lambda x: x[1]['passed_qc'])
143

144
    # Optionally, drops Pfam regions that do not pass the maximum region length
145
    # filter.
146
    if self.max_len is not None:
147
      len_fn = lambda x: x[1]['seq_end'] - x[1]['seq_start'] + 1 <= self.max_len
148
      regions = regions | 'LengthFilter' >> beam.Filter(len_fn)
149

150
    # Reads mapping Pfam region key: split.
151
    dataset_splits = (
152
        root
153
        | 'ReadDatasetSplits' >> read_dataset_splits_cls(
154
            self.dataset_splits_path))
155
    # Merges split info from `dataset_splits` into elements `regions` and
156
    # removes the key of each element, which is no longer needed after merging
157
    # the two `PCollection`s.
158
    return (
159
        {'pfam_regions': regions, 'dataset_splits': dataset_splits}
160
        | 'MergeSplitInfoByKey' >> schemas_lib.JoinTables(
161
            left_join_tables=['pfam_regions'])
162
        | 'RemoveKey' >> beam.Values())
163

164

165
def get_prng(record, global_seed, field_name = 'key'):
166
  """Generates a per-example pair reproducible PRNG key."""
167
  return random.Random(
168
      hash(tuple(record[f'{field_name}_{s}'] for s in SUFFIXES)) + global_seed)
169

170

171
def sample_flank_lengths(
172
    rng,
173
    seq_start,
174
    seq_end,
175
    max_len,
176
):
177
  """Samples length of the N-terminus and C-terminus flanks at random."""
178
  region_len = seq_end - seq_start + 1  # Endpoints are inclusive.
179
  max_ctx_len = max_len - region_len
180
  max_n_len = max_c_len = max_ctx_len
181
  ctx_len = rng.randint(0, max_ctx_len)
182
  n_len = rng.randint(max(ctx_len - max_c_len, 0), min(max_n_len, ctx_len))
183
  c_len = ctx_len - n_len
184
  return {'n_len': n_len, 'c_len': c_len}
185

186

187
def validate_flank_seeds(
188
    region_pair,
189
    indices,
190
    extra_margin = 0,
191
    min_overlap = 1,
192
):
193
  """Tests if a combination of UniProt flanks is valid for a region pair."""
194
  # First, verifies that all four flank seeds are non-empty if a flank needs to
195
  # be generated.
196
  for p, s in itertools.product(PREFIXES, SUFFIXES):
197
    flank_len = region_pair[f'{p}_len_{s}']
198
    idx = indices[f'{p}_{s}']
199
    key = region_pair[f'{p}_flank_seed_key_{idx}_{s}']
200
    sequence = region_pair[f'{p}_flank_seed_sequence_{idx}_{s}']
201
    if flank_len and (not key or not sequence):
202
      return False
203

204
  # Second, retrieves the collection of hmm_hits and other_hits in each flank
205
  # and checks that there are no shared annotations between the flanks of each
206
  # sequence.
207
  flank_hits = collections.defaultdict(set)
208
  for ann_type in ('hmm_hit_clan_acc',):
209
    for p, s in itertools.product(PREFIXES, SUFFIXES):
210
      idx = indices[f'{p}_{s}']
211
      hits = region_pair[f'{p}_flank_seed_{ann_type}_{idx}_{s}']
212
      flank_hits[f'{ann_type}_{s}'] |= set(hits)
213
    if set.intersection(*[flank_hits[f'{ann_type}_{s}'] for s in SUFFIXES]):
214
      return False
215

216
  # Finally, tests if there are shared hmm_hit annotations between the flanks
217
  # of one sequence and the main region of the other sequence.
218
  for s1, s2 in zip(SUFFIXES, reversed(SUFFIXES)):
219
    # TODO(fllinares): pre-compute `hmm_hits`, perhaps at the cost of clarity.
220
    overlaps = interval_overlaps(
221
        start=region_pair[f'seq_start_{s1}'] - extra_margin,
222
        end=region_pair[f'seq_end_{s1}'] + extra_margin,
223
        ref_starts=np.asarray(region_pair[f'hmm_hit_seq_start_{s1}']),
224
        ref_ends=np.asarray(region_pair[f'hmm_hit_seq_end_{s1}']))
225
    hit_ids = np.asarray(region_pair[f'hmm_hit_clan_acc_{s1}'])
226
    hmm_hits = set(hit_ids[overlaps >= min_overlap])
227
    hmm_hits.add(region_pair[f'clan_acc_{s1}'])  # Likely redundant.
228
    if hmm_hits & flank_hits[f'hmm_hit_clan_acc_{s2}']:
229
      return False
230

231
  return True
232

233

234
def pair_flank_seeds(
235
    rng,
236
    region_pair,
237
    extra_margin = 0,
238
    min_overlap = 1,
239
):
240
  """Searches from a valid combination of UniProt flanks for the region pair."""
241
  num_flank_seeds = schemas.NUM_FLANK_SEEDS
242
  idx_keys = tuple(f'{p}_{s}' for p, s in itertools.product(PREFIXES, SUFFIXES))
243
  # Iterates over all possible flank seed pairings in a random order.
244
  for indices in itertools.product(
245
      rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds),
246
      rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds),
247
      rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds),
248
      rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds)):
249
    indices = dict(zip(idx_keys, indices))
250
    if validate_flank_seeds(region_pair, indices, extra_margin, min_overlap):
251
      for p, s in itertools.product(PREFIXES, SUFFIXES):
252
        if region_pair[f'{p}_len_{s}']:  # Skips empty flanks.
253
          region_pair[f'{p}_flank_seed_idx_{s}'] = indices[f'{p}_{s}']
254
      break
255
  return region_pair
256

257

258
def sample_synthetic_flank(rng, length):
259
  """Samples `length` chars from `AA_PROB_TABLE` i.i.d.."""
260
  amino_acids = list(AA_PROB_TABLE.keys())
261
  probabilities = list(AA_PROB_TABLE.values())
262
  return ''.join(rng.choices(amino_acids, probabilities, k=length))
263

264

265
def generate_flanks(
266
    rng,
267
    region_pair,
268
    flanks,
269
    max_len,
270
    extra_margin = 0,
271
    min_overlap = 1,
272
):
273
  """Extends a pair of Pfam domains adding N and C-terminus flanks."""
274
  # Samples the lengths of the N-terminus and C-terminus flanks, adding the
275
  # resulting variables to `region_pair`.
276
  for s in SUFFIXES:
277
    out = sample_flank_lengths(
278
        rng=rng,
279
        seq_start=region_pair[f'seq_start_{s}'],
280
        seq_end=region_pair[f'seq_end_{s}'],
281
        max_len=max_len)
282
    region_pair.update({f'{k}_{s}': v for k, v in out.items()})
283

284
  # When using flanks from UniProt, the choice of which flank "seeds" to pick
285
  # for the N-terminus and C-terminus ends for each of the two regions must be
286
  # done jointly. In contrast, the rest of the flank generation pipeline can be
287
  # done independently for each sequence and flank.
288
  if flanks == 'uniprot':
289
    region_pair = pair_flank_seeds(rng, region_pair, extra_margin, min_overlap)
290

291
  # Processes each of the two regions, `sequence_x` and `sequence_y`,
292
  # independently.
293
  for s in SUFFIXES:
294
    seq_start = region_pair[f'seq_start_{s}']
295
    seq_end = region_pair[f'seq_end_{s}']
296
    # Generates the N-terminus and C-terminus flanks for `sequence_{s}`.
297
    for p in PREFIXES:
298
      # Synthetic flanks are obtained by randomly sampling amino acids from the
299
      # original sequence with replacement, independently for each flank.
300
      if flanks == 'synthetic':
301
        flank_acc = 'synth'
302
        flank_len = region_pair[f'{p}_len_{s}']
303
        # Computes the (inclusive) endpoints of the N-terminus flank. Note that
304
        # `flank_start` could be negative.
305
        if p == PREFIXES[0]:
306
          flank_start = seq_start - flank_len
307
          flank_end = seq_start - 1
308
        # Computes the (inclusive) endpoints of the C-terminus flank. Note that
309
        # `flank_end` could be larger than the length of the original sequence.
310
        else:
311
          flank_start = seq_end + 1
312
          flank_end = seq_end + flank_len
313
        region_pair[f'{p}_flank_{s}'] = sample_synthetic_flank(
314
            rng=rng, length=flank_len)
315
      # Uniprot flanks are obtained by (brute-force) searching for a combination
316
      # of UniPort (sub)sequences that satisfy all of the "quality control"
317
      # criteria in `validate_flank_seeds`. Note that, since each sequence has
318
      # only a finite number of precomputed flank "seeds" available (for
319
      # tractability), there is a small but non-zero probability that no flank
320
      # combination is valid. In these (rare) cases, no flanks are added.
321
      elif flanks == 'uniprot':
322
        idx = region_pair.get(f'{p}_flank_seed_idx_{s}', None)
323
        if idx is not None:  # A valid flank combination was found.
324
          flank_key = region_pair[f'{p}_flank_seed_key_{idx}_{s}']
325
          flank_seq = region_pair[f'{p}_flank_seed_sequence_{idx}_{s}']
326

327
          flank_acc, flank_endpoints = flank_key.split('/')
328
          flank_start, flank_end = [int(x) for x in flank_endpoints.split('-')]
329
          assert len(flank_seq) == (flank_end - flank_start + 1)
330

331
          offset = rng.randint(0, len(flank_seq) - region_pair[f'{p}_len_{s}'])
332
          flank_start += offset
333
          flank_end = flank_start + region_pair[f'{p}_len_{s}'] - 1
334

335
          region_pair[f'{p}_flank_{s}'] = flank_seq[
336
              offset:offset + region_pair[f'{p}_len_{s}']]
337
        else:  # No valid flank combination was found.
338
          # Marks flank as empty.
339
          flank_start = 0
340
          flank_end = -1
341
      # Unrecognized.
342
      else:
343
        raise ValueError(
344
            f"flanks must be 'synthetic' or 'uniprot'. Got {flanks} instead.")
345

346
      # Sets a flank key for inspectability purposes if the flank is not empty.
347
      if flank_start <= flank_end:
348
        region_pair[f'{p}_key_{s}'] = f'{flank_acc}/{flank_start}-{flank_end}'
349
      else:
350
        region_pair[f'{p}_key_{s}'] = ''
351

352
  return region_pair
353

354

355
def extend_sequences(region_pair):
356
  """Optionally, extends domain sequences with N and C-terminus flanks."""
357
  for s in SUFFIXES:
358
    seq_start = region_pair[f'seq_start_{s}']
359
    seq_end = region_pair[f'seq_end_{s}']
360
    n_flank = region_pair.get(f'n_flank_{s}', '')
361
    c_flank = region_pair.get(f'c_flank_{s}', '')
362
    sequence = region_pair[f'sequence_{s}']
363

364
    region_pair[f'sequence_{s}'] = ''.join(
365
        [n_flank, sequence[seq_start - 1:seq_end], c_flank])
366

367
    if f'ali_start_{s}' in region_pair:
368
      region_pair[f'ali_start_{s}'] += len(n_flank)
369

370
    # Backs up the original sequence.
371
    region_pair[f'original_sequence_{s}'] = sequence
372

373
  return region_pair
374

375

376
def interval_overlaps(
377
    start,
378
    end,
379
    ref_starts,
380
    ref_ends,
381
):
382
  """Computes the overlap between closed intervals."""
383
  overlaps = np.minimum(end, ref_ends) - np.maximum(start, ref_starts) + 1
384
  return np.maximum(0, overlaps)
385

386

387
def annotate_regions(
388
    region_pair,
389
    extra_margin = 0,
390
    min_overlap = 1,
391
):
392
  """Finds any annotations that overlap with the regions in `region_pair`."""
393
  for ann_type, id_name in zip(('hmm_hit', 'other_hit'),
394
                               ('clan_acc', 'type_id')):
395
    for s in SUFFIXES:
396
      hit_start = np.asarray(region_pair[f'{ann_type}_seq_start_{s}'])
397
      hit_end = np.asarray(region_pair[f'{ann_type}_seq_end_{s}'])
398
      hit_ids = np.asarray(region_pair[f'{ann_type}_{id_name}_{s}'])
399

400
      # If the flanks are synthetic or have passed the "quality control" checks
401
      # in `validate_flank_seeds`, then only annotations within the original
402
      # Pfam region endpoints should be taken into account.
403
      seq_start = region_pair[f'seq_start_{s}']
404
      seq_end = region_pair[f'seq_end_{s}']
405

406
      overlaps = interval_overlaps(
407
          start=seq_start - extra_margin,
408
          end=seq_end + extra_margin,
409
          ref_starts=hit_start,
410
          ref_ends=hit_end)
411
      indices = overlaps >= min_overlap
412

413
      region_pair[f'overlapping_{ann_type}_{s}'] = set(hit_ids[indices])
414

415
  return region_pair
416

417

418
def eval_confounding_in_regions(region_pair):
419
  """Inspects region for edge-cases, such as nested domains."""
420
  # Checks whether both regions have shared clan annotations other than their
421
  # own original `clan_acc`s.
422
  shared_clans = set.intersection(
423
      *[region_pair[f'overlapping_hmm_hit_{s}'] for s in SUFFIXES])
424
  # If both regions are labelled as belonging to the same clan, removes this
425
  # (expected) shared clan annotation from the set.
426
  if region_pair['homology_label'] > 0:
427
    shared_clans -= set(region_pair[f'clan_acc_{s}'] for s in SUFFIXES)
428
  region_pair['shares_clans'] = bool(shared_clans)
429
  # Checks whether both regions share other types of region annotations.
430
  for type_id in OTHER_REGIONS:
431
    region_pair[f'shares_{type_id}'] = all(
432
        type_id in region_pair[f'overlapping_other_hit_{s}'] for s in SUFFIXES)
433

434
  # We mark a region pair as "potentially confounded" for the homology task if
435
  # the regions are non-homologous and share any annotations in the annotation
436
  # categories described by `CONFOUNDING_REGIONS`.
437
  region_pair['maybe_confounded'] = (
438
      region_pair['homology_label'] == 0 and
439
      any(region_pair[f'shares_{type_id}'] for type_id in CONFOUNDING_REGIONS))
440

441
  return region_pair
442

443

444
def add_bos_and_eos_flags(
445
    region_pair,
446
    flanks = None,
447
):
448
  """Marks whether the sequences lie at the start/end of a full protein seq."""
449
  for s in SUFFIXES:
450
    seq_len = len(region_pair[f'original_sequence_{s}'])
451
    region_pair[f'bos_{s}'] = region_pair[f'seq_start_{s}'] == 1
452
    region_pair[f'eos_{s}'] = region_pair[f'seq_end_{s}'] == seq_len
453
    # In the case of synthetic or uniprot flanks, we only have a biologically
454
    # relevant start / end of a full sequence if no flanks were added.
455
    if flanks in ('synthetic', 'uniprot'):
456
      region_pair[f'bos_{s}'] &= region_pair[f'n_len_{s}'] == 0
457
      region_pair[f'eos_{s}'] &= region_pair[f'c_len_{s}'] == 0
458
  return region_pair
459

460

461
def add_extended_region_keys(region_pair):
462
  """Creates new keys for the sequences, summarizing the new endpoints."""
463
  for s in SUFFIXES:
464
    n_key_s = region_pair.get(f'n_key_{s}', '')
465
    c_key_s = region_pair.get(f'c_key_{s}', '')
466
    region_pair[f'extended_key_{s}'] = ';'.join(
467
        [n_key_s, region_pair[f'key_{s}'], c_key_s])
468
  return region_pair
469

470

471
def compute_alignment_path(region_pair):
472
  """Computes alignment path from `region_pair`'s gapped sequences."""
473
  matches, ali_start = alignment.alignment_from_gapped_sequences(
474
      gapped_sequence_x=region_pair['gapped_sequence_x'],
475
      gapped_sequence_y=region_pair['gapped_sequence_y'])
476
  region_pair['matches'] = matches
477
  region_pair['ali_start_x'] = ali_start[0]
478
  region_pair['ali_start_y'] = ali_start[1]
479
  return region_pair
480

481

482
def subsample_region_pairs(
483
    region_pair,
484
    count,
485
    min_count,
486
    resample_ratio,
487
    global_seed,
488
):
489
  """Randomly discards region pairs to rebalance label distribution."""
490
  rng = get_prng(region_pair, global_seed)
491
  keep_prob = resample_ratio * (min_count / count)
492
  if rng.uniform(0, 1) <= keep_prob:
493
    yield region_pair
494

495

496
def add_homology_label(region):
497
  """Computes (ternary) homology label (non-homologs / same clan / same fam)."""
498
  region['homology_label'] = 0
499
  # Note: if both sequences have the same Pfam accession (`pfam_acc`), they are
500
  # guaranteed to also have the same Clan accession (`clan_acc`). The converse
501
  # does not hold, however.
502
  for key in ('pfam_acc', 'clan_acc'):
503
    region['homology_label'] += int(
504
        region[f'{key}_{SUFFIXES[0]}'] == region[f'{key}_{SUFFIXES[1]}'])
505
  return region
506

507

508
def process_region_pair_alignment(
509
    region_pair,
510
    flanks = None,
511
    max_len = 511,
512
    extra_margin = 0,
513
    min_overlap = 1,
514
    global_seed = 0,
515
):
516
  """Generates a sample for the alignment task from a pair of Pfam regions."""
517
  rng = get_prng(region_pair, global_seed=global_seed)
518

519
  # Parses the gapped sequences in `region_pair` to extract the ground-truth
520
  # alignment path, described in terms of its starting positions and matches.
521
  region_pair = compute_alignment_path(region_pair)
522

523
  # Optionally, samples synthetic sequences ('synthetic') or UniProt sequences
524
  # ('uniprot') to generate a ground-truth alignment with flanks.
525
  if flanks in ('synthetic', 'uniprot'):
526
    region_pair = generate_flanks(
527
        rng=rng,
528
        region_pair=region_pair,
529
        flanks=flanks,
530
        max_len=max_len,
531
        extra_margin=extra_margin,
532
        min_overlap=min_overlap)
533
  # Extracts the (sub)sequences to be aligned, optionally including changes to
534
  # the N-terminus and C-terminus flanks.
535
  region_pair = extend_sequences(region_pair)
536
  # Only real flanks should be potentially problematic.
537
  region_pair['maybe_confounded'] = False
538
  region_pair['fallback'] = False
539

540
  # Adds flags to element indicating if the new endpoints correspond to the
541
  # start (resp. end) of a full sequence.
542
  region_pair = add_bos_and_eos_flags(region_pair, flanks)
543

544
  # Compresses the set of `matches` into a CIGAR-like state string.
545
  states = alignment.states_from_matches(region_pair['matches'])
546
  region_pair['states'] = alignment.compress_states(states)
547

548
  # Computes the percent identity of the sequnce pair, based on the ground-truth
549
  # alignment, taking any modifications to the flanks into account.
550
  region_pair['percent_identity'] = alignment.pid_from_matches(
551
      sequence_x=region_pair['sequence_x'],
552
      sequence_y=region_pair['sequence_y'],
553
      matches=region_pair['matches'],
554
      ali_start_x=region_pair['ali_start_x'],
555
      ali_start_y=region_pair['ali_start_y'])
556

557
  # Creates a new key for the sequences, summarizing the new endpoints.
558
  region_pair = add_extended_region_keys(region_pair)
559

560
  return region_pair
561

562

563
def process_region_pair_homology(
564
    region_pair,
565
    flanks = None,
566
    max_len = 511,
567
    extra_margin = 0,
568
    min_overlap = 1,
569
    global_seed = 0,
570
):
571
  """Generates a sample for the homology task from a pair of Pfam regions."""
572
  rng = get_prng(region_pair, global_seed=global_seed)
573

574
  # Ground-truth alignments are only available whenever both regions belong to
575
  # the same Pfam family.
576
  if region_pair['homology_label'] == 2:
577
    # Parses the gapped sequences in `region_pair` to extract the ground-truth
578
    # alignment path, described in terms of its starting positions and matches.
579
    region_pair = compute_alignment_path(region_pair)
580

581
  # Optionally, extends region boundaries ('contextual') or samples synthetic
582
  # sequences ('synthetic') to generate a ground-truth alignment with flanks.
583
  if flanks in ('synthetic', 'uniprot'):
584
    region_pair = generate_flanks(
585
        rng=rng,
586
        region_pair=region_pair,
587
        flanks=flanks,
588
        max_len=max_len,
589
        extra_margin=extra_margin,
590
        min_overlap=min_overlap)
591
  # Extracts the (sub)sequences to be aligned, optionally including changes to
592
  # the N-terminus and C-terminus flanks.
593
  region_pair = extend_sequences(region_pair)
594

595
  # Regions may contain shared annotations that could act as confounding
596
  # factors. We perform a best-effort attempt to detect such cases.
597
  # However, the incompleteness of annotation databases necessarily implies this
598
  # step will never be perfect and residual, undetected "confounding" might
599
  # persist for some region pairs.
600
  region_pair = annotate_regions(region_pair, extra_margin, min_overlap)
601
  region_pair = eval_confounding_in_regions(region_pair)
602

603
  # Adds flags to element indicating if the new endpoints correspond to the
604
  # start (resp. end) of a full sequence.
605
  region_pair = add_bos_and_eos_flags(region_pair, flanks)
606

607
  # Ground-truth percent identities for the region pair can only be computed
608
  # at the highest level of homology, namely, when both regions belong to the
609
  # same Pfam family.
610
  if region_pair['homology_label'] == 2:
611
    # Computes the percent identity of the sequnce pair, based on the
612
    # ground-truth alignment, taking any modifications to the flanks into
613
    # account.
614
    region_pair['percent_identity'] = alignment.pid_from_matches(
615
        sequence_x=region_pair['sequence_x'],
616
        sequence_y=region_pair['sequence_y'],
617
        matches=region_pair['matches'],
618
        ali_start_x=region_pair['ali_start_x'],
619
        ali_start_y=region_pair['ali_start_y'])
620
  else:
621
    region_pair['percent_identity'] = float('nan')
622

623
  # Creates a new key for the sequences, summarizing the new endpoints.
624
  region_pair = add_extended_region_keys(region_pair)
625

626
  return region_pair
627

628

629
def build_pfam_alignments_pipeline(
630
    file_pattern,
631
    dataset_splits_path,
632
    target_split,
633
    output_path,
634
    max_len = 511,
635
    flanks = None,
636
    extra_margin = 0,
637
    min_overlap = 1,
638
    global_seed = 0,
639
):
640
  """3a) Returns a pipeline to generate samples for sequence alignment task.
641

642
  Args:
643
    file_pattern: The file pattern from which to read preprocessed Pfam shards.
644
      This is assumed to be the result of steps 1a), 1b) and, optionally, step
645
      2) of the full preprocessing pipeline.
646
      See `preprocess_tables_lib.py` and `uniprot_flanks_lib.py` for additional
647
      details.
648
    dataset_splits_path: The path to the key, split mapping file.
649
    target_split: The dataset split for which to generate pairwise alignment
650
      data.
651
    output_path: The path prefix to the output files.
652
    max_len: The maximum length of sequences to be included in the output
653
      dataset (without BOS or EOS tokens).
654
    flanks: The approach to be used add flanking sequences to Pfam regions. If
655
      `None`, no flanking sequences will be added. Supported modes include
656
      `synthetic` and `uniprot`.
657
    extra_margin: Extends sequence boundaries by `extra_margin` residues when
658
      evaluating overlap between annotations.
659
    min_overlap: The minimum number of residues in a sequence that need to
660
      overlap with a region annotation in order for the annotation to be applied
661
      to the sequence.
662
    global_seed: A global seed for the PRNG.
663

664
  Returns:
665
  A beam.Pipeline.
666
  """
667
  def pipeline(root):
668
    # Reads data preprocessed by steps 1a) and 1b) in `preprocess_tables_lib.py`
669
    # and, optionally, step 2) in `uniprot_flanks_lib.py`.
670
    with_flank_seeds = flanks == 'uniprot'
671
    fields_to_keep = ALIGNMENT_FIELDS
672
    if with_flank_seeds:
673
      fields_to_keep += tuple(f[0] for f in schemas.FLANK_FIELDS)
674
    regions = (
675
        root
676
        | 'ReadParsedPfamData' >> ReadParsedPfamData(
677
            file_pattern=file_pattern,
678
            dataset_splits_path=dataset_splits_path,
679
            fields_to_keep=fields_to_keep,
680
            with_flank_seeds=with_flank_seeds,
681
            max_len=max_len,
682
            filter_by_qc=True))
683
    # Filters sequences not belonging to the target split and removes the no
684
    # longer needed split field.
685
    filtered_regions = (
686
        regions
687
        | 'FilterBySplit' >> beam.Filter(
688
            lambda x: x['split'] == target_split)
689
        | 'DropSplitField' >> beam.Map(
690
            functools.partial(
691
                utils.drop_record_fields,
692
                fields_to_drop=['split'])))
693
    # Enumerates all pairs of regions sharing the same Pfam accession. Each
694
    # region pair is processed to produce the final fields that will be used for
695
    # training and evaluating the models on the pairwise alignment task.
696
    region_pairs = (
697
        filtered_regions
698
        | 'EnumerateAllFamilyPairs' >> utils.Combinations(
699
            groupby_field='pfam_acc',
700
            key_field='key',
701
            num_samples=None,
702
            suffixes=SUFFIXES)
703
        | 'ProcessRegionPairs' >> beam.Map(
704
            functools.partial(
705
                process_region_pair_alignment,
706
                flanks=flanks,
707
                max_len=max_len,
708
                extra_margin=extra_margin,
709
                min_overlap=min_overlap,
710
                global_seed=global_seed)))
711
    # Writes postprocessed region pairs to disk as tab-delimited sharded text
712
    # files.
713
    _ = (
714
        region_pairs
715
        | 'WriteToTable' >> schemas_lib.WriteToTable(
716
            file_path_prefix=output_path,
717
            schema_cls=schemas.PairwiseAlignmentRow))
718

719
  return pipeline
720

721

722
def build_pfam_homology_pipeline(
723
    file_pattern,
724
    dataset_splits_path,
725
    target_split,
726
    output_path,
727
    avg_num_samples,
728
    prob_pos_different_family = 0.11,
729
    prob_neg = 0.5,
730
    max_len = 511,
731
    flanks = None,
732
    extra_margin = 0,
733
    min_overlap = 1,
734
    global_seed = 0,
735
):
736
  """3b) Returns a pipeline to generate samples for homology detection task.
737

738
  Args:
739
    file_pattern: The file pattern from which to read preprocessed Pfam shards.
740
      This is assumed to be the result of steps 1a), 1b) and, optionally, step
741
      2) of the full preprocessing pipeline.
742
      See `preprocess_tables_lib.py` and `uniprot_flanks_lib.py` for additional
743
      details.
744
    dataset_splits_path: The path to the key, split mapping file.
745
    target_split: The dataset split for which to generate pairwise alignment
746
      data.
747
    output_path: The path prefix to the output files.
748
    avg_num_samples: The (expected) number of samples (sequence pairs) to
749
      subsample (homologous and non-homologous).
750
    prob_pos_different_family: The (expected) proportion of samples consisting
751
      of region pairs in the same clan but different families.
752
    prob_neg: The (expected) proportion of samples consisting of non-homologous
753
      region pairs, that is, regions in different clans.
754
    max_len: The maximum length of sequences to be included in the output
755
      dataset.
756
    flanks: The approach to be used add flanking sequences to Pfam
757
      regions. If `None`, no flanking sequences will be added. Supported modes
758
      include `synthetic` and `uniprot`.
759
    extra_margin: Extends sequence boundaries by `extra_margin` residues when
760
      evaluating overlap between annotations.
761
    min_overlap: The minimum number of residues in a sequence that need to
762
      overlap with a region annotation in order for the annotation to be applied
763
      to the sequence.
764
    global_seed: A global seed for the PRNG.
765

766
  Returns:
767
  A beam.Pipeline.
768
  """
769
  def pipeline(root):
770
    # Reads data preprocessed by steps 1a) and 1b) in `preprocess_tables_lib.py`
771
    # and, optionally, step 2) in `uniprot_flanks_lib.py`.
772
    with_flank_seeds = flanks == 'uniprot'
773
    fields_to_keep = ALIGNMENT_FIELDS
774
    if with_flank_seeds:
775
      fields_to_keep += tuple(f[0] for f in schemas.FLANK_FIELDS)
776
    regions = (
777
        root
778
        | 'ReadParsedPfamData' >> ReadParsedPfamData(
779
            file_pattern=file_pattern,
780
            dataset_splits_path=dataset_splits_path,
781
            fields_to_keep=fields_to_keep,
782
            with_flank_seeds=with_flank_seeds,
783
            max_len=max_len,
784
            filter_by_qc=True))
785
    # Filters sequences not belonging to the target split and removes the no
786
    # longer needed split field.
787
    filtered_regions = (
788
        regions
789
        | 'FilterBySplit' >> beam.Filter(
790
            lambda x: x['split'] == target_split)
791
        | 'DropSplitField' >> beam.Map(
792
            functools.partial(
793
                utils.drop_record_fields,
794
                fields_to_drop=['split'])))
795

796
    # Enumerates a subsample of (on average) `avg_num_samples` pairs of
797
    # homologous and non-homologous regions, keeping only the latter and adding
798
    # a class label for each pair:
799
    # + neg: homology_label = 0, if non-homologs (different clan).
800
    # + mid: homology_label = 1, if remote homologs (same clan, different fams).
801
    # + pos: homology_label = 2, if homologs (same family).
802
    # This subsample will be heavily biased towards negative samples (neg).
803
    neg_region_pairs = (
804
        filtered_regions
805
        | 'EnumerateNegRegionPairs' >> utils.SubsampleOuterProduct(
806
            avg_num_samples=avg_num_samples,
807
            groupby_field=None,
808
            key_field='key',
809
            suffixes=SUFFIXES)
810
        | 'KeepNegRegionPairs' >> beam.Filter(
811
            lambda x: x['clan_acc_x'] != x['clan_acc_y'])
812
        | 'AddHomologyLabelNegRegionPairs' >> beam.Map(add_homology_label))
813

814
    # Enumerates a subsample of (on average) `avg_num_samples` pairs of
815
    # homologous regions (either in the same family, or in the same clan but
816
    # in different families), keeping only the latter and adding a class label
817
    # for each pair.
818
    # This subsample will be heavily biased towards samples in the same clan but
819
    # in different families (mid).
820
    mid_region_pairs = (
821
        filtered_regions
822
        | 'EnumerateMidRegionPairs' >> utils.SubsampleOuterProduct(
823
            avg_num_samples=avg_num_samples,
824
            groupby_field='clan_acc',
825
            key_field='key',
826
            suffixes=SUFFIXES)
827
        | 'KeepMidRegionPairs' >> beam.Filter(
828
            lambda x: x['pfam_acc_x'] != x['pfam_acc_y'])
829
        | 'AddHomologyLabelMidRegionPairs' >> beam.Map(add_homology_label))
830

831
    # Enumerates all pairs of regions sharing the same Pfam accession, adding
832
    # the corresponding homology label (2) to all of the resulting records.
833
    pos_region_pairs = (
834
        filtered_regions
835
        | 'EnumeratePosRegionPairs' >> utils.Combinations(
836
            groupby_field='pfam_acc',
837
            key_field='key',
838
            num_samples=None,
839
            suffixes=SUFFIXES)
840
        | 'AddHomologyLabelPosRegionPairs' >> beam.Map(add_homology_label))
841

842
    # Computes the number of region pairs in each category. We assume that
843
    # homologs (same family), i.e., `count_pos`, is the smallest. That holds
844
    # true for Pfam-A seed 34.0.
845
    count_pos = pos_region_pairs | 'CountPos' >> beam.combiners.Count.Globally()
846
    count_mid = mid_region_pairs | 'CountMid' >> beam.combiners.Count.Globally()
847
    count_neg = neg_region_pairs | 'CountNeg' >> beam.combiners.Count.Globally()
848

849
    # Downsamples each category to obtain data with the desired class label
850
    # distribution.
851
    prob_pos_same_family = 1.0 - prob_pos_different_family - prob_neg
852
    assert prob_pos_same_family > 0
853

854
    mid_region_pairs = (
855
        mid_region_pairs
856
        | 'DownsampleMidRegionPairs' >> beam.FlatMap(
857
            subsample_region_pairs,
858
            beam.pvalue.AsSingleton(count_mid),
859
            beam.pvalue.AsSingleton(count_pos),
860
            resample_ratio=prob_pos_different_family / prob_pos_same_family,
861
            global_seed=global_seed))
862
    neg_region_pairs = (
863
        neg_region_pairs
864
        | 'DownsampleNegRegionPairs' >> beam.FlatMap(
865
            subsample_region_pairs,
866
            beam.pvalue.AsSingleton(count_neg),
867
            beam.pvalue.AsSingleton(count_pos),
868
            resample_ratio=prob_neg / prob_pos_same_family,
869
            global_seed=global_seed))
870

871
    # Homologous and non-homologous regions are merged. The regions pairs are
872
    # then processed to produce the final fields that will be used for training
873
    # and evaluating the models on the pairwise homology detection task.
874
    region_pairs = (
875
        (pos_region_pairs, mid_region_pairs, neg_region_pairs)
876
        | 'MergeAllClasses' >> beam.Flatten()
877
        | 'ReshuffleAfterMerging' >> beam.Reshuffle()
878
        | 'ProcessRegionPairs' >> beam.Map(
879
            functools.partial(
880
                process_region_pair_homology,
881
                flanks=flanks,
882
                max_len=max_len,
883
                extra_margin=extra_margin,
884
                min_overlap=min_overlap,
885
                global_seed=global_seed)))
886
    # Writes postprocessed region pairs to disk as tab-delimited sharded text
887
    # files.
888
    _ = (
889
        region_pairs
890
        | 'WriteToTable' >> schemas_lib.WriteToTable(
891
            file_path_prefix=output_path,
892
            schema_cls=schemas.PairwiseHomologyRow))
893

894
  return pipeline
895

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.