google-research

Форк
0
/
graph_analysis.py 
177 строк · 6.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Curvature estimates of interaction graphs functions."""
17

18
import random
19

20
from absl import app
21
from absl import flags
22
import networkx as nx
23
import numpy as np
24

25
from scipy.sparse import csr_matrix
26
from tqdm import tqdm
27

28
from hyperbolic.datasets.datasets import DatasetClass
29

30
FLAGS = flags.FLAGS
31
flags.DEFINE_string(
32
    'dataset_path',
33
    default='data/ml-1m/',
34
    help='Path to dataset')
35
flags.DEFINE_integer(
36
    'num_of_triangles', default=20, help='number of triangles to sample')
37

38

39
def pairs_to_adj(dataset):
40
  """Creates the adjacency matrix of the user-item bipartite graph."""
41
  as_matrix = np.zeros((dataset.n_users, dataset.n_items))
42
  for pair in dataset.data['train']:
43
    user, item = pair
44
    as_matrix[user, item] = 1.0
45
  return as_matrix
46

47

48
def interaction(adj_matrix, users_as_nodes=True):
49
  """Creates interaction matrix.
50

51
  Args:
52
    adj_matrix: Numpy array representing the adjacency matrix of the
53
      user-item bipartite graph.
54
    users_as_nodes: Bool indicating which interaction matrix to generate. If
55
      True (False), generates a user-user (item-item) interaction matrix.
56

57
  Returns:
58
    Numpy array of size n_users x n_users (n_items x n_items) with zeros
59
    on the diagonal and number of shared items (users) elsewhere, if user=True
60
    (False).
61
  """
62
  sp_adj = csr_matrix(adj_matrix)
63
  if users_as_nodes:
64
    ret_matrix = (sp_adj * sp_adj.transpose()).todense()
65
  else:
66
    ret_matrix = (sp_adj.transpose() * sp_adj).todense()
67
  np.fill_diagonal(ret_matrix, 0.0)
68
  return ret_matrix
69

70

71
def weight_with_degree(interaction_matrix, adj_matrix, users_as_nodes=True):
72
  """Includes the bipartite nodes degree in the interaction graph edges weights."""
73
  if users_as_nodes:
74
    degrees = np.sum(adj_matrix, axis=1).reshape(-1, 1)
75
  else:
76
    degrees = np.sum(adj_matrix, axis=0).reshape(-1, 1)
77
  sum_degrees = np.maximum(1, degrees + degrees.transpose())
78
  total_weight = 2 * interaction_matrix / sum_degrees
79
  return total_weight
80

81

82
def weight_to_dist(weight, exp=False):
83
  """Turns the weights to distances, if needed uses exponential scale."""
84
  if exp:
85
    return (weight != 0) * np.exp(-10* weight)
86
  return np.divide(
87
      1.0, weight, out=np.zeros_like(weight), where=weight != 0)
88

89

90
def xi_stats(graph, n_iter=20):
91
  """Calculates curvature estimates for a given graph.
92

93
  Args:
94
    graph: NetworkX Graph class, representing undirected graph with positive
95
      edge weights (if no weights exist, assumes edges weights are 1).
96
    n_iter: Int indicating how many triangles to sample in the graph.
97

98
  Returns:
99
    Tuple of size 3 containng the mean of the curvatures of the triangles,
100
    the standard deviation of the curvatures of the triangles and the total
101
    number of legally sampled triangles.
102
  """
103
  xis = []
104
  if not nx.is_connected(graph):
105
    largest_cc = max(nx.connected_components(graph), key=len)
106
    graph = graph.subgraph(largest_cc).copy()
107
  nodes_list = list(graph.nodes())
108
  for _ in tqdm(range(n_iter), ascii=True, desc='Sample triangles'):
109
    # sample a triangle
110
    a, b, c = random.sample(nodes_list, 3)
111
    # find the middle node m between node b and c
112
    d_b_c, path_b_c = nx.single_source_dijkstra(graph, b, c)
113
    if len(path_b_c) <= 2:
114
      continue
115
    m = path_b_c[len(path_b_c) // 2]
116
    if m == a:
117
      continue
118
    # calculate xi for the sampled triangle, following section 3.2 in
119
    # Gu et al, Learning Mixed Curvature..., 2019
120
    all_len = nx.single_source_dijkstra_path_length(graph, a)
121
    d_a_m, d_a_b, d_a_c = all_len[m], all_len[b], all_len[c]
122
    xi = (d_a_m**2 + 0.25*d_b_c**2 -0.5*(d_a_b**2+d_a_c**2))/(2*d_a_m)
123
    xis.append(xi)
124
  return np.mean(xis), np.std(xis), len(xis)
125

126

127
def format_xi_stats(users_as_nodes, exp, xi_mean, xi_std, tot):
128
  """Formats the curvature estimates for logging.
129

130
  Args:
131
    users_as_nodes: Bool indicating which interaction graph was generated. If
132
      True (False), a user-user (item-item) interaction graph was generated.
133
    exp: Boolean indicating if the interaction graph distances are on
134
      an exponential scale.
135
    xi_mean: Float containng the mean of the curvatures of the sampled
136
      triangles.
137
    xi_std: Float containng the standard deviation of the curvatures of the
138
      sampled triangles.
139
    tot: Int containing the total number of legal sampled triangles.
140

141
  Returns:
142
    String storing the input information in a readable format.
143
  """
144
  stats = 'User-user stats:' if users_as_nodes else 'Item-item stats:'
145
  if exp:
146
    stats += ' (using exp)'
147
  stats += '\n'
148
  stats += '{:.3f} +/- {:.3f} \n'.format(xi_mean, xi_std)
149
  stats += 'out of {} samples.'.format(tot)
150
  return stats
151

152

153
def all_stats(dataset, n_iter=20):
154
  """Estimates curvature for all interaction graphs, returns a string summery."""
155
  summary = '\n'
156
  adj_matrix = pairs_to_adj(dataset)
157
  for users_as_nodes in [True, False]:
158
    one_side_interaction = interaction(adj_matrix, users_as_nodes)
159
    weights = weight_with_degree(one_side_interaction, adj_matrix,
160
                                 users_as_nodes)
161
    for exp in [True, False]:
162
      dist_matrix = weight_to_dist(weights, exp)
163
      graph = nx.from_numpy_array(dist_matrix)
164
      xi_mean, xi_std, tot = xi_stats(graph, n_iter)
165
      summary += format_xi_stats(users_as_nodes, exp, xi_mean, xi_std, tot)
166
      summary += '\n \n'
167
  return summary
168

169

170
def main(_):
171
  dataset_path = FLAGS.dataset_path
172
  data = DatasetClass(dataset_path, debug=False)
173
  print(all_stats(data, FLAGS.num_of_triangles))
174

175

176
if __name__ == '__main__':
177
  app.run(main)
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.