google-research
173 строки · 5.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Data utils for the clustering with strong and weak signals."""
17
18import gin
19import numpy as np
20from scipy.sparse import csr_matrix
21
22
23def read_numpy(filename, allow_pickle=False):
24with open(filename, 'rb') as f:
25return np.load(f, allow_pickle=allow_pickle)
26
27
28def load_weak_and_strong_signals(path, name):
29"""Loads weak and strong signals for dataset class."""
30if name in ['stackoverflow', 'search_snippets']:
31features = read_numpy(f'{path}/{name}_weak_signal')
32weak_signal = features.dot(features.T)
33strong_signal = read_numpy(f'{path}/{name}_strong_signal', True)
34strong_signal = strong_signal[()].toarray()
35else:
36weak_signal = read_numpy(f'{path}/{name}_weak_signal')
37strong_signal = read_numpy(f'{path}/{name}_strong_signal')
38return weak_signal, strong_signal
39
40
41@gin.configurable
42class Dataset:
43"""Dataset class."""
44
45def __init__(self, path, name):
46self.get_weak_and_strong_signals(path, name)
47
48def get_weak_and_strong_signals(self, path, name):
49# Assumption: weak_signal is a NxN numpy array of similarities.
50# Assumption: strong signal is a NxN numpy array of similarities.
51self.weak_signal, self.strong_signal = load_weak_and_strong_signals(
52path, name
53)
54assert self.weak_signal.shape == self.strong_signal.shape
55self.is_graph = False
56self.is_sparse = False
57
58@property
59def num_examples(self):
60return self.strong_signal.shape[0]
61
62@property
63def num_pairs(self):
64return (self.num_examples * (self.num_examples - 1)) / 2
65
66def same_cluster(self, ex_id1, ex_id2):
67return self.strong_signal[ex_id1][ex_id2] == 1
68
69def pair_same_cluster_iterator(self):
70for ex_id1 in range(self.num_examples):
71for ex_id2 in range(ex_id1 + 1, self.num_examples):
72yield ex_id1, ex_id2, self.same_cluster(ex_id1, ex_id2)
73
74def most_similar_pairs(self, k):
75"""Selects the top_k most similar indices from the weak_signal matrix.
76
77Consider a similarity matrix as follows:
78[[ 1 2 3 4]
79[ 2 4 5 6]
80[ 3 5 2 3]
81[ 4 6 3 2]]
82And assume we want to select the indices of the top 3 values.
83We first triangualrize the matrix:
84[[0 2 3 4]
85[0 0 5 6]
86[0 0 0 3]
87[0 0 0 0]]
88The ravel function then linearizes the data:
89[ 0 -2 -3 -4 0 0 -5 -6 0 0 0 -3 0 0 0 0]
90The argpartition selects the index of the top 3 elements:
91[3 6 7]
92The unravel function then turns the 1d indexes into 2d indexes:
93[[0 3], [1 2], [1 3]]
94
95Args:
96k: for top_k
97
98Returns:
99Indices of the top k elements
100"""
101triangular_sims = np.triu(self.weak_signal, k=1)
102idx = np.argpartition(-triangular_sims.ravel(), k)[:k]
103return list(np.column_stack(np.unravel_index(idx, triangular_sims.shape)))
104
105def argsort_similarities(self, similarities):
106"""Arg sorts an array of similarities row-wise in decreasing order."""
107return np.argsort(similarities)[:, ::-1]
108
109def k_nearest_neighbors_weak_signal(self, input_nodes, k):
110"""Returns k nearest neighbors of input nodes according to weak signal."""
111# Weak signal is cosine similarities of embeddings.
112# Higher values indicate closer neighbors.
113cosine_similarities = self.weak_signal[input_nodes, :]
114if self.is_sparse:
115cosine_similarities = cosine_similarities.toarray()
116return self.argsort_similarities(cosine_similarities)[:, 0 : k + 1]
117
118def symmetrize_graph(self, graph):
119"""Makes graph (csr matrix) symmetric."""
120rows, cols = graph.nonzero()
121graph[cols, rows] = graph[rows, cols]
122return graph
123
124def construct_weak_signal_knn_graph(self, k):
125"""Make k-nn graph according to weak_signal similarity matrix."""
126rows, cols = [], []
127for i in range(self.num_examples):
128k_neighbors = self.k_nearest_neighbors_weak_signal([i], k)[0, :]
129for neighbor in k_neighbors:
130if neighbor != i:
131rows.append(i)
132cols.append(neighbor)
133graph = csr_matrix(
134([1] * len(rows), (rows, cols)),
135shape=(self.num_examples, self.num_examples),
136)
137# Procedure to make graph symmetric.
138graph = self.symmetrize_graph(graph)
139return graph
140
141def reweight_graph_using_strong_signal(self, possible_edges):
142"""Reweight given edges using strong signal."""
143filtered_rows = []
144filtered_columns = []
145for v, u in possible_edges:
146if self.strong_signal[v, u]:
147filtered_rows.append(v)
148filtered_columns.append(u)
149graph = csr_matrix(
150([1] * len(filtered_rows), (filtered_rows, filtered_columns)),
151shape=(self.num_examples, self.num_examples),
152)
153graph = self.symmetrize_graph(graph)
154return graph
155
156def construct_weighted_knn_graph(self, k):
157"""Get weak signal knn graph and reweight edges using strong signal."""
158weak_signal_knn_graph = self.construct_weak_signal_knn_graph(k)
159rows, columns = weak_signal_knn_graph.nonzero()
160strong_signal_weighted_knn_graph = self.reweight_graph_using_strong_signal(
161zip(rows, columns)
162)
163return strong_signal_weighted_knn_graph
164
165
166class AdhocDataset:
167"""Adhoc dataset class for creating a graph dataset on the fly."""
168
169def __init__(self, strong_signal, features):
170self.strong_signal = strong_signal
171self.is_graph = True
172self.features = features
173self.num_examples = strong_signal.shape[0]
174