google-research

Форк
0
/
data_utils.py 
173 строки · 5.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Data utils for the clustering with strong and weak signals."""
17

18
import gin
19
import numpy as np
20
from scipy.sparse import csr_matrix
21

22

23
def read_numpy(filename, allow_pickle=False):
24
  with open(filename, 'rb') as f:
25
    return np.load(f, allow_pickle=allow_pickle)
26

27

28
def load_weak_and_strong_signals(path, name):
29
  """Loads weak and strong signals for dataset class."""
30
  if name in ['stackoverflow', 'search_snippets']:
31
    features = read_numpy(f'{path}/{name}_weak_signal')
32
    weak_signal = features.dot(features.T)
33
    strong_signal = read_numpy(f'{path}/{name}_strong_signal', True)
34
    strong_signal = strong_signal[()].toarray()
35
  else:
36
    weak_signal = read_numpy(f'{path}/{name}_weak_signal')
37
    strong_signal = read_numpy(f'{path}/{name}_strong_signal')
38
  return weak_signal, strong_signal
39

40

41
@gin.configurable
42
class Dataset:
43
  """Dataset class."""
44

45
  def __init__(self, path, name):
46
    self.get_weak_and_strong_signals(path, name)
47

48
  def get_weak_and_strong_signals(self, path, name):
49
    # Assumption: weak_signal is a NxN numpy array of similarities.
50
    # Assumption: strong signal is a NxN numpy array of similarities.
51
    self.weak_signal, self.strong_signal = load_weak_and_strong_signals(
52
        path, name
53
    )
54
    assert self.weak_signal.shape == self.strong_signal.shape
55
    self.is_graph = False
56
    self.is_sparse = False
57

58
  @property
59
  def num_examples(self):
60
    return self.strong_signal.shape[0]
61

62
  @property
63
  def num_pairs(self):
64
    return (self.num_examples * (self.num_examples - 1)) / 2
65

66
  def same_cluster(self, ex_id1, ex_id2):
67
    return self.strong_signal[ex_id1][ex_id2] == 1
68

69
  def pair_same_cluster_iterator(self):
70
    for ex_id1 in range(self.num_examples):
71
      for ex_id2 in range(ex_id1 + 1, self.num_examples):
72
        yield ex_id1, ex_id2, self.same_cluster(ex_id1, ex_id2)
73

74
  def most_similar_pairs(self, k):
75
    """Selects the top_k most similar indices from the weak_signal matrix.
76

77
    Consider a similarity matrix as follows:
78
    [[ 1  2  3  4]
79
     [ 2  4  5  6]
80
     [ 3  5  2  3]
81
     [ 4  6  3  2]]
82
    And assume we want to select the indices of the top 3 values.
83
    We first triangualrize the matrix:
84
    [[0 2 3 4]
85
     [0 0 5 6]
86
     [0 0 0 3]
87
     [0 0 0 0]]
88
    The ravel function then linearizes the data:
89
    [ 0 -2 -3 -4  0  0 -5 -6  0  0  0 -3  0  0  0  0]
90
    The argpartition selects the index of the top 3 elements:
91
    [3 6 7]
92
    The unravel function then turns the 1d indexes into 2d indexes:
93
    [[0 3], [1 2], [1 3]]
94

95
    Args:
96
      k: for top_k
97

98
    Returns:
99
      Indices of the top k elements
100
    """
101
    triangular_sims = np.triu(self.weak_signal, k=1)
102
    idx = np.argpartition(-triangular_sims.ravel(), k)[:k]
103
    return list(np.column_stack(np.unravel_index(idx, triangular_sims.shape)))
104

105
  def argsort_similarities(self, similarities):
106
    """Arg sorts an array of similarities row-wise in decreasing order."""
107
    return np.argsort(similarities)[:, ::-1]
108

109
  def k_nearest_neighbors_weak_signal(self, input_nodes, k):
110
    """Returns k nearest neighbors of input nodes according to weak signal."""
111
    # Weak signal is cosine similarities of embeddings.
112
    # Higher values indicate closer neighbors.
113
    cosine_similarities = self.weak_signal[input_nodes, :]
114
    if self.is_sparse:
115
      cosine_similarities = cosine_similarities.toarray()
116
    return self.argsort_similarities(cosine_similarities)[:, 0 : k + 1]
117

118
  def symmetrize_graph(self, graph):
119
    """Makes graph (csr matrix) symmetric."""
120
    rows, cols = graph.nonzero()
121
    graph[cols, rows] = graph[rows, cols]
122
    return graph
123

124
  def construct_weak_signal_knn_graph(self, k):
125
    """Make k-nn graph according to weak_signal similarity matrix."""
126
    rows, cols = [], []
127
    for i in range(self.num_examples):
128
      k_neighbors = self.k_nearest_neighbors_weak_signal([i], k)[0, :]
129
      for neighbor in k_neighbors:
130
        if neighbor != i:
131
          rows.append(i)
132
          cols.append(neighbor)
133
    graph = csr_matrix(
134
        ([1] * len(rows), (rows, cols)),
135
        shape=(self.num_examples, self.num_examples),
136
    )
137
    # Procedure to make graph symmetric.
138
    graph = self.symmetrize_graph(graph)
139
    return graph
140

141
  def reweight_graph_using_strong_signal(self, possible_edges):
142
    """Reweight given edges using strong signal."""
143
    filtered_rows = []
144
    filtered_columns = []
145
    for v, u in possible_edges:
146
      if self.strong_signal[v, u]:
147
        filtered_rows.append(v)
148
        filtered_columns.append(u)
149
    graph = csr_matrix(
150
        ([1] * len(filtered_rows), (filtered_rows, filtered_columns)),
151
        shape=(self.num_examples, self.num_examples),
152
    )
153
    graph = self.symmetrize_graph(graph)
154
    return graph
155

156
  def construct_weighted_knn_graph(self, k):
157
    """Get weak signal knn graph and reweight edges using strong signal."""
158
    weak_signal_knn_graph = self.construct_weak_signal_knn_graph(k)
159
    rows, columns = weak_signal_knn_graph.nonzero()
160
    strong_signal_weighted_knn_graph = self.reweight_graph_using_strong_signal(
161
        zip(rows, columns)
162
    )
163
    return strong_signal_weighted_knn_graph
164

165

166
class AdhocDataset:
167
  """Adhoc dataset class for creating a graph dataset on the fly."""
168

169
  def __init__(self, strong_signal, features):
170
    self.strong_signal = strong_signal
171
    self.is_graph = True
172
    self.features = features
173
    self.num_examples = strong_signal.shape[0]
174

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.