google-research
1093 строки · 39.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
17#
18# Licensed under the Apache License, Version 2.0 (the "License");
19# you may not use this file except in compliance with the License.
20# You may obtain a copy of the License at
21#
22# http://www.apache.org/licenses/LICENSE-2.0
23#
24# Unless required by applicable law or agreed to in writing, software
25# distributed under the License is distributed on an "AS IS" BASIS,
26# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27# See the License for the specific language governing permissions and
28# limitations under the License.
29# ==============================================================================
30"""Infrastructure and implementation of in-memory graph data.
31
32Instantiating an object will download a dataset, and cache it locally. The
33datasets will be cached on ~/data/ogb (for "ogbn-" and "ogbl-" datasets), which
34can be overridden by setting environment variable `OGB_CACHE_DIR`; and on
35~/data/planetoid (for "cora", "citeseer", "pubmed"), which can be overridden by
36environment variable `PLANETOID_CACHE_DIR`.
37
38High-level Abstract Classes:
39
40* `InMemoryGraphData`: provides nodes, edges, and features, for a
41homogeneous or a heteregenous graph.
42* `NodeClassificationGraphData`: an `InMemoryGraphData` that also provides
43list of {train, test, validation} nodes, as well as their labels.
44* `LinkPredictionGraphData`: an `InMemoryGraphData` that also provides lists
45of edges in {train, test, validation} partitions.
46
47
48`InMemoryGraphData` implementations can provide
49
50* a single GraphTensor for training on one big graph (e.g., for node
51classification with `tf_trainer.py` or `keras_trainer.py`),
52* a big graph from which in-memory sampling (e.g., `int_arithmetic_sampler`)
53can create dataset of sampled subgraphs (encoded as `tfgnn.GraphTensor`).
54
55All `InMemoryGraphData` implementations automatically inherit abilities of:
56
57* `as_graph_tensor()` .
58* These methods can be plugged-into TF-GNN models and training loops, e.g.,
59for node classification (see `tf_trainer.py` and `keras_trainer.py`).
60* In addition, they can be plugged-into in-memory sampling (see
61`int_arithmetic_sampler.py`, and example trainer script,
62`keras_minibatch_trainer.py`).
63
64
65Concrete implementations:
66
67* Node classification (inheriting `NodeClassificationGraphData`)
68
69* `OgbnData`: Wraps node classification graph data from OGB, i.e., with
70name prefix of "ogbn-", such as, "ogbn-arxiv".
71
72* `PlanetoidGraphData`: wraps graph data that are popularized by GCN paper
73(cora, citeseer, pubmed).
74
75* Link prediction (inheriting `LinkPredictionGraphData`)
76
77* `OgblData`: Wraps link prediction graph data from OGB, i.e., with name
78prefix of "ogbl-", such as, "ogbl-ddi".
79
80
81# Usage Example.
82
83```
84graph_data = datasets.OgbnData('ogbn-arxiv')
85
86# Optionally, make graph undirected.
87graph_data = graph_data.with_self_loops(True)
88
89# add self-loops:
90graph_data = graph_data.with_undirected_edges(True)
91
92# To get GraphTensor and GraphSchema at any graph data:
93graph_tensor = graph_data.as_graph_tensor()
94graph_schema = graph_data.graph_schema()
95
96spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema)
97# or optionally, by "relaxing" the batch dimension of `graph_tensor` (to None):
98# spec = graph_tensor.spec.relax(num_nodes=True, num_edges=True)
99```
100
101The first line is equivalent to
102`graph_data = datasets.get_in_memory_graph_data('ogbn-arxiv')`. Which is more
103general, because it can load other data types:
104* ogbn-* are node-calssificiation datasets for OGB.
105* 'pubmed', 'cora', 'citeseer', correspond to transductive graphs used in
106Planetoid (Yang et al, ICML'16).
107
108
109`graph_tensor` (type `GraphTensor`) contains all nodes, edges, and features.
110If it is a node-classification dataset, the training labels are also populated.
111**For nodes not in training set**, label feature will be `-1`. To also include
112If you want to explicitly get all labels from all partitions, you may:
113
114```
115graph_data = graph_data.with_split(['train', 'test', 'validation'])
116graph_tensor = graph_data.graph_tensor
117```
118
119Chaining `with_*` calls can reduce verbosity. For example,
120```
121graph_data = (
122datasets.OgbnData('ogbn-arxiv').with_undirected_edges(True)
123.with_self_loops(True))
124graph_tensor = graph_data.as_graph_tensor()
125```
126"""
127import abc128import copy129import io130import json131import os132import pickle133import sys134from typing import Any, List, Mapping, NamedTuple, Tuple, Union, Optional135import urllib.request136
137import numpy as np138import scipy139import tensorflow as tf140import tensorflow_gnn as tfgnn141
142
143class InMemoryGraphData(abc.ABC):144"""Abstract class for hold a graph data in-memory (nodes, edges, features).145
146Subclasses must implement methods `node_features_dicts()`, `node_counts()`,
147`edge_lists()`, `node_sets()`, and optionally, `context()`. They inherit
148methods `graph_schema()`, `edge_sets()`, and `as_graph_tensor()` based on
149those.
150"""
151
152def __init__(self, make_undirected = False,153add_self_loops = False):154self._make_undirected = make_undirected155self._add_self_loops = add_self_loops156
157@property158def name(self):159"""Returns name of dataset object. Can be overridden to return data name."""160return self.__class__.__name__161
162def with_undirected_edges(self, make_undirected):163"""Returns same graph data but with undirected edges added (or removed).164
165Subsequent calls to `.graph_schema()` and to `.as_graph_tensor()` will be
166affected. Specifically, the generated output `tfgnn.GraphTensor` (by
167`.as_graph_tensor()`) will reverse all homogeneous edge sets (where its
168source node set equals its target node set). Suppose edge `(i, j)` is
169included in *homogeneous* edge set "MyEdgeSet", then output `GraphTensor`
170will also contain edge `(j, i)` on edge set "MyEdgeSet". If edge `(j, i)`
171already exists, then it will be duplicated.
172
173If make_undirected == True:
174
175* output of `.as_graph_tensor()` will contain only edge-set names that are
176returned by `.edge_sets()`, where each homogeneous edge-set with M edges
177will be expanded to M*2 edges with edge `M+k` reversing edge `k`.
178* output of `.graph_schema()` will contain only edge-sets returned by
179`edge_sets`.
180
181If make_undirected == False:
182
183* output of `.as_graph_tensor()` will contain, for each edge set "EdgeSet"
184(returned by `.edge_sets()`) a new edge-set "rev_EdgeSet" that reverses
185the "EdgeSet".
186* output of `.graph_schema()`. will have both "EdgeSet" and "rev_EdgeSet".
187* `with_reverse_edge_sets()` is an equivalent and a more explicit method
188to add reverse edge sets to the graph tensor and its schema.
189Args:
190make_undirected: If True, subsequent calls to `.graph_schema()` and
191`.as_graph_tensor()` will export an undirected graph. If False, a
192directed graph (with additional "rev_*" edges).
193"""
194modified = copy.copy(self)195modified._make_undirected = make_undirected # pylint: disable=protected-access -- same class.196return modified197
198def with_reverse_edge_sets(self):199"""Returns same graph data but with reverse edge sets added."""200
201# Calling `with_undirected_edges` with `False` input automatically makes the202# output of `.as_graph_tensor()` to contain, for each edge set "EdgeSet"203# (returned by `.edge_sets()`) a new edge-set "rev_EdgeSet" that reverses204# the "EdgeSet". Similarly, output of `.graph_schema()`. will have both205# "EdgeSet" and "rev_EdgeSet".206return self.with_undirected_edges(False)207
208def with_self_loops(self, add_self_loops):209"""Returns same graph data but with self-loops added (or removed).210
211If add_self_loops == True, then subsequent calls to `.as_graph_tensor()`
212will contain edges `[(i, i) for i in range(N_j)]`, for each homogeneous edge
213set j, where `N_j` is the number of nodes in node set connected by edge set
214`j`.
215
216NOTE: self-loops will be added *regardless* if they already exist or not.
217If the datasets already has self-loops, calling this, will double the self-
218loop edges.
219
220Args:
221add_self_loops: If set, self-loops will be amended on subsequent calls to
222`.as_graph_tensor()`. If not, no self-loops will be automatically added.
223"""
224modified = copy.copy(self)225modified._add_self_loops = add_self_loops # pylint: disable=protected-access -- same class.226return modified227
228@abc.abstractmethod229def node_counts(self):230"""Returns total number of graph nodes per node set."""231raise NotImplementedError()232
233@abc.abstractmethod234def node_features_dicts(self):235"""Returns 2-level dict: NodeSetName->FeatureName->Feature tensor.236
237For every node set (`"x"`), feature tensor must have leading dimension equal
238to number of nodes in node set (`.node_counts()["x"]`). Other dimensions are
239dataset specific.
240"""
241raise NotImplementedError()242
243@abc.abstractmethod244def edge_lists(self):245"""Returns dict from "edge type tuple" to int Tensor of shape (2, num_edges).246
247"edge type tuple" string three-tuple:
248`(source node set name, edge set name, target node set name)`.
249where `edge set name` must be unique.
250"""
251raise NotImplementedError()252
253def node_sets(self):254"""Returns node sets of entire graph (dict: node set name -> NodeSet)."""255node_counts = self.node_counts()256features_dicts = self.node_features_dicts()257node_set_names = set(node_counts.keys()).union(features_dicts.keys())258return (259{name: tfgnn.NodeSet.from_fields(sizes=as_tensor([node_counts[name]]),260features=features_dicts.get(name, {}))261for name in node_set_names})262
263def context(self):264return None265
266def as_graph_tensor(self):267"""Returns `GraphTensor` holding the entire graph."""268return tfgnn.GraphTensor.from_pieces(269node_sets=self.node_sets(), edge_sets=self.edge_sets(),270context=self.context())271
272def graph_schema(self):273"""`tfgnn.GraphSchema` instance corresponding to `as_graph_tensor()`."""274# Populate node features specs.275schema = tfgnn.GraphSchema()276for node_set_name, node_set in self.node_sets().items():277node_features = schema.node_sets[node_set_name]278for feat_name, feature in node_set.features.items():279node_features.features[feat_name].dtype = feature.dtype.as_datatype_enum280for dim in feature.shape[1:]:281node_features.features[feat_name].shape.dim.add().size = dim282
283# Populate edge specs.284for edge_type in self.edge_lists().keys():285src_node_set_name, edge_set_name, dst_node_set_name = edge_type286# Populate edges with adjacency and it transpose.287schema.edge_sets[edge_set_name].source = src_node_set_name288schema.edge_sets[edge_set_name].target = dst_node_set_name289if not self._make_undirected:290schema.edge_sets['rev_' + edge_set_name].source = dst_node_set_name291schema.edge_sets['rev_' + edge_set_name].target = src_node_set_name292
293return schema294
295def edge_sets(self):296"""Returns edge sets of entire graph (dict: edge set name -> EdgeSet)."""297edge_sets = {}298node_counts = self.node_counts() if self._add_self_loops else None299for edge_type, edge_list in self.edge_lists().items():300(source_node_set_name, edge_set_name, target_node_set_name) = edge_type301
302if self._make_undirected and source_node_set_name == target_node_set_name:303edge_list = tf.concat([edge_list, edge_list[::-1]], axis=-1)304if self._add_self_loops and source_node_set_name == target_node_set_name:305all_nodes = tf.range(node_counts[source_node_set_name],306dtype=edge_list.dtype)307self_connections = tf.stack([all_nodes, all_nodes], axis=0)308edge_list = tf.concat([edge_list, self_connections], axis=-1)309edge_sets[edge_set_name] = tfgnn.EdgeSet.from_fields(310sizes=tf.shape(edge_list)[1:2],311adjacency=tfgnn.Adjacency.from_indices(312source=(source_node_set_name, edge_list[0]),313target=(target_node_set_name, edge_list[1])))314if not self._make_undirected:315edge_sets['rev_' + edge_set_name] = tfgnn.EdgeSet.from_fields(316sizes=tf.shape(edge_list)[1:2],317adjacency=tfgnn.Adjacency.from_indices(318source=(target_node_set_name, edge_list[1]),319target=(source_node_set_name, edge_list[0])))320return edge_sets321
322def save(self, filename):323"""Superclasses can save themselves to disk."""324raise NotImplementedError()325
326
327class NodeSplit(NamedTuple):328"""Contains 1D int tensors holding positions of {train, valid, test} nodes.329
330This is returned by `NodeClassificationGraphData.node_split()`
331"""
332train: tf.Tensor333validation: tf.Tensor334test: tf.Tensor335
336
337class EdgeSplit(NamedTuple):338"""Contains positive and negative edges in {train, test, valid} partitions.339
340Each `tf.Tensor` will be of shape `[2, num_edges]` with dtype int64.
341"""
342# Only need positive edges for training. The (entire) graph compliment can be343# used for negative edges.344train_edges: tf.Tensor345validation_edges: tf.Tensor346test_edges: tf.Tensor347negative_validation_edges: tf.Tensor348negative_test_edges: tf.Tensor349
350
351class NodeClassificationGraphData(InMemoryGraphData):352"""Adapts `InMemoryGraphData` for node classification settings.353
354Subclasses should information for node classification: (node labels, name of
355node set, and partitions train:validation:test nodes).
356"""
357
358def __init__(self, split = 'train', use_labels_as_features=False):359super().__init__()360self._splits = [split]361self._use_labels_as_features = use_labels_as_features362
363def with_split(self, split = 'train'364):365"""Returns same graph data but with specific partition.366
367Args:
368split: must be one of {"train", "validation", "test"}.
369"""
370splits = split if isinstance(split, (tuple, list)) else [split]371for split in splits:372if split not in ('train', 'validation', 'test'):373raise ValueError(374'split must be one of {"train", "validation", "test"}.')375modified = copy.copy(self)376modified._splits = splits # pylint: disable=protected-access -- same class.377return modified378
379def with_labels_as_features(380self, use_labels_as_features):381"""Returns same graph data with labels as an additional feature on nodes.382
383The feature will be added to the node-set with name `self.labeled_nodeset`.
384
385Args:
386use_labels_as_features: Label feature will be added iff set to True.
387"""
388modified = copy.copy(self)389modified._use_labels_as_features = use_labels_as_features # pylint: disable=protected-access -- same class.390return modified391
392@property393def splits(self):394return copy.copy(self._splits)395
396@abc.abstractmethod397def num_classes(self):398"""Number of node classes. Max of `labels` should be `< num_classes`."""399raise NotImplementedError('num_classes')400
401@abc.abstractmethod402def node_split(self):403"""`NodeSplit` with attributes `train`, `validation`, `test` set.404
405The attributes are set to indices of the `labeled_nodeset`. Specifically,
406they correspond to leading dimension of features of the node set.
407"""
408raise NotImplementedError()409
410@abc.abstractmethod411def labels(self):412"""int vector containing labels for train & validation nodes.413
414Size of vector is number of nodes in the labeled node set. In particular:
415`self.labels().shape[0] == self.node_counts()[self.labeled_nodeset]`.
416Specifically, the vector has as many entries as there are nodes belonging to
417the node set that this task aims to predict labels for.
418
419Entry `labels()[i]` will be -1 iff `i in self.node_split().test`. Otherwise,
420`labels()[i]` will be int in range [`0`, `self.num_classes() - 1`].
421"""
422raise NotImplementedError()423
424@abc.abstractmethod425def test_labels(self):426"""Like the above but contains no -1's.427
428Every {train, valid, test} node will have its class label.
429"""
430raise NotImplementedError()431
432@property433@abc.abstractmethod434def labeled_nodeset(self):435"""Name of node set which `labels` and `node_splits` reference."""436raise NotImplementedError()437
438@abc.abstractmethod439def node_features_dicts_without_labels(self):440raise NotImplementedError()441
442def node_features_dicts(self):443"""Implements a method required by the base class.444
445This method combines the data from `labels()` or `test_labels()` with the
446data from `node_features_dicts_without_labels()` into a single features
447dict.
448
449Subclasses need to implement aforementioned methods and may inherit this.
450
451Returns:
452NodeSetName -> FeatureName -> Feature Tensor.
453"""
454node_features_dicts = self.node_features_dicts_without_labels()455node_features_dicts = {ns: dict(features) # Shallow copy.456for ns, features in node_features_dicts.items()}457if self._use_labels_as_features:458if 'test' in self._splits:459node_features_dicts[self.labeled_nodeset]['label'] = self.test_labels()460else:461node_features_dicts[self.labeled_nodeset]['label'] = self.labels()462
463return node_features_dicts464
465def context(self):466node_split = self.node_split()467seed_nodes = tf.concat(468[getattr(node_split, split) for split in self._splits], axis=0)469seed_nodes = tf.expand_dims(seed_nodes, axis=0)470seed_feature_name = 'seed_nodes.' + self.labeled_nodeset471
472return tfgnn.Context.from_fields(features={seed_feature_name: seed_nodes})473
474def graph_schema(self):475graph_schema = super().graph_schema()476context_features = graph_schema.context.features477context_features['seed_nodes.' + self.labeled_nodeset].dtype = (478tf.int64.as_datatype_enum)479return graph_schema480
481def save(self, filename):482"""Saves the dataset on numpy compressed (.npz) file.483
484The file runs once the functions,
485(labeled_nodeset, test_labels, labels, node_split, edge_lists, node_counts,
486node_features, num_classes),
487composes a flat dict (keys are json-encoded arrays), then writes as numpy
488file. Flat dict is needed as numpy only saves named arrays, not nested
489structures.
490
491Args:
492filename: file path to save onto. ".npz" extension is recommended. Parent
493directory must exist.
494"""
495features_without_labels = self.node_features_dicts_without_labels()496node_split = self.node_split()497
498attribute_dict = {499('num_classes',): self.num_classes(),500('node_split', 'train'): node_split.train.numpy(),501('node_split', 'test'): node_split.test.numpy(),502('node_split', 'validation'): node_split.validation.numpy(),503('labels',): self.labels().numpy(),504('test_labels',): self.test_labels().numpy(),505('labeled_nodeset',): self.labeled_nodeset,506}507
508# Edge sets.509for (src_name, es_name, tgt_name), es_indices in self.edge_lists().items():510key = ('e', '#', src_name, es_name, tgt_name)511attribute_dict[key] = es_indices.numpy()512
513for ns_name, features in features_without_labels.items():514for feature_name, feature_tensor in features.items():515attribute_dict[('n', ns_name, feature_name)] = feature_tensor.numpy()516
517for node_set_name, node_count in self.node_counts().items():518attribute_dict[('nc', node_set_name)] = node_count519
520bytes_io = io.BytesIO()521attribute_dict = {json.dumps(k): v for k, v in attribute_dict.items()}522np.savez_compressed(bytes_io, **attribute_dict)523with tf.io.gfile.GFile(filename, 'wb') as f:524f.write(bytes_io.getvalue())525
526@staticmethod527def load(filename):528"""Loads from disk `NodeClassificationGraphData` that was `save()`ed."""529dataset_dict = dict(np.load(tf.io.gfile.GFile(filename, 'rb')))530dataset_dict = {tuple(json.loads(k)): v for k, v in dataset_dict.items()}531edge_lists = {}532node_features = {}533node_counts = {}534for key, array in dataset_dict.items():535# edge lists.536if key[0] == 'e':537if key[1] != '#':538raise ValueError('Expecting ("e", "#", ...) but got %s' % str(key))539src_name = key[2]540es_name = key[3]541tgt_name = key[4]542indices = as_tensor(array)543edge_lists[(src_name, es_name, tgt_name)] = indices544# node features.545if key[0] == 'n':546node_set_name = key[1]547feature_name = key[2]548if node_set_name not in node_features:549node_features[node_set_name] = {}550node_features[node_set_name][feature_name] = as_tensor(array)551if key[0] == 'nc':552node_counts[key[1]] = int(array)553
554return _PreloadedNodeClassificationGraphData(555num_classes=dataset_dict[('num_classes',)],556node_features_dicts_without_labels=node_features,557node_counts=node_counts,558edge_lists=edge_lists,559node_split=NodeSplit(560train=as_tensor(dataset_dict[('node_split', 'train')]),561validation=as_tensor(dataset_dict[('node_split', 'validation')]),562test=as_tensor(dataset_dict[('node_split', 'test')])),563labels=as_tensor(dataset_dict[('labels',)]),564test_labels=as_tensor(dataset_dict[('test_labels',)]),565labeled_nodeset=str(dataset_dict[('labeled_nodeset',)]))566
567
568class _PreloadedNodeClassificationGraphData(NodeClassificationGraphData):569"""Dataset from pre-computed attributes."""570
571def __init__(572self, num_classes,573node_features_dicts_without_labels,574node_counts,575edge_lists,576node_split, labels, test_labels,577labeled_nodeset):578super().__init__()579self._num_classes = num_classes580self._node_features_dicts_without_labels = (581node_features_dicts_without_labels)582self._node_counts = node_counts583self._edge_lists = edge_lists584self._node_split = node_split585self._labels = labels586self._test_labels = test_labels587self._labeled_nodeset = labeled_nodeset588
589def num_classes(self):590return self._num_classes591
592def node_features_dicts_without_labels(self):593return self._node_features_dicts_without_labels594
595def node_counts(self):596return self._node_counts597
598def edge_lists(self):599return self._edge_lists600
601def node_split(self):602return self._node_split603
604def labels(self):605return self._labels606
607def test_labels(self):608return self._test_labels609
610@property611def labeled_nodeset(self):612return self._labeled_nodeset613
614
615class _OgbGraph:616"""Wraps data exposed by OGB graph objects, while enforcing heterogeneity.617
618Attributes offered by this class are consistent with the APIs of GraphData.
619"""
620
621def __init__(self, graph):622"""Reads dict OGB `graph` and into the attributes defined below.623
624Args:
625graph: Dict, described in
626https://github.com/snap-stanford/ogb/blob/master/ogb/io/README.md#2-saving-graph-list
627"""
628if 'edge_index_dict' in graph: # Heterogeneous graph629assert 'num_nodes_dict' in graph630assert 'node_feat_dict' in graph631
632# node set name -> feature name -> feature matrix (numNodes x featDim).633node_set = {node_set_name: {'feat': as_tensor(feat)}634for node_set_name, feat in graph['node_feat_dict'].items()635if feat is not None}636# Populate remaining features637for key, node_set_name_to_feat in graph.items():638if key.startswith('node_') and key != 'node_feat_dict':639feat_name = key.split('node_', 1)[-1]640for node_set_name, feat in node_set_name_to_feat.items():641node_set[node_set_name][feat_name] = as_tensor(feat)642self._num_nodes_dict = graph['num_nodes_dict']643self._node_feat_dict = node_set644self._edge_index_dict = tf.nest.map_structure(645as_tensor, graph['edge_index_dict'])646else: # Homogenous graph. Make heterogeneous.647if graph.get('node_feat', None) is not None:648node_features = {649tfgnn.NODES: {'feat': as_tensor(graph['node_feat'])}650}651else:652node_features = {653tfgnn.NODES: {654'feat': tf.zeros([graph['num_nodes'], 0], dtype=tf.float32)655}656}657
658self._edge_index_dict = {659(tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): as_tensor(660graph['edge_index']),661}662self._num_nodes_dict = {tfgnn.NODES: graph['num_nodes']}663self._node_feat_dict = node_features664
665@property666def num_nodes_dict(self):667"""Maps "node set name" -> number of nodes."""668return self._num_nodes_dict669
670@property671def node_feat_dict(self):672"""Maps "node set name" to dict of "feature name"->tf.Tensor."""673return self._node_feat_dict674
675@property676def edge_index_dict(self):677"""Adjacency lists for all edge sets.678
679Returns:
680Dict (source node set name, edge set name, target node set name) -> edges.
681Where `edges` is tf.Tensor of shape (2, num edges), with `edges[0]` and
682`edges[1]`, respectively, containing source and target node IDs (as 1D int
683tf.Tensor).
684"""
685return self._edge_index_dict686
687
688def _get_ogbn_dataset(dataset_name, cache_dir = None):689"""Imports ogb and returns `NodePropPredDataset`."""690# This is done on purpose: we only import ogb if an ogb dataset is requested.691import ogb.nodeproppred # pylint: disable=g-import-not-at-top692return ogb.nodeproppred.NodePropPredDataset(dataset_name, root=cache_dir)693
694
695def _get_ogbl_dataset(dataset_name, cache_dir = None):696"""Imports ogb and returns `LinkPropPredDataset`."""697# This is done on purpose: we only import ogb if an ogb dataset is requested.698import ogb.linkproppred # pylint: disable=g-import-not-at-top699return ogb.linkproppred.LinkPropPredDataset(dataset_name, root=cache_dir)700
701
702class OgbnData(NodeClassificationGraphData):703"""Wraps node classification graph data of ogbn-* for in-memory learning."""704
705def __init__(self, dataset_name, cache_dir=None):706super().__init__()707self._dataset_name = dataset_name708if cache_dir is None:709cache_dir = os.environ.get(710'OGB_CACHE_DIR', os.path.expanduser(os.path.join('~', 'data', 'ogb')))711
712self._ogb_dataset = _get_ogbn_dataset(dataset_name, cache_dir)713self._graph, self._node_labels, self._node_split, self._labeled_nodeset = (714OgbnData._to_heterogeneous(self._ogb_dataset))715
716# rehape from [N, 1] to [N].717self._node_labels = self._node_labels[:, 0]718
719# train labels (test set to -1).720self._train_labels = np.copy(self._node_labels)721self._train_labels[self._node_split.test] = -1722
723self._train_labels = as_tensor(self._train_labels)724self._node_labels = as_tensor(self._node_labels)725
726@property727def name(self):728return self._dataset_name729
730@staticmethod731def _to_heterogeneous(732ogb_dataset):733"""Returns heterogeneous dicts from homogeneous or heterogeneous OGB dataset.734
735Args:
736ogb_dataset: OGBN dataset. It can be homogeneous (single node set type,
737single edge set type), or heterogeneous (various node/edge set types),
738and returns data structure as-if the dataset is heterogeneous (i.e.,
739names each node/edge set). If input is a homogeneous graph, then the
740node set will be named "nodes" and the edge set will be named "edges".
741
742Returns:
743tuple: `(ogb_graph, node_labels, idx_split, labeled_nodeset)`, where:
744`ogb_graph` is instance of _OgbGraph.
745`node_labels`: np.array of labels, with .shape[0] equals number of nodes
746in node set with name `labeled_nodeset`.
747`idx_split`: instance of NodeSplit. Members `train`, `test` and `valid`,
748respectively, contain indices of nodes in node set with name
749`labeled_nodeset`.
750`labeled_nodeset`: name of node set that the node-classification task is
751designed over.
752"""
753graph, node_labels = ogb_dataset[0]754ogb_graph = _OgbGraph(graph)755if 'edge_index_dict' in graph: # Graph is heterogeneous756assert 'num_nodes_dict' in graph757assert 'node_feat_dict' in graph758labeled_nodeset = list(node_labels.keys())759if len(labeled_nodeset) != 1:760raise ValueError('Expecting OGB dataset with *one* node set with '761'labels. Found: ' + ', '.join(labeled_nodeset))762labeled_nodeset = labeled_nodeset[0]763
764node_labels = node_labels[labeled_nodeset]765# idx_split is dict: {'train': {labeled_nodeset: np.array}, 'test': ...}.766idx_split = ogb_dataset.get_idx_split()767# Change to {'train': Tensor, 'test': Tensor, 'valid': Tensor}768idx_split = {split_name: as_tensor(split_dict[labeled_nodeset])769for split_name, split_dict in idx_split.items()}770# third-party OGB class returns dict with key 'valid'. Make consistent771# with TF nomenclature by renaming.772idx_split['validation'] = idx_split.pop('valid') # Rename773idx_split = NodeSplit(**idx_split)774
775return ogb_graph, node_labels, idx_split, labeled_nodeset776
777# Copy other node information.778for key, value in graph.items():779if key != 'node_feat' and key.startswith('node_'):780key = key.split('node_', 1)[-1]781ogb_graph.node_feat_dict[tfgnn.NODES][key] = as_tensor(value) # pytype: disable=unsupported-operands # always-use-property-annotation782idx_split = ogb_dataset.get_idx_split()783idx_split['validation'] = idx_split.pop('valid') # Rename784idx_split = NodeSplit(**tf.nest.map_structure(785tf.convert_to_tensor, idx_split))786return ogb_graph, node_labels, idx_split, tfgnn.NODES787
788def num_classes(self):789return self._ogb_dataset.num_classes790
791def node_features_dicts_without_labels(self):792# Deep-copy dict (*but* without copying tf.Tensor objects).793node_sets = self._graph.node_feat_dict794node_sets = {node_set_name: dict(node_set.items())795for node_set_name, node_set in node_sets.items()}796node_counts = self.node_counts()797for node_set_name, count in node_counts.items():798if node_set_name not in node_sets:799node_sets[node_set_name] = {}800feat_dict = node_sets[node_set_name]801feat_dict['#id'] = tf.range(count, dtype=tf.int32)802return node_sets803
804@property805def labeled_nodeset(self):806return self._labeled_nodeset807
808def node_counts(self):809return self._graph.num_nodes_dict810
811def edge_lists(self):812return self._graph.edge_index_dict813
814def node_split(self):815return self._node_split816
817def labels(self):818return self._train_labels819
820def test_labels(self):821"""int numpy array of length num_nodes containing train and test labels."""822return self._node_labels823
824
825def _maybe_download_file(source_url, destination_path, make_dirs=True):826"""Downloads URL `source_url` onto file `destination_path` if not present."""827if not tf.io.gfile.exists(destination_path):828dir_name = os.path.dirname(destination_path)829if make_dirs:830try:831tf.io.gfile.makedirs(dir_name)832except FileExistsError:833pass834
835with urllib.request.urlopen(source_url) as fin:836with tf.io.gfile.GFile(destination_path, 'wb') as fout:837fout.write(fin.read())838
839
840class PlanetoidGraphData(NodeClassificationGraphData):841"""Wraps Planetoid node-classificaiton datasets.842
843These datasets first appeared in the Planetoid [1] paper and popularized by
844the GCN paper [2].
845
846[1] Yang et al, ICML'16
847[2] Kipf & Welling, ICLR'17.
848"""
849
850def __init__(self, dataset_name, cache_dir=None):851super().__init__()852self._dataset_name = dataset_name853allowed_names = ('pubmed', 'citeseer', 'cora')854
855url_template = (856'https://github.com/kimiyoung/planetoid/blob/master/data/'857'ind.%s.%s?raw=true')858file_parts = ['ally', 'allx', 'graph', 'ty', 'tx', 'test.index']859if dataset_name not in allowed_names:860raise ValueError('Dataset must be one of: ' + ', '.join(allowed_names))861if cache_dir is None:862cache_dir = os.environ.get(863'PLANETOID_CACHE_DIR', os.path.expanduser(864os.path.join('~', 'data', 'planetoid')))865base_path = os.path.join(cache_dir, 'ind.%s' % dataset_name)866# Download all files.867for file_part in file_parts:868source_url = url_template % (dataset_name, file_part)869destination_path = os.path.join(870cache_dir, 'ind.%s.%s' % (dataset_name, file_part))871_maybe_download_file(source_url, destination_path)872
873# Load data files.874edge_lists = pickle.load(tf.io.gfile.GFile(base_path + '.graph', 'rb'))875allx = PlanetoidGraphData.load_x(base_path + '.allx')876ally = np.load(tf.io.gfile.GFile(base_path + '.ally', 'rb'),877allow_pickle=True)878
879testx = PlanetoidGraphData.load_x(base_path + '.tx')880
881# Add test882test_idx = list(883map(int, tf.io.gfile.GFile(884base_path + '.test.index').read().split('\n')[:-1]))885
886num_test_examples = max(test_idx) - min(test_idx) + 1887sparse_zeros = scipy.sparse.csr_matrix((num_test_examples, allx.shape[1]),888dtype='float32')889
890allx = scipy.sparse.vstack((allx, sparse_zeros))891llallx = allx.tolil()892llallx[test_idx] = testx893self._allx = as_tensor(np.array(llallx.todense()))894
895testy = np.load(tf.io.gfile.GFile(base_path + '.ty', 'rb'),896allow_pickle=True)897ally = np.pad(ally, [(0, num_test_examples), (0, 0)], mode='constant')898ally[test_idx] = testy899
900self._num_nodes = len(edge_lists)901self._num_classes = ally.shape[1]902self._node_labels = np.argmax(ally, axis=1)903self._train_labels = self._node_labels + 0 # Copy.904self._train_labels[test_idx] = -1905self._node_labels = as_tensor(self._node_labels)906self._train_labels = as_tensor(self._train_labels)907self._test_idx = as_tensor(np.array(test_idx, dtype='int32'))908self._node_split = None # Populated on `node_split()`909
910# Will be used to construct (sparse) adjacency matrix.911adj_src = []912adj_target = []913for node, neighbors in edge_lists.items():914adj_src.extend([node] * len(neighbors))915adj_target.extend(neighbors)916
917self._edge_list = as_tensor(np.stack([adj_src, adj_target], axis=0))918
919@property920def name(self):921return self._dataset_name922
923@staticmethod924def load_x(filename):925if sys.version_info > (3, 0):926return pickle.load(tf.io.gfile.GFile(filename, 'rb'), encoding='latin1')927else:928return np.load(tf.io.gfile.GFile(filename))929
930def num_classes(self):931return self._num_classes932
933def node_features_dicts_without_labels(self):934features = {'feat': self._allx}935features['#id'] = tf.range(self._num_nodes, dtype=tf.int32)936return {tfgnn.NODES: features}937
938def node_counts(self):939return {tfgnn.NODES: self._num_nodes}940
941def edge_lists(self):942return {(tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): self._edge_list}943
944def node_split(self):945if self._node_split is None:946# By default, we mimic Planetoid & GCN setup -- i.e., 20 labels per class.947labels_per_class = int(os.environ.get('PLANETOID_LABELS_PER_CLASS', '20'))948num_train_nodes = labels_per_class * self.num_classes()949num_validation_nodes = 500950train_ids = tf.range(num_train_nodes, dtype=tf.int32)951validation_ids = tf.range(952num_train_nodes,953num_train_nodes + num_validation_nodes, dtype=tf.int32)954self._node_split = NodeSplit(train=train_ids, validation=validation_ids,955test=self._test_idx)956return self._node_split957
958@property959def labeled_nodeset(self):960return tfgnn.NODES961
962def labels(self):963return self._train_labels964
965def test_labels(self):966"""int numpy array of length num_nodes containing train and test labels."""967return self._node_labels968
969
970class LinkPredictionGraphData(InMemoryGraphData):971"""Superclasses must wrap dataset of graph(s) for link-prediction tasks."""972
973@abc.abstractmethod974def edge_split(self):975"""Returns edge endpoints for {train, test, valid} partitions."""976raise NotImplementedError()977
978@property979@abc.abstractmethod980def target_edgeset(self):981"""Name of edge set over which link prediction is defined."""982raise NotImplementedError()983
984@property985def source_node_set_name(self):986"""Node set name of source node of (task) target_edgeset."""987return self.graph_schema().edge_sets[self.target_edgeset].source988
989@property990def target_node_set_name(self):991"""Node set name of target node of (task) target_edgeset."""992return self.graph_schema().edge_sets[self.target_edgeset].target993
994@property995def num_source_nodes(self):996"""Number of nodes in the source endpoint of (task) target_edgeset."""997return self.node_counts()[self.source_node_set_name]998
999@property1000def num_target_nodes(self):1001"""Number of nodes in the target endpoint of (task) target_edgeset."""1002return self.node_counts()[self.target_node_set_name]1003
1004
1005class OgblData(LinkPredictionGraphData):1006"""Wraps link prediction datasets of ogbl-* for in-memory learning."""1007
1008def __init__(self, dataset_name, cache_dir = None):1009super().__init__()1010self._dataset_name = dataset_name1011if cache_dir is None:1012cache_dir = os.environ.get(1013'OGB_CACHE_DIR', os.path.expanduser(os.path.join('~', 'data', 'ogb')))1014
1015self._ogb_dataset = _get_ogbl_dataset(dataset_name, cache_dir)1016
1017ogb_edge_dict = self._ogb_dataset.get_edge_split()1018self._edge_split = EdgeSplit(1019train_edges=as_tensor(ogb_edge_dict['train']['edge']),1020validation_edges=as_tensor(ogb_edge_dict['train']['edge']),1021test_edges=as_tensor(ogb_edge_dict['test']['edge']),1022negative_validation_edges=as_tensor(ogb_edge_dict['valid']['edge_neg']),1023negative_test_edges=as_tensor(ogb_edge_dict['test']['edge_neg']))1024
1025self._ogb_graph = _OgbGraph(self._ogb_dataset.graph)1026
1027@property1028def name(self):1029return self._dataset_name1030
1031def node_features_dicts(self, add_id = True):1032features = self._ogb_graph.node_feat_dict1033# 2-level dict shallow copy. Inner value stores reference to tf.Tensor,1034features = {node_set_name: copy.copy(features)1035for node_set_name, features in features.items()}1036if add_id:1037counts = self.node_counts()1038for node_set_name, feats in features.items():1039feats['#id'] = tf.range(counts[node_set_name], dtype=tf.int32) # pytype: disable=unsupported-operands # always-use-property-annotation1040return features1041
1042def node_counts(self):1043return dict(self._ogb_graph.num_nodes_dict) # Return copy.1044
1045def edge_lists(self):1046return dict(self._ogb_graph.edge_index_dict) # Return shallow copy.1047
1048def edge_split(self):1049return self._edge_split1050
1051@property1052def target_edgeset(self):1053return tfgnn.EDGES1054
1055
1056def get_in_memory_graph_data(dataset_name):1057if dataset_name.startswith('ogbn-'):1058return OgbnData(dataset_name)1059elif dataset_name.startswith('ogbl-'):1060return OgblData(dataset_name)1061elif dataset_name in ('cora', 'citeseer', 'pubmed'):1062return PlanetoidGraphData(dataset_name)1063else:1064raise ValueError('Unknown Dataset name: ' + dataset_name)1065
1066
1067# Shorthand. Can be replaced with: `as_tensor = tf.convert_to_tensor`.
1068def as_tensor(obj):1069"""short-hand for tf.convert_to_tensor."""1070return tf.convert_to_tensor(obj)1071
1072
1073def load_ogbn_graph_tensor(1074dataset_path, *, add_reverse_edge_sets = False1075):1076"""Load OGBN graph data as a graph tensor from numpy compressed (.npz) files.1077
1078To generate the .npz files from the original OGB dataset, please refer to
1079tensorflow_gnn/converters/ogb/convert_ogb_to_npz.py
1080
1081Args:
1082dataset_path: Path to the saved OGBN numpy compressed (.npz) files.
1083add_reverse_edge_sets: Flag to determine whether to add reversed edge sets.
1084
1085Returns:
1086A tfgnn.GraphTensor comprising of the full OGBN graph loaded in-memory.
1087"""
1088graph_data = NodeClassificationGraphData.load(dataset_path)1089graph_data = graph_data.with_labels_as_features(True)1090if add_reverse_edge_sets:1091graph_data = graph_data.with_reverse_edge_sets()1092graph_tensor = graph_data.as_graph_tensor()1093return graph_tensor1094