google-research

Форк
0
301 строка · 9.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Collections of preprocessing functions for different graph formats."""
17

18
import json
19
import time
20

21
from networkx.readwrite import json_graph
22
import numpy as np
23
import partition_utils
24
import scipy.sparse as sp
25
import sklearn.metrics
26
import sklearn.preprocessing
27
import tensorflow.compat.v1 as tf
28
from tensorflow.compat.v1 import gfile
29

30

31
def parse_index_file(filename):
32
  """Parse index file."""
33
  index = []
34
  for line in gfile.Open(filename):
35
    index.append(int(line.strip()))
36
  return index
37

38

39
def sample_mask(idx, l):
40
  """Create mask."""
41
  mask = np.zeros(l)
42
  mask[idx] = 1
43
  return np.array(mask, dtype=bool)
44

45

46
def sym_normalize_adj(adj):
47
  """Normalization by D^{-1/2} (A+I) D^{-1/2}."""
48
  adj = adj + sp.eye(adj.shape[0])
49
  rowsum = np.array(adj.sum(1)) + 1e-20
50
  d_inv_sqrt = np.power(rowsum, -0.5).flatten()
51
  d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
52
  d_mat_inv_sqrt = sp.diags(d_inv_sqrt, 0)
53
  adj = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt)
54
  return adj
55

56

57
def normalize_adj(adj):
58
  rowsum = np.array(adj.sum(1)).flatten()
59
  d_inv = 1.0 / (np.maximum(1.0, rowsum))
60
  d_mat_inv = sp.diags(d_inv, 0)
61
  adj = d_mat_inv.dot(adj)
62
  return adj
63

64

65
def normalize_adj_diag_enhance(adj, diag_lambda):
66
  """Normalization by  A'=(D+I)^{-1}(A+I), A'=A'+lambda*diag(A')."""
67
  adj = adj + sp.eye(adj.shape[0])
68
  rowsum = np.array(adj.sum(1)).flatten()
69
  d_inv = 1.0 / (rowsum + 1e-20)
70
  d_mat_inv = sp.diags(d_inv, 0)
71
  adj = d_mat_inv.dot(adj)
72
  adj = adj + diag_lambda * sp.diags(adj.diagonal(), 0)
73
  return adj
74

75

76
def sparse_to_tuple(sparse_mx):
77
  """Convert sparse matrix to tuple representation."""
78

79
  def to_tuple(mx):
80
    if not sp.isspmatrix_coo(mx):
81
      mx = mx.tocoo()
82
    coords = np.vstack((mx.row, mx.col)).transpose()
83
    values = mx.data
84
    shape = mx.shape
85
    return coords, values, shape
86

87
  if isinstance(sparse_mx, list):
88
    for i in range(len(sparse_mx)):
89
      sparse_mx[i] = to_tuple(sparse_mx[i])
90
  else:
91
    sparse_mx = to_tuple(sparse_mx)
92

93
  return sparse_mx
94

95

96
def calc_f1(y_pred, y_true, multilabel):
97
  if multilabel:
98
    y_pred[y_pred > 0] = 1
99
    y_pred[y_pred <= 0] = 0
100
  else:
101
    y_true = np.argmax(y_true, axis=1)
102
    y_pred = np.argmax(y_pred, axis=1)
103
  return sklearn.metrics.f1_score(
104
      y_true, y_pred, average='micro'), sklearn.metrics.f1_score(
105
          y_true, y_pred, average='macro')
106

107

108
def construct_feed_dict(features, support, labels, labels_mask, placeholders):
109
  """Construct feed dictionary."""
110
  feed_dict = dict()
111
  feed_dict.update({placeholders['labels']: labels})
112
  feed_dict.update({placeholders['labels_mask']: labels_mask})
113
  feed_dict.update({placeholders['features']: features})
114
  feed_dict.update({placeholders['support']: support})
115
  feed_dict.update({placeholders['num_features_nonzero']: features[1].shape})
116
  return feed_dict
117

118

119
def preprocess_multicluster(adj,
120
                            parts,
121
                            features,
122
                            y_train,
123
                            train_mask,
124
                            num_clusters,
125
                            block_size,
126
                            diag_lambda=-1):
127
  """Generate the batch for multiple clusters."""
128

129
  features_batches = []
130
  support_batches = []
131
  y_train_batches = []
132
  train_mask_batches = []
133
  total_nnz = 0
134
  np.random.shuffle(parts)
135
  for _, st in enumerate(range(0, num_clusters, block_size)):
136
    pt = parts[st]
137
    for pt_idx in range(st + 1, min(st + block_size, num_clusters)):
138
      pt = np.concatenate((pt, parts[pt_idx]), axis=0)
139
    features_batches.append(features[pt, :])
140
    y_train_batches.append(y_train[pt, :])
141
    support_now = adj[pt, :][:, pt]
142
    if diag_lambda == -1:
143
      support_batches.append(sparse_to_tuple(normalize_adj(support_now)))
144
    else:
145
      support_batches.append(
146
          sparse_to_tuple(normalize_adj_diag_enhance(support_now, diag_lambda)))
147
    total_nnz += support_now.count_nonzero()
148

149
    train_pt = []
150
    for newidx, idx in enumerate(pt):
151
      if train_mask[idx]:
152
        train_pt.append(newidx)
153
    train_mask_batches.append(sample_mask(train_pt, len(pt)))
154
  return (features_batches, support_batches, y_train_batches,
155
          train_mask_batches)
156

157

158
def preprocess(adj,
159
               features,
160
               y_train,
161
               train_mask,
162
               visible_data,
163
               num_clusters,
164
               diag_lambda=-1):
165
  """Do graph partitioning and preprocessing for SGD training."""
166

167
  # Do graph partitioning
168
  part_adj, parts = partition_utils.partition_graph(adj, visible_data,
169
                                                    num_clusters)
170
  if diag_lambda == -1:
171
    part_adj = normalize_adj(part_adj)
172
  else:
173
    part_adj = normalize_adj_diag_enhance(part_adj, diag_lambda)
174
  parts = [np.array(pt) for pt in parts]
175

176
  features_batches = []
177
  support_batches = []
178
  y_train_batches = []
179
  train_mask_batches = []
180
  total_nnz = 0
181
  for pt in parts:
182
    features_batches.append(features[pt, :])
183
    now_part = part_adj[pt, :][:, pt]
184
    total_nnz += now_part.count_nonzero()
185
    support_batches.append(sparse_to_tuple(now_part))
186
    y_train_batches.append(y_train[pt, :])
187

188
    train_pt = []
189
    for newidx, idx in enumerate(pt):
190
      if train_mask[idx]:
191
        train_pt.append(newidx)
192
    train_mask_batches.append(sample_mask(train_pt, len(pt)))
193
  return (parts, features_batches, support_batches, y_train_batches,
194
          train_mask_batches)
195

196

197
def load_graphsage_data(dataset_path, dataset_str, normalize=True):
198
  """Load GraphSAGE data."""
199
  start_time = time.time()
200

201
  graph_json = json.load(
202
      gfile.Open('{}/{}/{}-G.json'.format(dataset_path, dataset_str,
203
                                          dataset_str)))
204
  graph_nx = json_graph.node_link_graph(graph_json)
205

206
  id_map = json.load(
207
      gfile.Open('{}/{}/{}-id_map.json'.format(dataset_path, dataset_str,
208
                                               dataset_str)))
209
  is_digit = list(id_map.keys())[0].isdigit()
210
  id_map = {(int(k) if is_digit else k): int(v) for k, v in id_map.items()}
211
  class_map = json.load(
212
      gfile.Open('{}/{}/{}-class_map.json'.format(dataset_path, dataset_str,
213
                                                  dataset_str)))
214

215
  is_instance = isinstance(list(class_map.values())[0], list)
216
  class_map = {(int(k) if is_digit else k): (v if is_instance else int(v))
217
               for k, v in class_map.items()}
218

219
  broken_count = 0
220
  to_remove = []
221
  for node in graph_nx.nodes():
222
    if node not in id_map:
223
      to_remove.append(node)
224
      broken_count += 1
225
  for node in to_remove:
226
    graph_nx.remove_node(node)
227
  tf.logging.info(
228
      'Removed %d nodes that lacked proper annotations due to networkx versioning issues',
229
      broken_count)
230

231
  feats = np.load(
232
      gfile.Open(
233
          '{}/{}/{}-feats.npy'.format(dataset_path, dataset_str, dataset_str),
234
          'rb')).astype(np.float32)
235

236
  tf.logging.info('Loaded data (%f seconds).. now preprocessing..',
237
                  time.time() - start_time)
238
  start_time = time.time()
239

240
  edges = []
241
  for edge in graph_nx.edges():
242
    if edge[0] in id_map and edge[1] in id_map:
243
      edges.append((id_map[edge[0]], id_map[edge[1]]))
244
  num_data = len(id_map)
245

246
  val_data = np.array(
247
      [id_map[n] for n in graph_nx.nodes() if graph_nx.node[n]['val']],
248
      dtype=np.int32)
249
  test_data = np.array(
250
      [id_map[n] for n in graph_nx.nodes() if graph_nx.node[n]['test']],
251
      dtype=np.int32)
252
  is_train = np.ones((num_data), dtype=bool)
253
  is_train[val_data] = False
254
  is_train[test_data] = False
255
  train_data = np.array([n for n in range(num_data) if is_train[n]],
256
                        dtype=np.int32)
257

258
  train_edges = [
259
      (e[0], e[1]) for e in edges if is_train[e[0]] and is_train[e[1]]
260
  ]
261
  edges = np.array(edges, dtype=np.int32)
262
  train_edges = np.array(train_edges, dtype=np.int32)
263

264
  # Process labels
265
  if isinstance(list(class_map.values())[0], list):
266
    num_classes = len(list(class_map.values())[0])
267
    labels = np.zeros((num_data, num_classes), dtype=np.float32)
268
    for k in class_map.keys():
269
      labels[id_map[k], :] = np.array(class_map[k])
270
  else:
271
    num_classes = len(set(class_map.values()))
272
    labels = np.zeros((num_data, num_classes), dtype=np.float32)
273
    for k in class_map.keys():
274
      labels[id_map[k], class_map[k]] = 1
275

276
  if normalize:
277
    train_ids = np.array([
278
        id_map[n]
279
        for n in graph_nx.nodes()
280
        if not graph_nx.node[n]['val'] and not graph_nx.node[n]['test']
281
    ])
282
    train_feats = feats[train_ids]
283
    scaler = sklearn.preprocessing.StandardScaler()
284
    scaler.fit(train_feats)
285
    feats = scaler.transform(feats)
286

287
  def _construct_adj(edges):
288
    adj = sp.csr_matrix((np.ones(
289
        (edges.shape[0]), dtype=np.float32), (edges[:, 0], edges[:, 1])),
290
                        shape=(num_data, num_data))
291
    adj += adj.transpose()
292
    return adj
293

294
  train_adj = _construct_adj(train_edges)
295
  full_adj = _construct_adj(edges)
296

297
  train_feats = feats[train_data]
298
  test_feats = feats
299

300
  tf.logging.info('Data loaded, %f seconds.', time.time() - start_time)
301
  return num_data, train_adj, full_adj, feats, train_feats, test_feats, labels, train_data, val_data, test_data
302

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.