google-research

Форк
0
572 строки · 18.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Code for clustering strings by edit distance.
17

18
Includes exact and approximate strategies for clustering strings ("sequences")
19
of the same length based on Hamming distance. Developed for clustering DNA
20
sequences, but should be able to handle clustering for arbitrary strings.
21

22
To constrain the nearest neighbor search necessary for clustering, we use exact
23
and randomized approaches based on Locality Sensitive Hashing. Somewhat similar
24
approaches have been described in the research literature [1]. Assuming clusters
25
of constant size, this allows our clustering algorithms to run in linear time as
26
the number of sequences increases.
27

28
Typical usage example:
29

30
  >>> sequences = ['AAA', 'ATA', 'GGG']
31
  >>> clustering.cluster_by_edit_distance(sequences, edit_distance=1)
32
  [0, 0, 1]
33

34
References:
35
  [1] http://www.ncbi.nlm.nih.gov/pmc/articles/PMC4281958/
36
"""
37

38
import collections
39
import itertools
40
import logging
41
import math
42
import random
43

44

45
import Levenshtein
46
import numpy
47
import six
48

49
# Google internal
50
import gfile
51
import results_pb2
52
import sstable
53

54

55
_EMPTY_LIST = []
56

57

58
class AbstractMatcher:
59
  """Abstract base class for fast neighbor matching."""
60

61
  def match(self, sequence):
62
    """Return sequences in the vicinity of the given sequence.
63

64
    Beware: this method is only guaranteed to work for sequences that were used
65
    to initialize the AbstractMatcher. This is not checked.
66

67
    Args:
68
      sequence: string for which to lookup neighbors.
69

70
    Returns:
71
      Iterable of strings giving all sequences in the vicinity of the given
72
      sequence. This may include false positives, depending on the distance
73
      metric.
74
    """
75
    raise NotImplementedError
76

77

78
class _IntegerEncoder:
79
  """Build a encoding of the given keys as integers.
80

81
  Attributes:
82
    key_to_id: dict mapping keys to encoded integers.
83
    id_to_key: list mapping integer IDs to decoded keys.
84
  """
85

86
  def __init__(self):
87
    self.key_to_id = {}
88
    self.id_to_key = []
89

90
  def __getitem__(self, key):
91
    """Lookup the ID corresponding to the given key.
92

93
    If key does not yet have an ID, assign the next available integer ID.
94

95
    Args:
96
      key: hashable value.
97

98
    Returns:
99
      Integer ID.
100
    """
101
    try:
102
      return self.key_to_id[key]
103
    except KeyError:
104
      identifier = len(self.key_to_id)
105
      self.key_to_id[key] = identifier
106
      self.id_to_key.append(key)
107
      return identifier
108

109

110
class ScaMMatcher(AbstractMatcher):
111
  """Matcher that uses pre-computed lookup tables generated by ScaM."""
112

113
  def __init__(self, table, dtype='u4'):
114
    """Initialize as ScaMMatcher.
115

116
    Args:
117
      table: Mapping[str, result_pb2.NearestNeighbor] mapping each sequence to
118
        all of its (approximate) neighbors within some fixed edit distance.
119
      dtype: optional object convertable to numpy.dtype to use for storing
120
        positive integer IDs.
121

122
    Raises:
123
      ValueError: if dtype was not big enough.
124
    """
125
    neighbors = {}
126
    encoder = _IntegerEncoder()
127
    n_entries = len(table)
128

129
    for n, (sequence, value) in enumerate(table.items()):
130
      if n % 1000000 == 0 or (n < 1000000 and n % 100000 == 0):
131
        logging.info('loading ScaM results %r/%r', n, n_entries)
132

133
      neighbor_sequences = [neighbor.docid for neighbor in value.neighbor
134
                            if neighbor.docid != value.docid]
135
      if neighbor_sequences:
136
        neighbor_ids = numpy.array(
137
            [encoder[seq] for seq in neighbor_sequences], dtype=dtype)
138
        neighbors[encoder[sequence]] = neighbor_ids
139

140
    logging.info('finished loading ScaM results')
141

142
    if len(encoder.key_to_id) > numpy.iinfo(dtype).max:
143
      raise ValueError('ran out of integer IDs')
144

145
    self._neighbors = neighbors
146
    self._sequence_to_id = encoder.key_to_id
147
    self._id_to_sequence = encoder.id_to_key
148

149
  @classmethod
150
  def from_path(cls, pattern):
151
    """Create a ScaMMatcher from SSTables of ScaM NearestNeighbors results.
152

153
    Args:
154
      pattern: string pattern for paths to sstables holding output from the ScaM
155
        map-reduce.
156

157
    Returns:
158
      ScaMMatcher for doing lookups with these pre-computed neighbors.
159
    """
160
    paths = sorted(gfile.Glob(pattern))
161
    wrapper = sstable.TableWrapper(results_pb2.NearestNeighbors.FromString)
162
    table = sstable.ShardedSSTable(paths, wrapper=wrapper)
163
    return cls(table)
164

165
  def match(self, sequence):
166
    """See base class."""
167
    try:
168
      sequence_id = self._sequence_to_id[sequence]
169
    except KeyError:
170
      return _EMPTY_LIST
171
    else:
172
      return [self._id_to_sequence[id_] for id_ in self._neighbors[sequence_id]]
173

174

175
class HashMatcher(AbstractMatcher):
176
  """Match sequences using a hash table and a single hash function."""
177

178
  def __init__(self, sequences, hash_func, max_shift=0):
179
    """Initialize a HashMatcher.
180

181
    Args:
182
      sequences: a sequence of strings to match.
183
      hash_func: callable that maps a sequence to key corresponding to a hash
184
        bucket.
185
      max_shift: optional integer giving the maximum number of positional shifts
186
        to consider when partitioning the sequences.
187
    """
188
    buckets = {}
189
    for seq in sequences:
190
      for shift in range(max_shift + 1):
191
        shifted_seq = seq[shift:] + seq[:shift]
192
        key = hash_func(shifted_seq)
193
        # this is faster than using collections.defaultdict(list)
194
        if key in buckets:
195
          buckets[key].append(seq)
196
        else:
197
          buckets[key] = [seq]
198
    # filter out length one buckets to reduce memory requirements
199
    self._buckets = {k: v for k, v in buckets.items() if len(v) > 1}
200
    self._hash = hash_func
201

202
  def match(self, sequence):
203
    """See base class."""
204
    key = self._hash(sequence)
205
    # this lets us drop keys with only a single element
206
    return self._buckets.get(key, _EMPTY_LIST)
207

208

209
class LSHMatcher(AbstractMatcher):
210
  """Match sequences using Locality Sensitive Hashing."""
211

212
  def __init__(self, sequences, hash_functions, max_shift=0):
213
    """Initialize a LSHMatcher.
214

215
    Args:
216
      sequences: a sequence of sequences to partition.
217
      hash_functions: sequence of hash functions (callables) to use for
218
        partitioning sequences.
219
      max_shift: optional integer giving the maximum number of positional shifts
220
        to consider when partitioning.
221
    """
222
    self._partitions = [HashMatcher(sequences, func, max_shift)
223
                        for func in hash_functions]
224

225
  def match(self, sequence):
226
    """See base class."""
227
    neighbors = set()
228
    for group in self._partitions:
229
      neighbors.update(group.match(sequence))
230
    return neighbors
231

232

233
def exact_lsh_matches(sequences, edit_distance, measure='levenshtein',
234
                      target_occupancy=0.5, num_choices=4):
235
  """Build a callable for 'exact' matching of sequences using LSH.
236

237
  These matches are proven to be exact for Hamming distance, but we're not quite
238
  sure it's correct for Levenshtein distance.
239

240
  Args:
241
    sequences: sequence of strings to partition.
242
    edit_distance: maximum edit distance between any sequence and the closest
243
      other sequence in the same cluster.
244
    measure: optional string 'levenshtein' or 'hamming', indicating how to
245
      calculate edit distance.
246
    target_occupancy: float indicating the maximum acceptable average number
247
      of randomly chosen sequences that would appear in the same hash bucket
248
      by chance.
249
    num_choices: optional integer giving the number of valid sequence
250
      elements. Default value is 4, corresponding to the four base pairs in
251
      DNA.
252

253
  Returns:
254
    Callable for finding matches.
255
  """
256
  sequence_length = _unique_length(sequences)
257
  max_shift = _max_shift(edit_distance, measure)
258
  min_hash_length = optimal_hash_length(
259
      len(sequences), max_shift, target_occupancy, num_choices)
260
  segment_count = _required_segment_count(
261
      sequence_length, min_hash_length, edit_distance)
262
  all_segments = itertools.combinations(
263
      list(range(segment_count)), segment_count - edit_distance)
264
  hashes = (segmented_hash(segments, segment_count, sequence_length)
265
            for segments in all_segments)
266
  return LSHMatcher(sequences, hashes, max_shift).match
267

268

269
def approximate_lsh_matches(sequences, edit_distance, measure='levenshtein',
270
                            hash_length=10, num_rounds=10, seed=None):
271
  """Build a callable for matching of sequences using LSH.
272

273
  This approach was designed for Hamming distance. It may perform very poorly
274
  for Levenshtein distance.
275

276
  TODO(shoyer): refactor this API to take a desired success_probability
277
  instead of this lower level API.
278

279
  Args:
280
    sequences: sequence of strings to partition.
281
    edit_distance: maximum edit distance between any sequence and the closest
282
      other sequence in the same cluster.
283
    measure: optional string 'levenshtein' or 'hamming', indicating how to
284
      calculate edit distance.
285
    hash_length: integer number of bases from each sequence to use in the hash
286
      key.
287
    num_rounds: integer number of random partitions to create.
288
    seed: optional hashable random seed to guarantee reproducible results when
289
      calling the `partition` method.
290

291
  Returns:
292
    Callable for finding matches.
293
  """
294
  sequence_length = _unique_length(sequences)
295
  rand = random.Random(seed)
296
  hashes = (random_hash(hash_length, sequence_length, rand.random())
297
            for _ in range(num_rounds))
298
  max_shift = _max_shift(edit_distance, measure)
299
  return LSHMatcher(sequences, hashes, max_shift).match
300

301

302
def _unique_length(elements):
303
  """Calculate the unique length of the provided elements.
304

305
  Args:
306
    elements: iterable of objects with a defined length.
307

308
  Returns:
309
    Integer unique length.
310

311
  Raises:
312
    ValueError: if there is no unique length, or if the iterable is empty.
313
  """
314
  lengths = set(len(elem) for elem in elements)
315
  if not lengths:
316
    raise ValueError('no sequences provided')
317
  length = lengths.pop()
318
  if lengths:
319
    raise ValueError('sequences to cluster must have a unique length')
320
  return length
321

322

323
def optimal_hash_length(num_sequences, max_shift=0, target_occupancy=0.5,
324
                        num_choices=4):
325
  """Calculate the shortest hash length such that random collisions are rare.
326

327
  Args:
328
    num_sequences: integer number of distinct sequences to partition.
329
    max_shift: optional integer giving the maximum number of positional shifts
330
      to consider when partitioning.
331
    target_occupancy: float indicating the maximum acceptable average number of
332
      randomly chosen sequences that would appear in the same hash bucket by
333
      chance.
334
    num_choices: optional integer giving the number of valid sequence elements.
335
      Default value is 4, corresponding to the four base pairs in DNA.
336

337
  Returns:
338
    Integer giving the optimal hash length.
339
  """
340
  return int(math.ceil(math.log(float(num_sequences * (max_shift + 1))
341
                                / target_occupancy)
342
                       / math.log(num_choices)))
343

344

345
def _required_segment_count(sequence_length, min_hash_length, edit_distance):
346
  # The length of each hash, which is constructed from
347
  # (segment_count - edit_distance) out of segement_count segments, should be at
348
  # least min_hash_length. This means that we need to satisfy:
349
  #   sequence_length * (segment_count - edit_distance) / segment_count
350
  #     >= min_hash_length
351
  #
352
  # From some algebra, it follows that segment_count should be given by:
353
  return max(int(math.ceil(edit_distance /
354
                           (1 - float(min_hash_length) / sequence_length))), 1)
355

356

357
def segmented_hash(segments, segment_count, sequence_length):
358
  """Create a hash function for segmented partitioning.
359

360
  Args:
361
    segments: tuple of integers indicating segments to include in the hash.
362
    segment_count: integer indicating the total number of segments.
363
    sequence_length: integer length of all sequences in this partition.
364

365
  Returns:
366
    Hash function suitable for partitioning sequences with an ExactStrategy.
367
  """
368
  segment_length = int(sequence_length / segment_count)
369
  starts = [segment * segment_length for segment in segments]
370
  stops = [start + segment_length for start in starts]
371
  slices = [slice(start, stop) for start, stop in zip(starts, stops)]
372

373
  def hash_func(sequence):
374
    return ''.join(sequence[sl] for sl in slices)
375

376
  return hash_func
377

378

379
def random_hash(hash_length, sequence_length, seed=None):
380
  """Create a hash function for randomized partitioning.
381

382
  Args:
383
    hash_length: integer number of bases from each sequence to use in the hash
384
      key.
385
    sequence_length: integer length of all sequences in this partition.
386
    seed: optional hashable random seed to guarantee reproducible results.
387

388
  Returns:
389
    Hash function suitable for partitioning sequences with an
390
    ApproximateStrategy.
391
  """
392
  rand = random.Random(seed)
393
  indices = sorted(rand.sample(range(sequence_length), hash_length))
394

395
  def hash_func(sequence):
396
    return ''.join(sequence[idx] for idx in indices)
397

398
  return hash_func
399

400

401
def _max_shift(edit_distance, measure='levenshtein'):
402
  """Determine the maximum number of shifted positions at an edit distance.
403

404
  Args:
405
    edit_distance: maximum edit distance between sequences.
406
    measure: optional string 'levenshtein' or 'hamming', indicating how to
407
      calculate edit distance.
408

409
  Returns:
410
    Integer giving the largest number of posible shifts in position between the
411
    two sequences at the given edit distance.
412

413
  Raises:
414
    ValueError: if `measure` is invalid.
415
  """
416
  if measure == 'levenshtein':
417
    # as long as all sequences are restricted to have the same length, it
418
    # requires two edits (one insertion and one deletion) shift bases by one
419
    # position
420
    max_shift = edit_distance // 2
421
  elif measure == 'hamming':
422
    # shifting isn't necessary for Hamming distance
423
    max_shift = 0
424
  else:
425
    raise ValueError('unexpected measure %r' % measure)
426
  return max_shift
427

428

429
def hamming_distance(seq1, seq2):
430
  """Compute the Hamming distance between two sequences.
431

432
  This is the edit distance from s1 to s2, ignoring insertions/deletions.
433

434
  Args:
435
    seq1: sequence 1.
436
    seq2: sequence 2.
437

438
  Returns:
439
    An integer giving the Hamming distance.
440
  """
441
  return Levenshtein.hamming(seq1, seq2)
442

443

444
def levenshtein_distance(seq1, seq2):
445
  """Compute the Levenshtein distance between two sequences.
446

447
  This is the edit distance from s1 to s2, allowing for insertions/deletions.
448

449
  Args:
450
    seq1: sequence 1.
451
    seq2: sequence 2.
452

453
  Returns:
454
    An integer giving the Levenshtein distance.
455
  """
456
  seq1 = six.ensure_str(seq1)
457
  seq2 = six.ensure_str(seq2)
458
  return Levenshtein.distance(seq1, seq2)
459

460

461
def explore_cluster(seed_sequence, edit_distance, measure, find_nearby):
462
  """Returns all sequences in the same cluster as the provided sequence.
463

464
  Args:
465
    seed_sequence: string around which to find neighbors.
466
    edit_distance: maximum edit distance between any sequence and the closest
467
      other sequence in the same cluster.
468
    measure: string 'levenshtein' or 'hamming', indicating how to calculate edit
469
      distance.
470
    find_nearby: callable that returns an iterable of all nearby sequences to a
471
      given sequence. These sequences are checked to see if they fall within the
472
      distance threshold.
473

474
  Returns:
475
    The set of all sequences in the same cluster as seed_sequence, including the
476
    seed sequence itself.
477

478
  Raises:
479
    ValueError: if `measure` is invalid.
480
  """
481
  if measure == 'levenshtein':
482
    calc_distance = levenshtein_distance
483
  elif measure == 'hamming':
484
    calc_distance = hamming_distance
485
  else:
486
    raise ValueError('unexpected measure %r' % measure)
487

488
  # depth first search
489
  cluster = set()
490
  not_yet_explored = set([seed_sequence])
491

492
  while not_yet_explored:
493
    sequence = not_yet_explored.pop()
494
    cluster.add(sequence)
495

496
    for candidate in find_nearby(sequence):
497
      if (candidate not in cluster and
498
          candidate not in not_yet_explored and
499
          calc_distance(candidate, sequence) <= edit_distance):
500
        not_yet_explored.add(candidate)
501

502
  return cluster
503

504

505
def cluster_by_edit_distance(sequences, edit_distance, measure='levenshtein',
506
                             find_nearby=None):
507
  """Cluster strings by edit distance.
508

509
  Every sequence in a cluster can be reached from other sequence in the same
510
  cluster by taking steps no larger than the provided Hamming distance.
511

512
  Args:
513
    sequences: a sequence of strings with the same length, e.g., ['ATGC',
514
      'AATT', 'GCCG'].
515
    edit_distance: maximum edit distance between any sequence and the closest
516
      other sequence in the same cluster.
517
    measure: optional string 'levenshtein' or 'hamming', indicating how to
518
      calculate edit distance.
519
    find_nearby: optional callable that returns an iterable of all nearby
520
      sequences to a given sequence. These sequences are checked to see if they
521
      fall within the distance threshold. By default, returns a callable created
522
      by calling `exact_lsh_matches` on the provided sequences.
523

524
  Returns:
525
    A list of integer cluster IDs corresponding to each sequence. Cluster IDs
526
    are sequential integers starting from one and sorted by order of
527
    appearance.
528

529
  Raises:
530
    ValueError: if not all sequences have the same length.
531
  """
532

533
  if find_nearby is None:
534
    find_nearby = exact_lsh_matches(sequences, edit_distance, measure)
535

536
  # start from 1 so a default value of 0 corresponds to unclustered.
537
  cluster_idx = 1
538
  cluster_assignments = {}  # {sequence: cluster_id}
539

540
  for i, seed in enumerate(sequences):
541
    if seed not in cluster_assignments:
542
      for seq in explore_cluster(seed, edit_distance, measure, find_nearby):
543
        cluster_assignments[seq] = cluster_idx
544
      cluster_idx += 1
545
    if i % 1000000 == 0 or (i < 1000000 and i % 100000 == 0):
546
      logging.info('assigned clusters to first %r/%r sequences', i,
547
                   len(sequences))
548

549
  return [cluster_assignments[seq] for seq in sequences]
550

551

552
def set_of_cluster_sets(sequences, clusters):
553
  """Convert lists of sequences and cluster assignments to a set of sets.
554

555
  This provides an order-invariant way to compare cluster assignments.
556

557
  Args:
558
    sequences: list of strings indicating sequences.
559
    clusters: list of integers indicating cluster assignments.
560

561
  Returns:
562
    frozenset of frozensets of clustered sequences.
563

564
  Example:
565

566
    >>> set_of_cluster_sets(['AAA', 'AGA', 'GGG'], [0, 0, 1])
567
    frozenset([frozenset(['AAA', 'AGA']), frozenset(['GGG'])])
568
  """
569
  all_clusters = collections.defaultdict(set)
570
  for sequence, cluster in zip(sequences, clusters):
571
    all_clusters[cluster].add(sequence)
572
  return frozenset(frozenset(seqs) for seqs in all_clusters.values())
573

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.