google-research
301 строка · 9.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Collections of preprocessing functions for different graph formats."""
17
18import json19import time20
21from networkx.readwrite import json_graph22import numpy as np23import partition_utils24import scipy.sparse as sp25import sklearn.metrics26import sklearn.preprocessing27import tensorflow.compat.v1 as tf28from tensorflow.compat.v1 import gfile29
30
31def parse_index_file(filename):32"""Parse index file."""33index = []34for line in gfile.Open(filename):35index.append(int(line.strip()))36return index37
38
39def sample_mask(idx, l):40"""Create mask."""41mask = np.zeros(l)42mask[idx] = 143return np.array(mask, dtype=bool)44
45
46def sym_normalize_adj(adj):47"""Normalization by D^{-1/2} (A+I) D^{-1/2}."""48adj = adj + sp.eye(adj.shape[0])49rowsum = np.array(adj.sum(1)) + 1e-2050d_inv_sqrt = np.power(rowsum, -0.5).flatten()51d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.52d_mat_inv_sqrt = sp.diags(d_inv_sqrt, 0)53adj = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt)54return adj55
56
57def normalize_adj(adj):58rowsum = np.array(adj.sum(1)).flatten()59d_inv = 1.0 / (np.maximum(1.0, rowsum))60d_mat_inv = sp.diags(d_inv, 0)61adj = d_mat_inv.dot(adj)62return adj63
64
65def normalize_adj_diag_enhance(adj, diag_lambda):66"""Normalization by A'=(D+I)^{-1}(A+I), A'=A'+lambda*diag(A')."""67adj = adj + sp.eye(adj.shape[0])68rowsum = np.array(adj.sum(1)).flatten()69d_inv = 1.0 / (rowsum + 1e-20)70d_mat_inv = sp.diags(d_inv, 0)71adj = d_mat_inv.dot(adj)72adj = adj + diag_lambda * sp.diags(adj.diagonal(), 0)73return adj74
75
76def sparse_to_tuple(sparse_mx):77"""Convert sparse matrix to tuple representation."""78
79def to_tuple(mx):80if not sp.isspmatrix_coo(mx):81mx = mx.tocoo()82coords = np.vstack((mx.row, mx.col)).transpose()83values = mx.data84shape = mx.shape85return coords, values, shape86
87if isinstance(sparse_mx, list):88for i in range(len(sparse_mx)):89sparse_mx[i] = to_tuple(sparse_mx[i])90else:91sparse_mx = to_tuple(sparse_mx)92
93return sparse_mx94
95
96def calc_f1(y_pred, y_true, multilabel):97if multilabel:98y_pred[y_pred > 0] = 199y_pred[y_pred <= 0] = 0100else:101y_true = np.argmax(y_true, axis=1)102y_pred = np.argmax(y_pred, axis=1)103return sklearn.metrics.f1_score(104y_true, y_pred, average='micro'), sklearn.metrics.f1_score(105y_true, y_pred, average='macro')106
107
108def construct_feed_dict(features, support, labels, labels_mask, placeholders):109"""Construct feed dictionary."""110feed_dict = dict()111feed_dict.update({placeholders['labels']: labels})112feed_dict.update({placeholders['labels_mask']: labels_mask})113feed_dict.update({placeholders['features']: features})114feed_dict.update({placeholders['support']: support})115feed_dict.update({placeholders['num_features_nonzero']: features[1].shape})116return feed_dict117
118
119def preprocess_multicluster(adj,120parts,121features,122y_train,123train_mask,124num_clusters,125block_size,126diag_lambda=-1):127"""Generate the batch for multiple clusters."""128
129features_batches = []130support_batches = []131y_train_batches = []132train_mask_batches = []133total_nnz = 0134np.random.shuffle(parts)135for _, st in enumerate(range(0, num_clusters, block_size)):136pt = parts[st]137for pt_idx in range(st + 1, min(st + block_size, num_clusters)):138pt = np.concatenate((pt, parts[pt_idx]), axis=0)139features_batches.append(features[pt, :])140y_train_batches.append(y_train[pt, :])141support_now = adj[pt, :][:, pt]142if diag_lambda == -1:143support_batches.append(sparse_to_tuple(normalize_adj(support_now)))144else:145support_batches.append(146sparse_to_tuple(normalize_adj_diag_enhance(support_now, diag_lambda)))147total_nnz += support_now.count_nonzero()148
149train_pt = []150for newidx, idx in enumerate(pt):151if train_mask[idx]:152train_pt.append(newidx)153train_mask_batches.append(sample_mask(train_pt, len(pt)))154return (features_batches, support_batches, y_train_batches,155train_mask_batches)156
157
158def preprocess(adj,159features,160y_train,161train_mask,162visible_data,163num_clusters,164diag_lambda=-1):165"""Do graph partitioning and preprocessing for SGD training."""166
167# Do graph partitioning168part_adj, parts = partition_utils.partition_graph(adj, visible_data,169num_clusters)170if diag_lambda == -1:171part_adj = normalize_adj(part_adj)172else:173part_adj = normalize_adj_diag_enhance(part_adj, diag_lambda)174parts = [np.array(pt) for pt in parts]175
176features_batches = []177support_batches = []178y_train_batches = []179train_mask_batches = []180total_nnz = 0181for pt in parts:182features_batches.append(features[pt, :])183now_part = part_adj[pt, :][:, pt]184total_nnz += now_part.count_nonzero()185support_batches.append(sparse_to_tuple(now_part))186y_train_batches.append(y_train[pt, :])187
188train_pt = []189for newidx, idx in enumerate(pt):190if train_mask[idx]:191train_pt.append(newidx)192train_mask_batches.append(sample_mask(train_pt, len(pt)))193return (parts, features_batches, support_batches, y_train_batches,194train_mask_batches)195
196
197def load_graphsage_data(dataset_path, dataset_str, normalize=True):198"""Load GraphSAGE data."""199start_time = time.time()200
201graph_json = json.load(202gfile.Open('{}/{}/{}-G.json'.format(dataset_path, dataset_str,203dataset_str)))204graph_nx = json_graph.node_link_graph(graph_json)205
206id_map = json.load(207gfile.Open('{}/{}/{}-id_map.json'.format(dataset_path, dataset_str,208dataset_str)))209is_digit = list(id_map.keys())[0].isdigit()210id_map = {(int(k) if is_digit else k): int(v) for k, v in id_map.items()}211class_map = json.load(212gfile.Open('{}/{}/{}-class_map.json'.format(dataset_path, dataset_str,213dataset_str)))214
215is_instance = isinstance(list(class_map.values())[0], list)216class_map = {(int(k) if is_digit else k): (v if is_instance else int(v))217for k, v in class_map.items()}218
219broken_count = 0220to_remove = []221for node in graph_nx.nodes():222if node not in id_map:223to_remove.append(node)224broken_count += 1225for node in to_remove:226graph_nx.remove_node(node)227tf.logging.info(228'Removed %d nodes that lacked proper annotations due to networkx versioning issues',229broken_count)230
231feats = np.load(232gfile.Open(233'{}/{}/{}-feats.npy'.format(dataset_path, dataset_str, dataset_str),234'rb')).astype(np.float32)235
236tf.logging.info('Loaded data (%f seconds).. now preprocessing..',237time.time() - start_time)238start_time = time.time()239
240edges = []241for edge in graph_nx.edges():242if edge[0] in id_map and edge[1] in id_map:243edges.append((id_map[edge[0]], id_map[edge[1]]))244num_data = len(id_map)245
246val_data = np.array(247[id_map[n] for n in graph_nx.nodes() if graph_nx.node[n]['val']],248dtype=np.int32)249test_data = np.array(250[id_map[n] for n in graph_nx.nodes() if graph_nx.node[n]['test']],251dtype=np.int32)252is_train = np.ones((num_data), dtype=bool)253is_train[val_data] = False254is_train[test_data] = False255train_data = np.array([n for n in range(num_data) if is_train[n]],256dtype=np.int32)257
258train_edges = [259(e[0], e[1]) for e in edges if is_train[e[0]] and is_train[e[1]]260]261edges = np.array(edges, dtype=np.int32)262train_edges = np.array(train_edges, dtype=np.int32)263
264# Process labels265if isinstance(list(class_map.values())[0], list):266num_classes = len(list(class_map.values())[0])267labels = np.zeros((num_data, num_classes), dtype=np.float32)268for k in class_map.keys():269labels[id_map[k], :] = np.array(class_map[k])270else:271num_classes = len(set(class_map.values()))272labels = np.zeros((num_data, num_classes), dtype=np.float32)273for k in class_map.keys():274labels[id_map[k], class_map[k]] = 1275
276if normalize:277train_ids = np.array([278id_map[n]279for n in graph_nx.nodes()280if not graph_nx.node[n]['val'] and not graph_nx.node[n]['test']281])282train_feats = feats[train_ids]283scaler = sklearn.preprocessing.StandardScaler()284scaler.fit(train_feats)285feats = scaler.transform(feats)286
287def _construct_adj(edges):288adj = sp.csr_matrix((np.ones(289(edges.shape[0]), dtype=np.float32), (edges[:, 0], edges[:, 1])),290shape=(num_data, num_data))291adj += adj.transpose()292return adj293
294train_adj = _construct_adj(train_edges)295full_adj = _construct_adj(edges)296
297train_feats = feats[train_data]298test_feats = feats299
300tf.logging.info('Data loaded, %f seconds.', time.time() - start_time)301return num_data, train_adj, full_adj, feats, train_feats, test_feats, labels, train_data, val_data, test_data302