google-research
177 строк · 6.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Curvature estimates of interaction graphs functions."""
17
18import random
19
20from absl import app
21from absl import flags
22import networkx as nx
23import numpy as np
24
25from scipy.sparse import csr_matrix
26from tqdm import tqdm
27
28from hyperbolic.datasets.datasets import DatasetClass
29
30FLAGS = flags.FLAGS
31flags.DEFINE_string(
32'dataset_path',
33default='data/ml-1m/',
34help='Path to dataset')
35flags.DEFINE_integer(
36'num_of_triangles', default=20, help='number of triangles to sample')
37
38
39def pairs_to_adj(dataset):
40"""Creates the adjacency matrix of the user-item bipartite graph."""
41as_matrix = np.zeros((dataset.n_users, dataset.n_items))
42for pair in dataset.data['train']:
43user, item = pair
44as_matrix[user, item] = 1.0
45return as_matrix
46
47
48def interaction(adj_matrix, users_as_nodes=True):
49"""Creates interaction matrix.
50
51Args:
52adj_matrix: Numpy array representing the adjacency matrix of the
53user-item bipartite graph.
54users_as_nodes: Bool indicating which interaction matrix to generate. If
55True (False), generates a user-user (item-item) interaction matrix.
56
57Returns:
58Numpy array of size n_users x n_users (n_items x n_items) with zeros
59on the diagonal and number of shared items (users) elsewhere, if user=True
60(False).
61"""
62sp_adj = csr_matrix(adj_matrix)
63if users_as_nodes:
64ret_matrix = (sp_adj * sp_adj.transpose()).todense()
65else:
66ret_matrix = (sp_adj.transpose() * sp_adj).todense()
67np.fill_diagonal(ret_matrix, 0.0)
68return ret_matrix
69
70
71def weight_with_degree(interaction_matrix, adj_matrix, users_as_nodes=True):
72"""Includes the bipartite nodes degree in the interaction graph edges weights."""
73if users_as_nodes:
74degrees = np.sum(adj_matrix, axis=1).reshape(-1, 1)
75else:
76degrees = np.sum(adj_matrix, axis=0).reshape(-1, 1)
77sum_degrees = np.maximum(1, degrees + degrees.transpose())
78total_weight = 2 * interaction_matrix / sum_degrees
79return total_weight
80
81
82def weight_to_dist(weight, exp=False):
83"""Turns the weights to distances, if needed uses exponential scale."""
84if exp:
85return (weight != 0) * np.exp(-10* weight)
86return np.divide(
871.0, weight, out=np.zeros_like(weight), where=weight != 0)
88
89
90def xi_stats(graph, n_iter=20):
91"""Calculates curvature estimates for a given graph.
92
93Args:
94graph: NetworkX Graph class, representing undirected graph with positive
95edge weights (if no weights exist, assumes edges weights are 1).
96n_iter: Int indicating how many triangles to sample in the graph.
97
98Returns:
99Tuple of size 3 containng the mean of the curvatures of the triangles,
100the standard deviation of the curvatures of the triangles and the total
101number of legally sampled triangles.
102"""
103xis = []
104if not nx.is_connected(graph):
105largest_cc = max(nx.connected_components(graph), key=len)
106graph = graph.subgraph(largest_cc).copy()
107nodes_list = list(graph.nodes())
108for _ in tqdm(range(n_iter), ascii=True, desc='Sample triangles'):
109# sample a triangle
110a, b, c = random.sample(nodes_list, 3)
111# find the middle node m between node b and c
112d_b_c, path_b_c = nx.single_source_dijkstra(graph, b, c)
113if len(path_b_c) <= 2:
114continue
115m = path_b_c[len(path_b_c) // 2]
116if m == a:
117continue
118# calculate xi for the sampled triangle, following section 3.2 in
119# Gu et al, Learning Mixed Curvature..., 2019
120all_len = nx.single_source_dijkstra_path_length(graph, a)
121d_a_m, d_a_b, d_a_c = all_len[m], all_len[b], all_len[c]
122xi = (d_a_m**2 + 0.25*d_b_c**2 -0.5*(d_a_b**2+d_a_c**2))/(2*d_a_m)
123xis.append(xi)
124return np.mean(xis), np.std(xis), len(xis)
125
126
127def format_xi_stats(users_as_nodes, exp, xi_mean, xi_std, tot):
128"""Formats the curvature estimates for logging.
129
130Args:
131users_as_nodes: Bool indicating which interaction graph was generated. If
132True (False), a user-user (item-item) interaction graph was generated.
133exp: Boolean indicating if the interaction graph distances are on
134an exponential scale.
135xi_mean: Float containng the mean of the curvatures of the sampled
136triangles.
137xi_std: Float containng the standard deviation of the curvatures of the
138sampled triangles.
139tot: Int containing the total number of legal sampled triangles.
140
141Returns:
142String storing the input information in a readable format.
143"""
144stats = 'User-user stats:' if users_as_nodes else 'Item-item stats:'
145if exp:
146stats += ' (using exp)'
147stats += '\n'
148stats += '{:.3f} +/- {:.3f} \n'.format(xi_mean, xi_std)
149stats += 'out of {} samples.'.format(tot)
150return stats
151
152
153def all_stats(dataset, n_iter=20):
154"""Estimates curvature for all interaction graphs, returns a string summery."""
155summary = '\n'
156adj_matrix = pairs_to_adj(dataset)
157for users_as_nodes in [True, False]:
158one_side_interaction = interaction(adj_matrix, users_as_nodes)
159weights = weight_with_degree(one_side_interaction, adj_matrix,
160users_as_nodes)
161for exp in [True, False]:
162dist_matrix = weight_to_dist(weights, exp)
163graph = nx.from_numpy_array(dist_matrix)
164xi_mean, xi_std, tot = xi_stats(graph, n_iter)
165summary += format_xi_stats(users_as_nodes, exp, xi_mean, xi_std, tot)
166summary += '\n \n'
167return summary
168
169
170def main(_):
171dataset_path = FLAGS.dataset_path
172data = DatasetClass(dataset_path, debug=False)
173print(all_stats(data, FLAGS.num_of_triangles))
174
175
176if __name__ == '__main__':
177app.run(main)
178