google-research
572 строки · 18.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Code for clustering strings by edit distance.
17
18Includes exact and approximate strategies for clustering strings ("sequences")
19of the same length based on Hamming distance. Developed for clustering DNA
20sequences, but should be able to handle clustering for arbitrary strings.
21
22To constrain the nearest neighbor search necessary for clustering, we use exact
23and randomized approaches based on Locality Sensitive Hashing. Somewhat similar
24approaches have been described in the research literature [1]. Assuming clusters
25of constant size, this allows our clustering algorithms to run in linear time as
26the number of sequences increases.
27
28Typical usage example:
29
30>>> sequences = ['AAA', 'ATA', 'GGG']
31>>> clustering.cluster_by_edit_distance(sequences, edit_distance=1)
32[0, 0, 1]
33
34References:
35[1] http://www.ncbi.nlm.nih.gov/pmc/articles/PMC4281958/
36"""
37
38import collections
39import itertools
40import logging
41import math
42import random
43
44
45import Levenshtein
46import numpy
47import six
48
49# Google internal
50import gfile
51import results_pb2
52import sstable
53
54
55_EMPTY_LIST = []
56
57
58class AbstractMatcher:
59"""Abstract base class for fast neighbor matching."""
60
61def match(self, sequence):
62"""Return sequences in the vicinity of the given sequence.
63
64Beware: this method is only guaranteed to work for sequences that were used
65to initialize the AbstractMatcher. This is not checked.
66
67Args:
68sequence: string for which to lookup neighbors.
69
70Returns:
71Iterable of strings giving all sequences in the vicinity of the given
72sequence. This may include false positives, depending on the distance
73metric.
74"""
75raise NotImplementedError
76
77
78class _IntegerEncoder:
79"""Build a encoding of the given keys as integers.
80
81Attributes:
82key_to_id: dict mapping keys to encoded integers.
83id_to_key: list mapping integer IDs to decoded keys.
84"""
85
86def __init__(self):
87self.key_to_id = {}
88self.id_to_key = []
89
90def __getitem__(self, key):
91"""Lookup the ID corresponding to the given key.
92
93If key does not yet have an ID, assign the next available integer ID.
94
95Args:
96key: hashable value.
97
98Returns:
99Integer ID.
100"""
101try:
102return self.key_to_id[key]
103except KeyError:
104identifier = len(self.key_to_id)
105self.key_to_id[key] = identifier
106self.id_to_key.append(key)
107return identifier
108
109
110class ScaMMatcher(AbstractMatcher):
111"""Matcher that uses pre-computed lookup tables generated by ScaM."""
112
113def __init__(self, table, dtype='u4'):
114"""Initialize as ScaMMatcher.
115
116Args:
117table: Mapping[str, result_pb2.NearestNeighbor] mapping each sequence to
118all of its (approximate) neighbors within some fixed edit distance.
119dtype: optional object convertable to numpy.dtype to use for storing
120positive integer IDs.
121
122Raises:
123ValueError: if dtype was not big enough.
124"""
125neighbors = {}
126encoder = _IntegerEncoder()
127n_entries = len(table)
128
129for n, (sequence, value) in enumerate(table.items()):
130if n % 1000000 == 0 or (n < 1000000 and n % 100000 == 0):
131logging.info('loading ScaM results %r/%r', n, n_entries)
132
133neighbor_sequences = [neighbor.docid for neighbor in value.neighbor
134if neighbor.docid != value.docid]
135if neighbor_sequences:
136neighbor_ids = numpy.array(
137[encoder[seq] for seq in neighbor_sequences], dtype=dtype)
138neighbors[encoder[sequence]] = neighbor_ids
139
140logging.info('finished loading ScaM results')
141
142if len(encoder.key_to_id) > numpy.iinfo(dtype).max:
143raise ValueError('ran out of integer IDs')
144
145self._neighbors = neighbors
146self._sequence_to_id = encoder.key_to_id
147self._id_to_sequence = encoder.id_to_key
148
149@classmethod
150def from_path(cls, pattern):
151"""Create a ScaMMatcher from SSTables of ScaM NearestNeighbors results.
152
153Args:
154pattern: string pattern for paths to sstables holding output from the ScaM
155map-reduce.
156
157Returns:
158ScaMMatcher for doing lookups with these pre-computed neighbors.
159"""
160paths = sorted(gfile.Glob(pattern))
161wrapper = sstable.TableWrapper(results_pb2.NearestNeighbors.FromString)
162table = sstable.ShardedSSTable(paths, wrapper=wrapper)
163return cls(table)
164
165def match(self, sequence):
166"""See base class."""
167try:
168sequence_id = self._sequence_to_id[sequence]
169except KeyError:
170return _EMPTY_LIST
171else:
172return [self._id_to_sequence[id_] for id_ in self._neighbors[sequence_id]]
173
174
175class HashMatcher(AbstractMatcher):
176"""Match sequences using a hash table and a single hash function."""
177
178def __init__(self, sequences, hash_func, max_shift=0):
179"""Initialize a HashMatcher.
180
181Args:
182sequences: a sequence of strings to match.
183hash_func: callable that maps a sequence to key corresponding to a hash
184bucket.
185max_shift: optional integer giving the maximum number of positional shifts
186to consider when partitioning the sequences.
187"""
188buckets = {}
189for seq in sequences:
190for shift in range(max_shift + 1):
191shifted_seq = seq[shift:] + seq[:shift]
192key = hash_func(shifted_seq)
193# this is faster than using collections.defaultdict(list)
194if key in buckets:
195buckets[key].append(seq)
196else:
197buckets[key] = [seq]
198# filter out length one buckets to reduce memory requirements
199self._buckets = {k: v for k, v in buckets.items() if len(v) > 1}
200self._hash = hash_func
201
202def match(self, sequence):
203"""See base class."""
204key = self._hash(sequence)
205# this lets us drop keys with only a single element
206return self._buckets.get(key, _EMPTY_LIST)
207
208
209class LSHMatcher(AbstractMatcher):
210"""Match sequences using Locality Sensitive Hashing."""
211
212def __init__(self, sequences, hash_functions, max_shift=0):
213"""Initialize a LSHMatcher.
214
215Args:
216sequences: a sequence of sequences to partition.
217hash_functions: sequence of hash functions (callables) to use for
218partitioning sequences.
219max_shift: optional integer giving the maximum number of positional shifts
220to consider when partitioning.
221"""
222self._partitions = [HashMatcher(sequences, func, max_shift)
223for func in hash_functions]
224
225def match(self, sequence):
226"""See base class."""
227neighbors = set()
228for group in self._partitions:
229neighbors.update(group.match(sequence))
230return neighbors
231
232
233def exact_lsh_matches(sequences, edit_distance, measure='levenshtein',
234target_occupancy=0.5, num_choices=4):
235"""Build a callable for 'exact' matching of sequences using LSH.
236
237These matches are proven to be exact for Hamming distance, but we're not quite
238sure it's correct for Levenshtein distance.
239
240Args:
241sequences: sequence of strings to partition.
242edit_distance: maximum edit distance between any sequence and the closest
243other sequence in the same cluster.
244measure: optional string 'levenshtein' or 'hamming', indicating how to
245calculate edit distance.
246target_occupancy: float indicating the maximum acceptable average number
247of randomly chosen sequences that would appear in the same hash bucket
248by chance.
249num_choices: optional integer giving the number of valid sequence
250elements. Default value is 4, corresponding to the four base pairs in
251DNA.
252
253Returns:
254Callable for finding matches.
255"""
256sequence_length = _unique_length(sequences)
257max_shift = _max_shift(edit_distance, measure)
258min_hash_length = optimal_hash_length(
259len(sequences), max_shift, target_occupancy, num_choices)
260segment_count = _required_segment_count(
261sequence_length, min_hash_length, edit_distance)
262all_segments = itertools.combinations(
263list(range(segment_count)), segment_count - edit_distance)
264hashes = (segmented_hash(segments, segment_count, sequence_length)
265for segments in all_segments)
266return LSHMatcher(sequences, hashes, max_shift).match
267
268
269def approximate_lsh_matches(sequences, edit_distance, measure='levenshtein',
270hash_length=10, num_rounds=10, seed=None):
271"""Build a callable for matching of sequences using LSH.
272
273This approach was designed for Hamming distance. It may perform very poorly
274for Levenshtein distance.
275
276TODO(shoyer): refactor this API to take a desired success_probability
277instead of this lower level API.
278
279Args:
280sequences: sequence of strings to partition.
281edit_distance: maximum edit distance between any sequence and the closest
282other sequence in the same cluster.
283measure: optional string 'levenshtein' or 'hamming', indicating how to
284calculate edit distance.
285hash_length: integer number of bases from each sequence to use in the hash
286key.
287num_rounds: integer number of random partitions to create.
288seed: optional hashable random seed to guarantee reproducible results when
289calling the `partition` method.
290
291Returns:
292Callable for finding matches.
293"""
294sequence_length = _unique_length(sequences)
295rand = random.Random(seed)
296hashes = (random_hash(hash_length, sequence_length, rand.random())
297for _ in range(num_rounds))
298max_shift = _max_shift(edit_distance, measure)
299return LSHMatcher(sequences, hashes, max_shift).match
300
301
302def _unique_length(elements):
303"""Calculate the unique length of the provided elements.
304
305Args:
306elements: iterable of objects with a defined length.
307
308Returns:
309Integer unique length.
310
311Raises:
312ValueError: if there is no unique length, or if the iterable is empty.
313"""
314lengths = set(len(elem) for elem in elements)
315if not lengths:
316raise ValueError('no sequences provided')
317length = lengths.pop()
318if lengths:
319raise ValueError('sequences to cluster must have a unique length')
320return length
321
322
323def optimal_hash_length(num_sequences, max_shift=0, target_occupancy=0.5,
324num_choices=4):
325"""Calculate the shortest hash length such that random collisions are rare.
326
327Args:
328num_sequences: integer number of distinct sequences to partition.
329max_shift: optional integer giving the maximum number of positional shifts
330to consider when partitioning.
331target_occupancy: float indicating the maximum acceptable average number of
332randomly chosen sequences that would appear in the same hash bucket by
333chance.
334num_choices: optional integer giving the number of valid sequence elements.
335Default value is 4, corresponding to the four base pairs in DNA.
336
337Returns:
338Integer giving the optimal hash length.
339"""
340return int(math.ceil(math.log(float(num_sequences * (max_shift + 1))
341/ target_occupancy)
342/ math.log(num_choices)))
343
344
345def _required_segment_count(sequence_length, min_hash_length, edit_distance):
346# The length of each hash, which is constructed from
347# (segment_count - edit_distance) out of segement_count segments, should be at
348# least min_hash_length. This means that we need to satisfy:
349# sequence_length * (segment_count - edit_distance) / segment_count
350# >= min_hash_length
351#
352# From some algebra, it follows that segment_count should be given by:
353return max(int(math.ceil(edit_distance /
354(1 - float(min_hash_length) / sequence_length))), 1)
355
356
357def segmented_hash(segments, segment_count, sequence_length):
358"""Create a hash function for segmented partitioning.
359
360Args:
361segments: tuple of integers indicating segments to include in the hash.
362segment_count: integer indicating the total number of segments.
363sequence_length: integer length of all sequences in this partition.
364
365Returns:
366Hash function suitable for partitioning sequences with an ExactStrategy.
367"""
368segment_length = int(sequence_length / segment_count)
369starts = [segment * segment_length for segment in segments]
370stops = [start + segment_length for start in starts]
371slices = [slice(start, stop) for start, stop in zip(starts, stops)]
372
373def hash_func(sequence):
374return ''.join(sequence[sl] for sl in slices)
375
376return hash_func
377
378
379def random_hash(hash_length, sequence_length, seed=None):
380"""Create a hash function for randomized partitioning.
381
382Args:
383hash_length: integer number of bases from each sequence to use in the hash
384key.
385sequence_length: integer length of all sequences in this partition.
386seed: optional hashable random seed to guarantee reproducible results.
387
388Returns:
389Hash function suitable for partitioning sequences with an
390ApproximateStrategy.
391"""
392rand = random.Random(seed)
393indices = sorted(rand.sample(range(sequence_length), hash_length))
394
395def hash_func(sequence):
396return ''.join(sequence[idx] for idx in indices)
397
398return hash_func
399
400
401def _max_shift(edit_distance, measure='levenshtein'):
402"""Determine the maximum number of shifted positions at an edit distance.
403
404Args:
405edit_distance: maximum edit distance between sequences.
406measure: optional string 'levenshtein' or 'hamming', indicating how to
407calculate edit distance.
408
409Returns:
410Integer giving the largest number of posible shifts in position between the
411two sequences at the given edit distance.
412
413Raises:
414ValueError: if `measure` is invalid.
415"""
416if measure == 'levenshtein':
417# as long as all sequences are restricted to have the same length, it
418# requires two edits (one insertion and one deletion) shift bases by one
419# position
420max_shift = edit_distance // 2
421elif measure == 'hamming':
422# shifting isn't necessary for Hamming distance
423max_shift = 0
424else:
425raise ValueError('unexpected measure %r' % measure)
426return max_shift
427
428
429def hamming_distance(seq1, seq2):
430"""Compute the Hamming distance between two sequences.
431
432This is the edit distance from s1 to s2, ignoring insertions/deletions.
433
434Args:
435seq1: sequence 1.
436seq2: sequence 2.
437
438Returns:
439An integer giving the Hamming distance.
440"""
441return Levenshtein.hamming(seq1, seq2)
442
443
444def levenshtein_distance(seq1, seq2):
445"""Compute the Levenshtein distance between two sequences.
446
447This is the edit distance from s1 to s2, allowing for insertions/deletions.
448
449Args:
450seq1: sequence 1.
451seq2: sequence 2.
452
453Returns:
454An integer giving the Levenshtein distance.
455"""
456seq1 = six.ensure_str(seq1)
457seq2 = six.ensure_str(seq2)
458return Levenshtein.distance(seq1, seq2)
459
460
461def explore_cluster(seed_sequence, edit_distance, measure, find_nearby):
462"""Returns all sequences in the same cluster as the provided sequence.
463
464Args:
465seed_sequence: string around which to find neighbors.
466edit_distance: maximum edit distance between any sequence and the closest
467other sequence in the same cluster.
468measure: string 'levenshtein' or 'hamming', indicating how to calculate edit
469distance.
470find_nearby: callable that returns an iterable of all nearby sequences to a
471given sequence. These sequences are checked to see if they fall within the
472distance threshold.
473
474Returns:
475The set of all sequences in the same cluster as seed_sequence, including the
476seed sequence itself.
477
478Raises:
479ValueError: if `measure` is invalid.
480"""
481if measure == 'levenshtein':
482calc_distance = levenshtein_distance
483elif measure == 'hamming':
484calc_distance = hamming_distance
485else:
486raise ValueError('unexpected measure %r' % measure)
487
488# depth first search
489cluster = set()
490not_yet_explored = set([seed_sequence])
491
492while not_yet_explored:
493sequence = not_yet_explored.pop()
494cluster.add(sequence)
495
496for candidate in find_nearby(sequence):
497if (candidate not in cluster and
498candidate not in not_yet_explored and
499calc_distance(candidate, sequence) <= edit_distance):
500not_yet_explored.add(candidate)
501
502return cluster
503
504
505def cluster_by_edit_distance(sequences, edit_distance, measure='levenshtein',
506find_nearby=None):
507"""Cluster strings by edit distance.
508
509Every sequence in a cluster can be reached from other sequence in the same
510cluster by taking steps no larger than the provided Hamming distance.
511
512Args:
513sequences: a sequence of strings with the same length, e.g., ['ATGC',
514'AATT', 'GCCG'].
515edit_distance: maximum edit distance between any sequence and the closest
516other sequence in the same cluster.
517measure: optional string 'levenshtein' or 'hamming', indicating how to
518calculate edit distance.
519find_nearby: optional callable that returns an iterable of all nearby
520sequences to a given sequence. These sequences are checked to see if they
521fall within the distance threshold. By default, returns a callable created
522by calling `exact_lsh_matches` on the provided sequences.
523
524Returns:
525A list of integer cluster IDs corresponding to each sequence. Cluster IDs
526are sequential integers starting from one and sorted by order of
527appearance.
528
529Raises:
530ValueError: if not all sequences have the same length.
531"""
532
533if find_nearby is None:
534find_nearby = exact_lsh_matches(sequences, edit_distance, measure)
535
536# start from 1 so a default value of 0 corresponds to unclustered.
537cluster_idx = 1
538cluster_assignments = {} # {sequence: cluster_id}
539
540for i, seed in enumerate(sequences):
541if seed not in cluster_assignments:
542for seq in explore_cluster(seed, edit_distance, measure, find_nearby):
543cluster_assignments[seq] = cluster_idx
544cluster_idx += 1
545if i % 1000000 == 0 or (i < 1000000 and i % 100000 == 0):
546logging.info('assigned clusters to first %r/%r sequences', i,
547len(sequences))
548
549return [cluster_assignments[seq] for seq in sequences]
550
551
552def set_of_cluster_sets(sequences, clusters):
553"""Convert lists of sequences and cluster assignments to a set of sets.
554
555This provides an order-invariant way to compare cluster assignments.
556
557Args:
558sequences: list of strings indicating sequences.
559clusters: list of integers indicating cluster assignments.
560
561Returns:
562frozenset of frozensets of clustered sequences.
563
564Example:
565
566>>> set_of_cluster_sets(['AAA', 'AGA', 'GGG'], [0, 0, 1])
567frozenset([frozenset(['AAA', 'AGA']), frozenset(['GGG'])])
568"""
569all_clusters = collections.defaultdict(set)
570for sequence, cluster in zip(sequences, clusters):
571all_clusters[cluster].add(sequence)
572return frozenset(frozenset(seqs) for seqs in all_clusters.values())
573