google-research
894 строки · 34.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Pipelines to generate datasets for the alignment and homology tasks."""
17
18import collections19import functools20import itertools21import random22from typing import Callable, Dict, Iterable, Iterator, Optional23
24import apache_beam as beam25import numpy as np26import tensorflow as tf27
28from dedal.preprocessing import alignment29from dedal.preprocessing import schemas30from dedal.preprocessing import schemas_lib31from dedal.preprocessing import types32from dedal.preprocessing import utils33
34
35# Type aliases
36Record = types.Record37Array = np.ndarray38PRNG = random.Random39
40
41# Constants
42ALIGNMENT_FIELDS = (43'key', 'pfam_acc', 'clan_acc', 'seq_start', 'seq_end', 'passed_qc',44'sequence', 'gapped_sequence',45'hmm_hit_seq_start', 'hmm_hit_seq_end', 'hmm_hit_clan_acc',46'other_hit_seq_start', 'other_hit_seq_end', 'other_hit_type_id')47OTHER_REGIONS = (48'coiled_coil', 'disorder', 'low_complexity', 'sig_p', 'transmembrane')49CONFOUNDING_REGIONS = ('clans',)50SPLITS = ('train', 'iid_validation', 'ood_validation', 'iid_test', 'ood_test')51PREFIXES = ('n', 'c') # Indices of the two flanks (N-terminus, C-terminus).52SUFFIXES = ('x', 'y') # Indices of the two regions in the pairs.53AA_PROB_TABLE = {'A': 0.0832177442683677,54'C': 0.013846953304580488,55'D': 0.05746458363960445,56'E': 0.0662881179936497,57'F': 0.03796971998594428,58'G': 0.06893382361885941,59'H': 0.021288200487753935,60'I': 0.05563140763547133,61'K': 0.05514607272883951,62'L': 0.09466292698151958,63'M': 0.021795269979457313,64'N': 0.042539379191607996,65'P': 0.04820225072073993,66'Q': 0.0397779223030366,67'R': 0.05446797313841446,68'S': 0.07095409709482754,69'T': 0.05849503653768657,70'V': 0.06729418111891682,71'W': 0.011878265406494127,72'Y': 0.030146073864228278}73
74
75def load_pfam_accs(path):76"""Returns a mapping from Pfam accessions to unique integer indices.77
78Args:
79path: The path to a plain text file containing one Pfam accession per line.
80
81Returns:
82A dictionary mapping Pfam accessions to integer-valued identifiers numbered
83between 0 (inclusive) and the total number of distinct Pfam accessions
84(exclusive).
85"""
86pfam_acc_to_index = {}87with tf.io.gfile.GFile(path, 'r') as f:88for i, line in enumerate(f):89pfam_acc = line.strip()90if pfam_acc in pfam_acc_to_index:91raise ValueError(f'Key {pfam_acc} is duplicated in {path}.')92pfam_acc_to_index[pfam_acc] = i93return pfam_acc_to_index94
95
96class ReadParsedPfamData(beam.PTransform):97"""Reads TSV data generated by Step 1. of the preprocessing pipeline."""98
99def __init__(100self,101file_pattern,102dataset_splits_path,103fields_to_keep = None,104with_flank_seeds = False,105max_len = None,106filter_by_qc = True,107):108self.file_pattern = file_pattern109self.dataset_splits_path = dataset_splits_path110self.fields_to_keep = fields_to_keep111self.with_flank_seeds = with_flank_seeds112self.max_len = max_len113self.filter_by_qc = filter_by_qc114self._schema_cls = (schemas.ExtendedParsedPfamRow if self.with_flank_seeds115else schemas.ParsedPfamRow)116
117def expand(118self,119root,120):121read_alignment_data_cls = functools.partial(122schemas_lib.ReadFromTable,123schema_cls=self._schema_cls,124key_field='key',125skip_header_lines=1,126fields_to_keep=self.fields_to_keep)127
128read_dataset_splits_cls = functools.partial(129schemas_lib.ReadFromTable,130schema_cls=schemas.DatasetSplits,131key_field='key',132skip_header_lines=1,133fields_to_keep=('split',))134
135# Reads Pfam data from the preceding pipeline. Each element in the output136# `PCollection` represents a different Pfam region.137regions = root | 'ReadRegions' >> read_alignment_data_cls(self.file_pattern)138
139# Optionally, removes any Pfam regions that did not pass the quality control140# checks from the `PCollection`.141if self.filter_by_qc:142regions = regions | 'QCFilter' >> beam.Filter(lambda x: x[1]['passed_qc'])143
144# Optionally, drops Pfam regions that do not pass the maximum region length145# filter.146if self.max_len is not None:147len_fn = lambda x: x[1]['seq_end'] - x[1]['seq_start'] + 1 <= self.max_len148regions = regions | 'LengthFilter' >> beam.Filter(len_fn)149
150# Reads mapping Pfam region key: split.151dataset_splits = (152root
153| 'ReadDatasetSplits' >> read_dataset_splits_cls(154self.dataset_splits_path))155# Merges split info from `dataset_splits` into elements `regions` and156# removes the key of each element, which is no longer needed after merging157# the two `PCollection`s.158return (159{'pfam_regions': regions, 'dataset_splits': dataset_splits}160| 'MergeSplitInfoByKey' >> schemas_lib.JoinTables(161left_join_tables=['pfam_regions'])162| 'RemoveKey' >> beam.Values())163
164
165def get_prng(record, global_seed, field_name = 'key'):166"""Generates a per-example pair reproducible PRNG key."""167return random.Random(168hash(tuple(record[f'{field_name}_{s}'] for s in SUFFIXES)) + global_seed)169
170
171def sample_flank_lengths(172rng,173seq_start,174seq_end,175max_len,176):177"""Samples length of the N-terminus and C-terminus flanks at random."""178region_len = seq_end - seq_start + 1 # Endpoints are inclusive.179max_ctx_len = max_len - region_len180max_n_len = max_c_len = max_ctx_len181ctx_len = rng.randint(0, max_ctx_len)182n_len = rng.randint(max(ctx_len - max_c_len, 0), min(max_n_len, ctx_len))183c_len = ctx_len - n_len184return {'n_len': n_len, 'c_len': c_len}185
186
187def validate_flank_seeds(188region_pair,189indices,190extra_margin = 0,191min_overlap = 1,192):193"""Tests if a combination of UniProt flanks is valid for a region pair."""194# First, verifies that all four flank seeds are non-empty if a flank needs to195# be generated.196for p, s in itertools.product(PREFIXES, SUFFIXES):197flank_len = region_pair[f'{p}_len_{s}']198idx = indices[f'{p}_{s}']199key = region_pair[f'{p}_flank_seed_key_{idx}_{s}']200sequence = region_pair[f'{p}_flank_seed_sequence_{idx}_{s}']201if flank_len and (not key or not sequence):202return False203
204# Second, retrieves the collection of hmm_hits and other_hits in each flank205# and checks that there are no shared annotations between the flanks of each206# sequence.207flank_hits = collections.defaultdict(set)208for ann_type in ('hmm_hit_clan_acc',):209for p, s in itertools.product(PREFIXES, SUFFIXES):210idx = indices[f'{p}_{s}']211hits = region_pair[f'{p}_flank_seed_{ann_type}_{idx}_{s}']212flank_hits[f'{ann_type}_{s}'] |= set(hits)213if set.intersection(*[flank_hits[f'{ann_type}_{s}'] for s in SUFFIXES]):214return False215
216# Finally, tests if there are shared hmm_hit annotations between the flanks217# of one sequence and the main region of the other sequence.218for s1, s2 in zip(SUFFIXES, reversed(SUFFIXES)):219# TODO(fllinares): pre-compute `hmm_hits`, perhaps at the cost of clarity.220overlaps = interval_overlaps(221start=region_pair[f'seq_start_{s1}'] - extra_margin,222end=region_pair[f'seq_end_{s1}'] + extra_margin,223ref_starts=np.asarray(region_pair[f'hmm_hit_seq_start_{s1}']),224ref_ends=np.asarray(region_pair[f'hmm_hit_seq_end_{s1}']))225hit_ids = np.asarray(region_pair[f'hmm_hit_clan_acc_{s1}'])226hmm_hits = set(hit_ids[overlaps >= min_overlap])227hmm_hits.add(region_pair[f'clan_acc_{s1}']) # Likely redundant.228if hmm_hits & flank_hits[f'hmm_hit_clan_acc_{s2}']:229return False230
231return True232
233
234def pair_flank_seeds(235rng,236region_pair,237extra_margin = 0,238min_overlap = 1,239):240"""Searches from a valid combination of UniProt flanks for the region pair."""241num_flank_seeds = schemas.NUM_FLANK_SEEDS242idx_keys = tuple(f'{p}_{s}' for p, s in itertools.product(PREFIXES, SUFFIXES))243# Iterates over all possible flank seed pairings in a random order.244for indices in itertools.product(245rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds),246rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds),247rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds),248rng.sample(range(1, num_flank_seeds + 1), k=num_flank_seeds)):249indices = dict(zip(idx_keys, indices))250if validate_flank_seeds(region_pair, indices, extra_margin, min_overlap):251for p, s in itertools.product(PREFIXES, SUFFIXES):252if region_pair[f'{p}_len_{s}']: # Skips empty flanks.253region_pair[f'{p}_flank_seed_idx_{s}'] = indices[f'{p}_{s}']254break255return region_pair256
257
258def sample_synthetic_flank(rng, length):259"""Samples `length` chars from `AA_PROB_TABLE` i.i.d.."""260amino_acids = list(AA_PROB_TABLE.keys())261probabilities = list(AA_PROB_TABLE.values())262return ''.join(rng.choices(amino_acids, probabilities, k=length))263
264
265def generate_flanks(266rng,267region_pair,268flanks,269max_len,270extra_margin = 0,271min_overlap = 1,272):273"""Extends a pair of Pfam domains adding N and C-terminus flanks."""274# Samples the lengths of the N-terminus and C-terminus flanks, adding the275# resulting variables to `region_pair`.276for s in SUFFIXES:277out = sample_flank_lengths(278rng=rng,279seq_start=region_pair[f'seq_start_{s}'],280seq_end=region_pair[f'seq_end_{s}'],281max_len=max_len)282region_pair.update({f'{k}_{s}': v for k, v in out.items()})283
284# When using flanks from UniProt, the choice of which flank "seeds" to pick285# for the N-terminus and C-terminus ends for each of the two regions must be286# done jointly. In contrast, the rest of the flank generation pipeline can be287# done independently for each sequence and flank.288if flanks == 'uniprot':289region_pair = pair_flank_seeds(rng, region_pair, extra_margin, min_overlap)290
291# Processes each of the two regions, `sequence_x` and `sequence_y`,292# independently.293for s in SUFFIXES:294seq_start = region_pair[f'seq_start_{s}']295seq_end = region_pair[f'seq_end_{s}']296# Generates the N-terminus and C-terminus flanks for `sequence_{s}`.297for p in PREFIXES:298# Synthetic flanks are obtained by randomly sampling amino acids from the299# original sequence with replacement, independently for each flank.300if flanks == 'synthetic':301flank_acc = 'synth'302flank_len = region_pair[f'{p}_len_{s}']303# Computes the (inclusive) endpoints of the N-terminus flank. Note that304# `flank_start` could be negative.305if p == PREFIXES[0]:306flank_start = seq_start - flank_len307flank_end = seq_start - 1308# Computes the (inclusive) endpoints of the C-terminus flank. Note that309# `flank_end` could be larger than the length of the original sequence.310else:311flank_start = seq_end + 1312flank_end = seq_end + flank_len313region_pair[f'{p}_flank_{s}'] = sample_synthetic_flank(314rng=rng, length=flank_len)315# Uniprot flanks are obtained by (brute-force) searching for a combination316# of UniPort (sub)sequences that satisfy all of the "quality control"317# criteria in `validate_flank_seeds`. Note that, since each sequence has318# only a finite number of precomputed flank "seeds" available (for319# tractability), there is a small but non-zero probability that no flank320# combination is valid. In these (rare) cases, no flanks are added.321elif flanks == 'uniprot':322idx = region_pair.get(f'{p}_flank_seed_idx_{s}', None)323if idx is not None: # A valid flank combination was found.324flank_key = region_pair[f'{p}_flank_seed_key_{idx}_{s}']325flank_seq = region_pair[f'{p}_flank_seed_sequence_{idx}_{s}']326
327flank_acc, flank_endpoints = flank_key.split('/')328flank_start, flank_end = [int(x) for x in flank_endpoints.split('-')]329assert len(flank_seq) == (flank_end - flank_start + 1)330
331offset = rng.randint(0, len(flank_seq) - region_pair[f'{p}_len_{s}'])332flank_start += offset333flank_end = flank_start + region_pair[f'{p}_len_{s}'] - 1334
335region_pair[f'{p}_flank_{s}'] = flank_seq[336offset:offset + region_pair[f'{p}_len_{s}']]337else: # No valid flank combination was found.338# Marks flank as empty.339flank_start = 0340flank_end = -1341# Unrecognized.342else:343raise ValueError(344f"flanks must be 'synthetic' or 'uniprot'. Got {flanks} instead.")345
346# Sets a flank key for inspectability purposes if the flank is not empty.347if flank_start <= flank_end:348region_pair[f'{p}_key_{s}'] = f'{flank_acc}/{flank_start}-{flank_end}'349else:350region_pair[f'{p}_key_{s}'] = ''351
352return region_pair353
354
355def extend_sequences(region_pair):356"""Optionally, extends domain sequences with N and C-terminus flanks."""357for s in SUFFIXES:358seq_start = region_pair[f'seq_start_{s}']359seq_end = region_pair[f'seq_end_{s}']360n_flank = region_pair.get(f'n_flank_{s}', '')361c_flank = region_pair.get(f'c_flank_{s}', '')362sequence = region_pair[f'sequence_{s}']363
364region_pair[f'sequence_{s}'] = ''.join(365[n_flank, sequence[seq_start - 1:seq_end], c_flank])366
367if f'ali_start_{s}' in region_pair:368region_pair[f'ali_start_{s}'] += len(n_flank)369
370# Backs up the original sequence.371region_pair[f'original_sequence_{s}'] = sequence372
373return region_pair374
375
376def interval_overlaps(377start,378end,379ref_starts,380ref_ends,381):382"""Computes the overlap between closed intervals."""383overlaps = np.minimum(end, ref_ends) - np.maximum(start, ref_starts) + 1384return np.maximum(0, overlaps)385
386
387def annotate_regions(388region_pair,389extra_margin = 0,390min_overlap = 1,391):392"""Finds any annotations that overlap with the regions in `region_pair`."""393for ann_type, id_name in zip(('hmm_hit', 'other_hit'),394('clan_acc', 'type_id')):395for s in SUFFIXES:396hit_start = np.asarray(region_pair[f'{ann_type}_seq_start_{s}'])397hit_end = np.asarray(region_pair[f'{ann_type}_seq_end_{s}'])398hit_ids = np.asarray(region_pair[f'{ann_type}_{id_name}_{s}'])399
400# If the flanks are synthetic or have passed the "quality control" checks401# in `validate_flank_seeds`, then only annotations within the original402# Pfam region endpoints should be taken into account.403seq_start = region_pair[f'seq_start_{s}']404seq_end = region_pair[f'seq_end_{s}']405
406overlaps = interval_overlaps(407start=seq_start - extra_margin,408end=seq_end + extra_margin,409ref_starts=hit_start,410ref_ends=hit_end)411indices = overlaps >= min_overlap412
413region_pair[f'overlapping_{ann_type}_{s}'] = set(hit_ids[indices])414
415return region_pair416
417
418def eval_confounding_in_regions(region_pair):419"""Inspects region for edge-cases, such as nested domains."""420# Checks whether both regions have shared clan annotations other than their421# own original `clan_acc`s.422shared_clans = set.intersection(423*[region_pair[f'overlapping_hmm_hit_{s}'] for s in SUFFIXES])424# If both regions are labelled as belonging to the same clan, removes this425# (expected) shared clan annotation from the set.426if region_pair['homology_label'] > 0:427shared_clans -= set(region_pair[f'clan_acc_{s}'] for s in SUFFIXES)428region_pair['shares_clans'] = bool(shared_clans)429# Checks whether both regions share other types of region annotations.430for type_id in OTHER_REGIONS:431region_pair[f'shares_{type_id}'] = all(432type_id in region_pair[f'overlapping_other_hit_{s}'] for s in SUFFIXES)433
434# We mark a region pair as "potentially confounded" for the homology task if435# the regions are non-homologous and share any annotations in the annotation436# categories described by `CONFOUNDING_REGIONS`.437region_pair['maybe_confounded'] = (438region_pair['homology_label'] == 0 and439any(region_pair[f'shares_{type_id}'] for type_id in CONFOUNDING_REGIONS))440
441return region_pair442
443
444def add_bos_and_eos_flags(445region_pair,446flanks = None,447):448"""Marks whether the sequences lie at the start/end of a full protein seq."""449for s in SUFFIXES:450seq_len = len(region_pair[f'original_sequence_{s}'])451region_pair[f'bos_{s}'] = region_pair[f'seq_start_{s}'] == 1452region_pair[f'eos_{s}'] = region_pair[f'seq_end_{s}'] == seq_len453# In the case of synthetic or uniprot flanks, we only have a biologically454# relevant start / end of a full sequence if no flanks were added.455if flanks in ('synthetic', 'uniprot'):456region_pair[f'bos_{s}'] &= region_pair[f'n_len_{s}'] == 0457region_pair[f'eos_{s}'] &= region_pair[f'c_len_{s}'] == 0458return region_pair459
460
461def add_extended_region_keys(region_pair):462"""Creates new keys for the sequences, summarizing the new endpoints."""463for s in SUFFIXES:464n_key_s = region_pair.get(f'n_key_{s}', '')465c_key_s = region_pair.get(f'c_key_{s}', '')466region_pair[f'extended_key_{s}'] = ';'.join(467[n_key_s, region_pair[f'key_{s}'], c_key_s])468return region_pair469
470
471def compute_alignment_path(region_pair):472"""Computes alignment path from `region_pair`'s gapped sequences."""473matches, ali_start = alignment.alignment_from_gapped_sequences(474gapped_sequence_x=region_pair['gapped_sequence_x'],475gapped_sequence_y=region_pair['gapped_sequence_y'])476region_pair['matches'] = matches477region_pair['ali_start_x'] = ali_start[0]478region_pair['ali_start_y'] = ali_start[1]479return region_pair480
481
482def subsample_region_pairs(483region_pair,484count,485min_count,486resample_ratio,487global_seed,488):489"""Randomly discards region pairs to rebalance label distribution."""490rng = get_prng(region_pair, global_seed)491keep_prob = resample_ratio * (min_count / count)492if rng.uniform(0, 1) <= keep_prob:493yield region_pair494
495
496def add_homology_label(region):497"""Computes (ternary) homology label (non-homologs / same clan / same fam)."""498region['homology_label'] = 0499# Note: if both sequences have the same Pfam accession (`pfam_acc`), they are500# guaranteed to also have the same Clan accession (`clan_acc`). The converse501# does not hold, however.502for key in ('pfam_acc', 'clan_acc'):503region['homology_label'] += int(504region[f'{key}_{SUFFIXES[0]}'] == region[f'{key}_{SUFFIXES[1]}'])505return region506
507
508def process_region_pair_alignment(509region_pair,510flanks = None,511max_len = 511,512extra_margin = 0,513min_overlap = 1,514global_seed = 0,515):516"""Generates a sample for the alignment task from a pair of Pfam regions."""517rng = get_prng(region_pair, global_seed=global_seed)518
519# Parses the gapped sequences in `region_pair` to extract the ground-truth520# alignment path, described in terms of its starting positions and matches.521region_pair = compute_alignment_path(region_pair)522
523# Optionally, samples synthetic sequences ('synthetic') or UniProt sequences524# ('uniprot') to generate a ground-truth alignment with flanks.525if flanks in ('synthetic', 'uniprot'):526region_pair = generate_flanks(527rng=rng,528region_pair=region_pair,529flanks=flanks,530max_len=max_len,531extra_margin=extra_margin,532min_overlap=min_overlap)533# Extracts the (sub)sequences to be aligned, optionally including changes to534# the N-terminus and C-terminus flanks.535region_pair = extend_sequences(region_pair)536# Only real flanks should be potentially problematic.537region_pair['maybe_confounded'] = False538region_pair['fallback'] = False539
540# Adds flags to element indicating if the new endpoints correspond to the541# start (resp. end) of a full sequence.542region_pair = add_bos_and_eos_flags(region_pair, flanks)543
544# Compresses the set of `matches` into a CIGAR-like state string.545states = alignment.states_from_matches(region_pair['matches'])546region_pair['states'] = alignment.compress_states(states)547
548# Computes the percent identity of the sequnce pair, based on the ground-truth549# alignment, taking any modifications to the flanks into account.550region_pair['percent_identity'] = alignment.pid_from_matches(551sequence_x=region_pair['sequence_x'],552sequence_y=region_pair['sequence_y'],553matches=region_pair['matches'],554ali_start_x=region_pair['ali_start_x'],555ali_start_y=region_pair['ali_start_y'])556
557# Creates a new key for the sequences, summarizing the new endpoints.558region_pair = add_extended_region_keys(region_pair)559
560return region_pair561
562
563def process_region_pair_homology(564region_pair,565flanks = None,566max_len = 511,567extra_margin = 0,568min_overlap = 1,569global_seed = 0,570):571"""Generates a sample for the homology task from a pair of Pfam regions."""572rng = get_prng(region_pair, global_seed=global_seed)573
574# Ground-truth alignments are only available whenever both regions belong to575# the same Pfam family.576if region_pair['homology_label'] == 2:577# Parses the gapped sequences in `region_pair` to extract the ground-truth578# alignment path, described in terms of its starting positions and matches.579region_pair = compute_alignment_path(region_pair)580
581# Optionally, extends region boundaries ('contextual') or samples synthetic582# sequences ('synthetic') to generate a ground-truth alignment with flanks.583if flanks in ('synthetic', 'uniprot'):584region_pair = generate_flanks(585rng=rng,586region_pair=region_pair,587flanks=flanks,588max_len=max_len,589extra_margin=extra_margin,590min_overlap=min_overlap)591# Extracts the (sub)sequences to be aligned, optionally including changes to592# the N-terminus and C-terminus flanks.593region_pair = extend_sequences(region_pair)594
595# Regions may contain shared annotations that could act as confounding596# factors. We perform a best-effort attempt to detect such cases.597# However, the incompleteness of annotation databases necessarily implies this598# step will never be perfect and residual, undetected "confounding" might599# persist for some region pairs.600region_pair = annotate_regions(region_pair, extra_margin, min_overlap)601region_pair = eval_confounding_in_regions(region_pair)602
603# Adds flags to element indicating if the new endpoints correspond to the604# start (resp. end) of a full sequence.605region_pair = add_bos_and_eos_flags(region_pair, flanks)606
607# Ground-truth percent identities for the region pair can only be computed608# at the highest level of homology, namely, when both regions belong to the609# same Pfam family.610if region_pair['homology_label'] == 2:611# Computes the percent identity of the sequnce pair, based on the612# ground-truth alignment, taking any modifications to the flanks into613# account.614region_pair['percent_identity'] = alignment.pid_from_matches(615sequence_x=region_pair['sequence_x'],616sequence_y=region_pair['sequence_y'],617matches=region_pair['matches'],618ali_start_x=region_pair['ali_start_x'],619ali_start_y=region_pair['ali_start_y'])620else:621region_pair['percent_identity'] = float('nan')622
623# Creates a new key for the sequences, summarizing the new endpoints.624region_pair = add_extended_region_keys(region_pair)625
626return region_pair627
628
629def build_pfam_alignments_pipeline(630file_pattern,631dataset_splits_path,632target_split,633output_path,634max_len = 511,635flanks = None,636extra_margin = 0,637min_overlap = 1,638global_seed = 0,639):640"""3a) Returns a pipeline to generate samples for sequence alignment task.641
642Args:
643file_pattern: The file pattern from which to read preprocessed Pfam shards.
644This is assumed to be the result of steps 1a), 1b) and, optionally, step
6452) of the full preprocessing pipeline.
646See `preprocess_tables_lib.py` and `uniprot_flanks_lib.py` for additional
647details.
648dataset_splits_path: The path to the key, split mapping file.
649target_split: The dataset split for which to generate pairwise alignment
650data.
651output_path: The path prefix to the output files.
652max_len: The maximum length of sequences to be included in the output
653dataset (without BOS or EOS tokens).
654flanks: The approach to be used add flanking sequences to Pfam regions. If
655`None`, no flanking sequences will be added. Supported modes include
656`synthetic` and `uniprot`.
657extra_margin: Extends sequence boundaries by `extra_margin` residues when
658evaluating overlap between annotations.
659min_overlap: The minimum number of residues in a sequence that need to
660overlap with a region annotation in order for the annotation to be applied
661to the sequence.
662global_seed: A global seed for the PRNG.
663
664Returns:
665A beam.Pipeline.
666"""
667def pipeline(root):668# Reads data preprocessed by steps 1a) and 1b) in `preprocess_tables_lib.py`669# and, optionally, step 2) in `uniprot_flanks_lib.py`.670with_flank_seeds = flanks == 'uniprot'671fields_to_keep = ALIGNMENT_FIELDS672if with_flank_seeds:673fields_to_keep += tuple(f[0] for f in schemas.FLANK_FIELDS)674regions = (675root
676| 'ReadParsedPfamData' >> ReadParsedPfamData(677file_pattern=file_pattern,678dataset_splits_path=dataset_splits_path,679fields_to_keep=fields_to_keep,680with_flank_seeds=with_flank_seeds,681max_len=max_len,682filter_by_qc=True))683# Filters sequences not belonging to the target split and removes the no684# longer needed split field.685filtered_regions = (686regions
687| 'FilterBySplit' >> beam.Filter(688lambda x: x['split'] == target_split)689| 'DropSplitField' >> beam.Map(690functools.partial(691utils.drop_record_fields,692fields_to_drop=['split'])))693# Enumerates all pairs of regions sharing the same Pfam accession. Each694# region pair is processed to produce the final fields that will be used for695# training and evaluating the models on the pairwise alignment task.696region_pairs = (697filtered_regions
698| 'EnumerateAllFamilyPairs' >> utils.Combinations(699groupby_field='pfam_acc',700key_field='key',701num_samples=None,702suffixes=SUFFIXES)703| 'ProcessRegionPairs' >> beam.Map(704functools.partial(705process_region_pair_alignment,706flanks=flanks,707max_len=max_len,708extra_margin=extra_margin,709min_overlap=min_overlap,710global_seed=global_seed)))711# Writes postprocessed region pairs to disk as tab-delimited sharded text712# files.713_ = (714region_pairs
715| 'WriteToTable' >> schemas_lib.WriteToTable(716file_path_prefix=output_path,717schema_cls=schemas.PairwiseAlignmentRow))718
719return pipeline720
721
722def build_pfam_homology_pipeline(723file_pattern,724dataset_splits_path,725target_split,726output_path,727avg_num_samples,728prob_pos_different_family = 0.11,729prob_neg = 0.5,730max_len = 511,731flanks = None,732extra_margin = 0,733min_overlap = 1,734global_seed = 0,735):736"""3b) Returns a pipeline to generate samples for homology detection task.737
738Args:
739file_pattern: The file pattern from which to read preprocessed Pfam shards.
740This is assumed to be the result of steps 1a), 1b) and, optionally, step
7412) of the full preprocessing pipeline.
742See `preprocess_tables_lib.py` and `uniprot_flanks_lib.py` for additional
743details.
744dataset_splits_path: The path to the key, split mapping file.
745target_split: The dataset split for which to generate pairwise alignment
746data.
747output_path: The path prefix to the output files.
748avg_num_samples: The (expected) number of samples (sequence pairs) to
749subsample (homologous and non-homologous).
750prob_pos_different_family: The (expected) proportion of samples consisting
751of region pairs in the same clan but different families.
752prob_neg: The (expected) proportion of samples consisting of non-homologous
753region pairs, that is, regions in different clans.
754max_len: The maximum length of sequences to be included in the output
755dataset.
756flanks: The approach to be used add flanking sequences to Pfam
757regions. If `None`, no flanking sequences will be added. Supported modes
758include `synthetic` and `uniprot`.
759extra_margin: Extends sequence boundaries by `extra_margin` residues when
760evaluating overlap between annotations.
761min_overlap: The minimum number of residues in a sequence that need to
762overlap with a region annotation in order for the annotation to be applied
763to the sequence.
764global_seed: A global seed for the PRNG.
765
766Returns:
767A beam.Pipeline.
768"""
769def pipeline(root):770# Reads data preprocessed by steps 1a) and 1b) in `preprocess_tables_lib.py`771# and, optionally, step 2) in `uniprot_flanks_lib.py`.772with_flank_seeds = flanks == 'uniprot'773fields_to_keep = ALIGNMENT_FIELDS774if with_flank_seeds:775fields_to_keep += tuple(f[0] for f in schemas.FLANK_FIELDS)776regions = (777root
778| 'ReadParsedPfamData' >> ReadParsedPfamData(779file_pattern=file_pattern,780dataset_splits_path=dataset_splits_path,781fields_to_keep=fields_to_keep,782with_flank_seeds=with_flank_seeds,783max_len=max_len,784filter_by_qc=True))785# Filters sequences not belonging to the target split and removes the no786# longer needed split field.787filtered_regions = (788regions
789| 'FilterBySplit' >> beam.Filter(790lambda x: x['split'] == target_split)791| 'DropSplitField' >> beam.Map(792functools.partial(793utils.drop_record_fields,794fields_to_drop=['split'])))795
796# Enumerates a subsample of (on average) `avg_num_samples` pairs of797# homologous and non-homologous regions, keeping only the latter and adding798# a class label for each pair:799# + neg: homology_label = 0, if non-homologs (different clan).800# + mid: homology_label = 1, if remote homologs (same clan, different fams).801# + pos: homology_label = 2, if homologs (same family).802# This subsample will be heavily biased towards negative samples (neg).803neg_region_pairs = (804filtered_regions
805| 'EnumerateNegRegionPairs' >> utils.SubsampleOuterProduct(806avg_num_samples=avg_num_samples,807groupby_field=None,808key_field='key',809suffixes=SUFFIXES)810| 'KeepNegRegionPairs' >> beam.Filter(811lambda x: x['clan_acc_x'] != x['clan_acc_y'])812| 'AddHomologyLabelNegRegionPairs' >> beam.Map(add_homology_label))813
814# Enumerates a subsample of (on average) `avg_num_samples` pairs of815# homologous regions (either in the same family, or in the same clan but816# in different families), keeping only the latter and adding a class label817# for each pair.818# This subsample will be heavily biased towards samples in the same clan but819# in different families (mid).820mid_region_pairs = (821filtered_regions
822| 'EnumerateMidRegionPairs' >> utils.SubsampleOuterProduct(823avg_num_samples=avg_num_samples,824groupby_field='clan_acc',825key_field='key',826suffixes=SUFFIXES)827| 'KeepMidRegionPairs' >> beam.Filter(828lambda x: x['pfam_acc_x'] != x['pfam_acc_y'])829| 'AddHomologyLabelMidRegionPairs' >> beam.Map(add_homology_label))830
831# Enumerates all pairs of regions sharing the same Pfam accession, adding832# the corresponding homology label (2) to all of the resulting records.833pos_region_pairs = (834filtered_regions
835| 'EnumeratePosRegionPairs' >> utils.Combinations(836groupby_field='pfam_acc',837key_field='key',838num_samples=None,839suffixes=SUFFIXES)840| 'AddHomologyLabelPosRegionPairs' >> beam.Map(add_homology_label))841
842# Computes the number of region pairs in each category. We assume that843# homologs (same family), i.e., `count_pos`, is the smallest. That holds844# true for Pfam-A seed 34.0.845count_pos = pos_region_pairs | 'CountPos' >> beam.combiners.Count.Globally()846count_mid = mid_region_pairs | 'CountMid' >> beam.combiners.Count.Globally()847count_neg = neg_region_pairs | 'CountNeg' >> beam.combiners.Count.Globally()848
849# Downsamples each category to obtain data with the desired class label850# distribution.851prob_pos_same_family = 1.0 - prob_pos_different_family - prob_neg852assert prob_pos_same_family > 0853
854mid_region_pairs = (855mid_region_pairs
856| 'DownsampleMidRegionPairs' >> beam.FlatMap(857subsample_region_pairs,858beam.pvalue.AsSingleton(count_mid),859beam.pvalue.AsSingleton(count_pos),860resample_ratio=prob_pos_different_family / prob_pos_same_family,861global_seed=global_seed))862neg_region_pairs = (863neg_region_pairs
864| 'DownsampleNegRegionPairs' >> beam.FlatMap(865subsample_region_pairs,866beam.pvalue.AsSingleton(count_neg),867beam.pvalue.AsSingleton(count_pos),868resample_ratio=prob_neg / prob_pos_same_family,869global_seed=global_seed))870
871# Homologous and non-homologous regions are merged. The regions pairs are872# then processed to produce the final fields that will be used for training873# and evaluating the models on the pairwise homology detection task.874region_pairs = (875(pos_region_pairs, mid_region_pairs, neg_region_pairs)876| 'MergeAllClasses' >> beam.Flatten()877| 'ReshuffleAfterMerging' >> beam.Reshuffle()878| 'ProcessRegionPairs' >> beam.Map(879functools.partial(880process_region_pair_homology,881flanks=flanks,882max_len=max_len,883extra_margin=extra_margin,884min_overlap=min_overlap,885global_seed=global_seed)))886# Writes postprocessed region pairs to disk as tab-delimited sharded text887# files.888_ = (889region_pairs
890| 'WriteToTable' >> schemas_lib.WriteToTable(891file_path_prefix=output_path,892schema_cls=schemas.PairwiseHomologyRow))893
894return pipeline895