google-research
218 строк · 6.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Merger layer of the GSL Layer.
17
18This step merges an input graph with a generated graph and returns a
19GraphTensor as the final output.
20"""
21import tensorflow as tf22
23from ugsl import datasets24
25
26@tf.keras.utils.register_keras_serializable(package="GSL")27class Merger(tf.keras.layers.Layer):28
29def __init__(self, graph_data):30super().__init__()31self._graph_data = graph_data32
33def get_config(self):34return dict(graph_data=self._graph_data, **super().get_config())35
36
37class WeightedSum(Merger):38"""Sums a generated adjacency with a given adjacency into a GraphTensor."""39
40def __init__(41self,42graph_data,43dropout_rate,44given_adjacency_weight = 1.0,45):46super().__init__(graph_data)47self._dropout_rate = dropout_rate48self._dropout_layer = tf.keras.layers.Dropout(dropout_rate)49self._given_adjacency_weight = given_adjacency_weight50
51def call(self, inputs):52graph_structure = inputs[0]53node_embeddings = inputs[1]54noisy_gt = self._graph_data.get_input_graph_tensor()55given_noisy_sources = noisy_gt.edge_sets["edges"].adjacency.source56given_noisy_targets = noisy_gt.edge_sets["edges"].adjacency.target57noisy_sources = tf.concat(58(graph_structure.sources, given_noisy_sources), axis=059)60noisy_targets = tf.concat(61(graph_structure.targets, given_noisy_targets), axis=062)63noisy_weights = tf.concat(64(65graph_structure.weights,66self._given_adjacency_weight * tf.ones(given_noisy_sources.shape),67),68axis=0,69)70graph_tensor = self._graph_data.as_graph_tensor_given_adjacency(71[noisy_sources, noisy_targets],72edge_weights=self._dropout_layer(noisy_weights),73node_features=node_embeddings,74)75return graph_tensor76
77def get_config(self):78return dict(79dropout_rate=self._dropout_rate,80given_adjacency_weight=self._given_adjacency_weight,81**super().get_config(),82)83
84
85class ToGraphTensor(Merger):86"""ToGraphTensor converts an adjacency in the form of rows, columns, and weights into a GraphTensor."""87
88def __init__(89self,90graph_data,91dropout_rate,92**kwargs,93):94super().__init__(graph_data)95self._dropout_rate = dropout_rate96self._dropout_layer = tf.keras.layers.Dropout(dropout_rate)97
98def call(self, inputs):99graph_structure = inputs[0]100node_embeddings = inputs[1]101graph_tensor = self._graph_data.as_graph_tensor_given_adjacency(102[graph_structure.sources, graph_structure.targets],103edge_weights=self._dropout_layer(graph_structure.weights),104node_features=node_embeddings,105)106return graph_tensor107
108def get_config(self):109return dict(dropout_rate=self._dropout_rate, **super().get_config())110
111
112class RandomGraphTensor(Merger):113"""Generates a random graph tensor to be tested as baseline in the framework."""114
115def __init__(116self,117graph_data,118dropout_rate,119**kwargs,120):121super().__init__(graph_data)122self._graph_data = graph_data123self._dropout_rate = dropout_rate124self._dropout_layer = tf.keras.layers.Dropout(dropout_rate)125input_gt = self._graph_data.get_input_graph_tensor()126number_of_edges = input_gt.edge_sets["edges"].adjacency.source.shape[0]127number_of_nodes = input_gt.node_sets["nodes"].features["feat"].shape[0]128self._random_sources = tf.random.uniform(129shape=(number_of_edges,),130minval=0,131maxval=number_of_nodes,132dtype=tf.int32,133)134self._random_targets = tf.random.uniform(135shape=(number_of_edges,),136minval=0,137maxval=number_of_nodes,138dtype=tf.int32,139)140self._random_weights = tf.random.uniform(141shape=(number_of_edges,), minval=0, maxval=1.0, dtype=tf.float32142)143
144def call(self, inputs):145node_embeddings = inputs[1]146graph_tensor = self._graph_data.as_graph_tensor_given_adjacency(147tf.stack([self._random_sources, self._random_targets], axis=0),148edge_weights=self._dropout_layer(self._random_weights),149node_features=node_embeddings,150)151return graph_tensor152
153def get_config(self):154return dict(155dropout_rate=self._dropout_rate,156**super().get_config(),157)158
159
160class InputGraphTensor(Merger):161"""Sums a generated adjacency with a given adjacency into a GraphTensor."""162
163def __init__(164self,165graph_data,166dropout_rate,167**kwargs,168):169super().__init__(graph_data)170self._dropout_rate = dropout_rate171self._dropout_layer = tf.keras.layers.Dropout(dropout_rate)172
173def call(self, inputs):174node_embeddings = inputs[1]175noisy_gt = self._graph_data.get_input_graph_tensor()176noisy_sources = noisy_gt.edge_sets["edges"].adjacency.source177noisy_targets = noisy_gt.edge_sets["edges"].adjacency.target178noisy_weights = tf.ones(noisy_sources.shape)179
180graph_tensor = self._graph_data.as_graph_tensor_given_adjacency(181[noisy_sources, noisy_targets],182edge_weights=self._dropout_layer(noisy_weights),183node_features=node_embeddings,184)185return graph_tensor186
187def get_config(self):188return dict(189dropout_rate=self._dropout_rate,190**super().get_config(),191)192
193
194def get_merger(195graph_data, name, **kwargs196):197"""Return the corresponding merger based on the name provided.198
199Args:
200graph_data: the GSL graph data.
201name: name of the merger to use in the gsl framework.
202**kwargs:
203
204Returns:
205Merger associated to the provided name.
206Raises:
207ValueError: if the merger name is not defined.
208"""
209if name == "none":210return ToGraphTensor(graph_data, **kwargs)211elif name == "weighted-sum":212return WeightedSum(graph_data, **kwargs)213elif name == "random":214return RandomGraphTensor(graph_data, **kwargs)215elif name == "input":216return InputGraphTensor(graph_data, **kwargs)217else:218raise ValueError(f"Merger {name} is not defined.")219