google-research
591 строка · 19.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Dataset definitions for in-memory graph structure learning."""
17import copy18import io19import math20import os21import random22from typing import List, Mapping, MutableMapping, Tuple23
24import numpy as np25import scipy.sparse26import tensorflow as tf27import tensorflow_gnn as tfgnn28import tensorflow_hub as tfhub29
30from ugsl import tfgnn_datasets31
32
33class GSLGraphData:34"""Wraps graph datasets to be used for graph structure learning.35
36GSLGraphData can take a given tensor as a generated adjacency and incorporate
37it in the graph tensow.
38"""
39
40def __init__(41self,42remove_noise_ratio = 0.0,43add_noise_ratio = 0.0,44):45super().__init__()46# Saving the generated noisy adjacency to reuse.47self._cached_noisy_adjacency = None48self._input_gt = self.as_graph_tensor_noisy_adjacency(49remove_noise_ratio=remove_noise_ratio, add_noise_ratio=add_noise_ratio50)51
52def node_sets(self):53raise NotImplementedError54
55def splits(self):56return copy.copy(self._splits)57
58def num_classes(self):59raise NotImplementedError('num_classes')60
61def node_split(self):62raise NotImplementedError()63
64def labels(self):65raise NotImplementedError()66
67def test_labels(self):68raise NotImplementedError()69
70@property71def labeled_nodeset(self):72raise NotImplementedError()73
74def node_features_dicts_without_labels(75self,76):77raise NotImplementedError()78
79def edge_lists(80self,81):82raise NotImplementedError()83
84def as_graph_tensor(self):85raise NotImplementedError()86
87def node_features_dicts(88self,89):90raise NotImplementedError()91
92def get_input_graph_tensor(self):93return self._input_gt94
95def as_graph_tensor_given_adjacency(96self,97adjacency_tensor,98edge_weights,99node_features,100make_undirected = False,101add_self_loops = False,102):103"""Returns `GraphTensor` holding the entire graph."""104return tfgnn.GraphTensor.from_pieces(105node_sets=self.node_sets_given_features(node_features),106edge_sets=self.edge_sets_given_adjacency(107adjacency_tensor,108edge_weights,109make_undirected,110add_self_loops,111),112context=self.context(),113)114
115def node_sets_given_features(116self, node_features117):118"""Returns node sets of entire graph (dict: node set name -> NodeSet)."""119node_counts = self.node_counts()120features_dicts = self.node_features_dicts()121node_set_names = set(node_counts.keys()).union(features_dicts.keys())122return {123name: tfgnn.NodeSet.from_fields(124sizes=tf.convert_to_tensor([node_counts[name]]),125features={'feat': node_features})126for name in node_set_names127}128
129def edge_sets_given_adjacency(130self,131edge_list,132edge_weights,133make_undirected = False,134add_self_loops = False,135):136"""Returns edge sets of entire graph (dict: edge set name -> EdgeSet)."""137if make_undirected:138edge_list = tf.concat([edge_list, edge_list[::-1]], axis=-1)139edge_weights = tf.concat([edge_weights, edge_weights[::-1]], axis=-1)140if add_self_loops:141node_counts = self.node_counts()142all_nodes = tf.range(node_counts[tfgnn.NODES], dtype=edge_list.dtype)143self_connections = tf.stack([all_nodes, all_nodes], axis=0)144# The following line adds self_connections to the existing edges.145# It is possible for an edge to be both available in the edge_list and146# also in the self_connections.147edge_list = tf.concat([edge_list, self_connections], axis=-1)148edge_weights = tf.concat(149[edge_weights, tf.ones(node_counts[tfgnn.NODES])], axis=-1150)151return {152tfgnn.EDGES: tfgnn.EdgeSet.from_fields(153sizes=tf.shape(edge_list)[1:2],154adjacency=tfgnn.Adjacency.from_indices(155source=(tfgnn.NODES, edge_list[0]),156target=(tfgnn.NODES, edge_list[1]),157),158features={'weights': edge_weights},159)160}161
162def as_graph_tensor_noisy_adjacency(163self,164remove_noise_ratio,165add_noise_ratio,166make_undirected = False,167add_self_loops = False,168):169"""Returns `GraphTensor` holding the entire graph."""170return tfgnn.GraphTensor.from_pieces(171node_sets=self.node_sets(),172edge_sets=self.edge_sets_noisy_adjacency(173add_noise_ratio=add_noise_ratio,174remove_noise_ratio=remove_noise_ratio,175make_undirected=make_undirected,176add_self_loops=add_self_loops,177),178context=self.context(),179)180
181def edge_sets_noisy_adjacency(182self,183add_noise_ratio,184remove_noise_ratio,185make_undirected = False,186add_self_loops = False,187):188"""Returns noisy edge sets of entire graph (dict: edge set name -> EdgeSet)."""189if self._cached_noisy_adjacency:190return self._cached_noisy_adjacency191edge_sets = {}192node_counts = self.node_counts()193for edge_type, edge_list in self.edge_lists().items():194(source_node_set_name, edge_set_name, target_node_set_name) = edge_type195number_of_nodes = node_counts[source_node_set_name]196sources = edge_list[0].numpy()197targets = edge_list[1].numpy()198number_of_edges = len(sources)199if add_noise_ratio:200number_of_edges_to_add = math.floor(201((number_of_nodes * number_of_nodes) / 2 - number_of_edges)202* add_noise_ratio203)204sources_to_add = np.array(205random.choices(range(number_of_nodes), k=number_of_edges_to_add)206)207targets_to_add = np.array(208random.choices(range(number_of_nodes), k=number_of_edges_to_add)209)210else:211sources_to_add, targets_to_add = np.array([]), np.array([])212if remove_noise_ratio:213number_of_edges_to_remove = math.floor(214number_of_edges * remove_noise_ratio215)216edge_indices_to_remove = random.sample(217range(0, number_of_edges), number_of_edges_to_remove218)219noisy_sources = np.delete(sources, edge_indices_to_remove)220noisy_targets = np.delete(targets, edge_indices_to_remove)221else:222noisy_sources, noisy_targets = sources, targets223noisy_sources = tf.constant(224np.concatenate((noisy_sources, sources_to_add)), dtype=tf.int32225)226noisy_targets = tf.constant(227np.concatenate((noisy_targets, targets_to_add)), dtype=tf.int32228)229edge_list = tf.stack([noisy_sources, noisy_targets])230if make_undirected:231edge_list = tf.concat([edge_list, edge_list[::-1]], axis=-1)232if add_self_loops:233all_nodes = tf.range(number_of_nodes, dtype=edge_list.dtype)234self_connections = tf.stack([all_nodes, all_nodes], axis=0)235edge_list = tf.concat([edge_list, self_connections], axis=-1)236edge_sets[edge_set_name] = tfgnn.EdgeSet.from_fields(237sizes=tf.shape(edge_list)[1:2],238adjacency=tfgnn.Adjacency.from_indices(239source=(source_node_set_name, edge_list[0]),240target=(target_node_set_name, edge_list[1]),241),242)243self._cached_noisy_adjacency = edge_sets244return edge_sets245
246
247class GSLPlanetoidGraphData(tfgnn_datasets.PlanetoidGraphData, GSLGraphData):248"""Wraps Planetoid graph datasets to be used for graph structure learning.249
250Besides the initial input adjacency matrix, GSLGraphData can take a given
251tensor as a generated adjacency and incorporate it in the graph tensow.
252"""
253
254def __init__(255self,256dataset_name,257remove_noise_ratio,258add_noise_ratio,259):260tfgnn_datasets.PlanetoidGraphData.__init__(self, dataset_name)261GSLGraphData.__init__(262self,263remove_noise_ratio=remove_noise_ratio,264add_noise_ratio=add_noise_ratio,265)266
267
268class GcnBenchmarkFileGraphData(tfgnn_datasets.NodeClassificationGraphData):269"""Adapt npz with format of github.com/shchur/gnn-benchmark into TF-GNN.270
271NOTE: This can be moved to TF-GNN (tfgnn/experimental/in_memory/datasets.py).
272"""
273
274def __init__(self, dataset_path):275"""Loads .npz file following shchur's format."""276if not tf.io.gfile.exists(dataset_path):277raise ValueError('Dataset file not found: ' + dataset_path)278
279adj_matrix, attr_matrix, labels, label_mask = _load_npz_to_sparse_graph(280dataset_path)281del label_mask282
283edge_indices = tf.convert_to_tensor(adj_matrix.nonzero())284self._edge_lists = {(tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): edge_indices}285
286num_nodes = attr_matrix.shape[0]287self._node_features_dicts = {288tfgnn.NODES: {289'feat': tf.convert_to_tensor(attr_matrix),290'#id': tf.range(num_nodes),291}292}293self._node_counts = {tfgnn.NODES: num_nodes}294self._num_classes = labels.max() + 1295self._test_labels = tf.convert_to_tensor(labels)296
297permutation = np.random.default_rng(seed=1234).permutation(num_nodes)298num_train_examples = num_nodes // 10299num_validate_examples = num_nodes // 10300train_indices = permutation[:num_train_examples]301num_validate_plus_train = num_validate_examples + num_train_examples302validate_indices = permutation[num_train_examples:num_validate_plus_train]303test_indices = permutation[num_validate_plus_train:]304
305self._node_split = tfgnn_datasets.NodeSplit(306tf.convert_to_tensor(train_indices),307tf.convert_to_tensor(validate_indices),308tf.convert_to_tensor(test_indices))309
310self._train_labels = labels + 0 # Make a copy.311self._train_labels[test_indices] = -1312self._train_labels = tf.convert_to_tensor(self._train_labels)313super().__init__()314
315def node_counts(self):316return self._node_counts317
318def edge_lists(self):319return self._edge_lists320
321def num_classes(self):322return self._num_classes323
324def node_split(self):325return self._node_split326
327def labels(self):328return self._train_labels329
330def test_labels(self):331return self._test_labels332
333@property334def labeled_nodeset(self):335return tfgnn.NODES336
337def node_features_dicts_without_labels(self):338return self._node_features_dicts339
340
341_maybe_download_file = tfgnn_datasets._maybe_download_file # pylint: disable=protected-access342
343
344class GcnBenchmarkUrlGraphData(GcnBenchmarkFileGraphData):345
346def __init__(347self, npz_url,348cache_dir = os.path.expanduser(349os.path.join('~', 'data', 'gnn-benchmark'))):350destination_url = os.path.join(cache_dir, os.path.basename(npz_url))351_maybe_download_file(npz_url, destination_url)352super().__init__(destination_url)353
354
355def _load_npz_to_sparse_graph(file_name):356"""Copied from experimental/users/tsitsulin/gcns/cgcn/utilities/graph.py."""357file_bytes = tf.io.gfile.GFile(file_name, 'rb').read()358bytes_io = io.BytesIO(file_bytes)359with np.load(bytes_io, allow_pickle=True) as fin:360loader = dict(fin)361adj_matrix = scipy.sparse.csr_matrix(362(loader['adj_data'], loader['adj_indices'], loader['adj_indptr']),363shape=loader['adj_shape'])364
365if 'attr_data' in loader:366# Attributes are stored as a sparse CSR matrix367attr_matrix = scipy.sparse.csr_matrix(368(loader['attr_data'], loader['attr_indices'],369loader['attr_indptr']),370shape=loader['attr_shape']).todense()371elif 'attr_matrix' in loader:372# Attributes are stored as a (dense) np.ndarray373attr_matrix = loader['attr_matrix']374else:375raise ValueError('No attributes in the data file: ' + file_name)376
377if 'labels_data' in loader:378# Labels are stored as a CSR matrix379labels = scipy.sparse.csr_matrix(380(loader['labels_data'], loader['labels_indices'],381loader['labels_indptr']),382shape=loader['labels_shape'])383label_mask = labels.nonzero()[0]384labels = labels.nonzero()[1]385elif 'labels' in loader:386# Labels are stored as a numpy array387labels = loader['labels']388label_mask = np.ones(labels.shape, dtype=np.bool_)389else:390raise ValueError('No labels in the data file: ' + file_name)391
392return adj_matrix, attr_matrix, labels, label_mask393
394
395class GSLAmazonPhotosGraphData(GcnBenchmarkUrlGraphData, GSLGraphData):396"""Wraps GCN Benchmark datasets to be used for graph structure learning."""397
398def __init__(399self,400dataset_name,401remove_noise_ratio,402add_noise_ratio,403):404GcnBenchmarkUrlGraphData.__init__(405self,406'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'407'amazon_electronics_photo.npz')408GSLGraphData.__init__(409self,410remove_noise_ratio=remove_noise_ratio,411add_noise_ratio=add_noise_ratio,412)413
414
415class StackOverflowGraphlessData(tfgnn_datasets.NodeClassificationGraphData):416"""Stackoverflow dataset contains node features and labels (but no edges)."""417
418def __init__(419self, cache_dir = os.path.expanduser(420os.path.join('~', 'data', 'stackoverflow-bert'))):421labels_path = os.path.join(cache_dir, 'labels.npy')422embeddings_path = os.path.join(cache_dir, 'embeddings.npy')423
424if (not tf.io.gfile.exists(labels_path) or425not tf.io.gfile.exists(embeddings_path)):426if not tf.io.gfile.exists(cache_dir):427tf.io.gfile.makedirs(cache_dir)428# Download.429self._download_dataset_extract_features(labels_path, embeddings_path)430
431node_features = np.load(tf.io.gfile.GFile(embeddings_path, 'rb'))432node_labels = np.load(tf.io.gfile.GFile(labels_path, 'rb'))433num_nodes = node_features.shape[0]434self._node_counts = {tfgnn.NODES: num_nodes}435self._num_classes = node_labels.max() + 1436self._test_labels = tf.convert_to_tensor(node_labels)437self._edge_lists = {438(tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): (439tf.zeros(shape=[2, 0], dtype=tf.int32))}440self._node_features_dicts = {441tfgnn.NODES: {442'feat': tf.convert_to_tensor(node_features, dtype=tf.float32),443'#id': tf.range(num_nodes),444}445}446permutation = np.random.default_rng(seed=1234).permutation(num_nodes)447num_train_examples = num_nodes // 10448num_validate_examples = num_nodes // 10449train_indices = permutation[:num_train_examples]450num_validate_plus_train = num_validate_examples + num_train_examples451validate_indices = permutation[num_train_examples:num_validate_plus_train]452test_indices = permutation[num_validate_plus_train:]453
454self._node_split = tfgnn_datasets.NodeSplit(455tf.convert_to_tensor(train_indices),456tf.convert_to_tensor(validate_indices),457tf.convert_to_tensor(test_indices))458
459self._train_labels = node_labels + 0 # Make a copy.460self._train_labels[test_indices] = -1461self._train_labels = tf.convert_to_tensor(self._train_labels)462super().__init__()463
464def node_counts(self):465return self._node_counts466
467def edge_lists(self):468return self._edge_lists469
470def num_classes(self):471return self._num_classes472
473def node_split(self):474return self._node_split475
476def labels(self):477return self._train_labels478
479def test_labels(self):480return self._test_labels481
482@property483def labeled_nodeset(self):484return tfgnn.NODES485
486def node_features_dicts_without_labels(self):487return self._node_features_dicts488
489def _download_dataset_extract_features(490self, labels_path, embeddings_path):491cache_dir = os.path.dirname(labels_path)492url = ('https://raw.githubusercontent.com/rashadulrakib/'493'short-text-clustering-enhancement/master/data/stackoverflow/'494'traintest')495tab_separated_filepath = os.path.join(cache_dir, 'traintest.tsv')496_maybe_download_file(url, tab_separated_filepath)497
498data_cluster = {}499with tf.io.gfile.GFile(tab_separated_filepath, 'r') as f:500for line in f:501l1, l2, text = line.strip().split('\t')502data_cluster[text] = (int(l1), int(l2))503
504def remove_cls_sep(masks):505last_1s = np.sum(masks, axis=1) - 1506for i in range(masks.shape[0]):507masks[i][0] = 0508masks[i][last_1s[i]] = 0509return masks510
511def bert_embs(texts):512text_preprocessed = bert_preprocess_model(texts)513bert_results = bert_model(text_preprocessed)514masks = np.expand_dims(515remove_cls_sep(text_preprocessed['input_mask'].numpy()), axis=2)516emb = (np.sum(bert_results['sequence_output'].numpy() * masks, axis=1)517/ np.sum(masks, axis=1))518return emb519
520# Instantiate BERT model.521bert_preprocess_model = tfhub.KerasLayer(522'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')523bert_model = tfhub.KerasLayer(524'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3')525
526# Map keys of `cluster` through `bert_model``527data_cluster_keys = list(data_cluster.keys())528embeddings = []529for i in range(0, len(data_cluster_keys), 100):530embeddings.append(bert_embs(data_cluster_keys[i:i+100]))531embeddings = np.vstack(embeddings)532labels = np.array([data_cluster[t][1] for t in data_cluster_keys])533
534with tf.io.gfile.GFile(labels_path, 'wb') as fout:535np.save(fout, labels)536
537with tf.io.gfile.GFile(embeddings_path, 'wb') as fout:538np.save(fout, embeddings)539
540
541class GSLStackOverflowGraphlessData(StackOverflowGraphlessData, GSLGraphData):542"""Wraps Stackoverflow datasets to be used for graph structure learning."""543
544def __init__(545self,546remove_noise_ratio,547add_noise_ratio,548cache_dir = os.path.expanduser(549os.path.join('~', 'data', 'stackoverflow-bert'))):550StackOverflowGraphlessData.__init__(self, cache_dir=cache_dir)551GSLGraphData.__init__(552self, remove_noise_ratio=remove_noise_ratio,553add_noise_ratio=add_noise_ratio)554
555
556def get_in_memory_graph_data(557dataset_name,558remove_noise_ratio,559add_noise_ratio,560):561"""Getting the dataset based on the name.562
563Args:
564dataset_name: the name of the dataset to prepare.
565remove_noise_ratio: ratio of the existing edge to remove.
566add_noise_ratio: ratio of the non-existing edges to add.
567
568Returns:
569The graph data to be used in training.
570Raises:
571ValueError: if the name of the dataset is not defined.
572"""
573if dataset_name in ('cora', 'citeseer', 'pubmed'):574return GSLPlanetoidGraphData(575dataset_name,576remove_noise_ratio=remove_noise_ratio,577add_noise_ratio=add_noise_ratio,578)579elif dataset_name == 'amazon_photos':580return GSLAmazonPhotosGraphData(581dataset_name,582remove_noise_ratio=remove_noise_ratio,583add_noise_ratio=add_noise_ratio,584)585elif dataset_name == 'stackoverflow':586return GSLStackOverflowGraphlessData(587remove_noise_ratio=remove_noise_ratio,588add_noise_ratio=add_noise_ratio,589)590else:591raise ValueError('Unknown Dataset name: ' + dataset_name)592