google-research
216 строк · 6.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""EdgeScorer layer of the GSL Model.
17
18The edge scorer contains multiple functions tried in existing models.
19"""
20import tensorflow as tf
21
22
23@tf.keras.utils.register_keras_serializable(package="GSL")
24class EdgeScorer(tf.keras.layers.Layer):
25"""Wraps edge scorers to be used for graph structure learning."""
26
27def compute_cosine_similarities(self, node_embeddings):
28node_embeddings = tf.math.l2_normalize(node_embeddings, dim=1)
29return self.compute_dot_product_similarities(node_embeddings)
30
31def compute_dot_product_similarities(self, node_embeddings):
32similarities = tf.matmul(node_embeddings, node_embeddings, transpose_b=True)
33return similarities
34
35
36@tf.keras.utils.register_keras_serializable(package="GSL")
37class Attentive(EdgeScorer):
38"""Generates a fully connected adjacency using an attentive approach.
39
40This edge scorer, first uses a vector to project the node features and then
41creates a fully connected graph of their similarities.
42"""
43
44def __init__(
45self,
46initialization,
47nheads,
48seed = 133337,
49**kwargs,
50):
51super().__init__()
52initialization_dict = {"method1": "ones", "method2": "random_uniform"}
53self._initialization = initialization_dict[initialization]
54if nheads <= 0:
55raise ValueError("Number of heads should be greater than zero.")
56self._nheads = nheads
57self._seed = seed
58
59def build(self, input_shape):
60node_embedding_dim = input_shape[-1]
61self._attention_vectors = []
62for _ in range(self._nheads):
63# Multiple instances of initializer should be created to give different
64# outputs.
65if self._initialization == "ones":
66cfg = {"class_name": self._initialization, "config": {}}
67else:
68cfg = {
69"class_name": self._initialization,
70"config": {"seed": self._seed},
71}
72self._seed += 1
73initializer = tf.keras.initializers.get(cfg)
74self._attention_vectors.append(
75tf.Variable(
76initial_value=initializer(shape=(node_embedding_dim,)),
77trainable=True,
78)
79)
80
81def compute_one_head(
82self, features, attention_vector
83):
84node_embeddings = tf.multiply(attention_vector, features)
85return self.compute_cosine_similarities(node_embeddings)
86
87def call(self, inputs):
88similarities = 0
89for vector in self._attention_vectors:
90similarities += self.compute_one_head(inputs, vector)
91return similarities / self._nheads
92
93def get_config(self):
94return dict(
95initialization=self._initialization,
96nheads=self._nheads,
97seed=self._seed,
98**super().get_config(),
99)
100
101
102@tf.keras.utils.register_keras_serializable(package="GSL")
103class FP(EdgeScorer):
104"""Generates a fully connected adjacency with all values as parameters."""
105
106def __init__(
107self,
108node_features,
109initialization,
110**kwargs,
111):
112super().__init__()
113initialization_dict = {"method1": "similarity", "method2": "glorot_uniform"}
114self._initialization = initialization_dict[initialization]
115self._node_features = node_features
116
117def build(self, input_shape=None):
118number_of_nodes = input_shape[-2]
119if self._initialization == "similarity":
120self._similarities = tf.Variable(
121initial_value=self.compute_cosine_similarities(self._node_features),
122trainable=True,
123)
124else:
125initializer = tf.keras.initializers.get(self._initialization)
126self._similarities = tf.Variable(
127initial_value=initializer(shape=(number_of_nodes, number_of_nodes)),
128trainable=True,
129)
130
131def call(self, inputs):
132# FP edge scorer only returns the adjacecy as is.
133return self._similarities
134
135def get_config(self):
136return dict(
137node_features=self._node_features,
138initialization=self._initialization,
139**super().get_config(),
140)
141
142
143@tf.keras.utils.register_keras_serializable(package="GSL")
144class MLP(EdgeScorer):
145"""Generates a fully connected adjacency using an MLP model.
146
147This edge scorer, first uses an MLP to project the node features and then
148creates a fully connected graph of their similarities.
149"""
150
151def __init__(
152self,
153hidden_size,
154output_size,
155nlayers,
156activation,
157initialization,
158dropout_rate,
159**kwargs,
160):
161super().__init__()
162initialization_dict = {"method1": "identity", "method2": "glorot_uniform"}
163self._initialization = initialization_dict[initialization]
164self._hidden_size = hidden_size
165self._output_size = output_size
166self._nlayers = nlayers
167self._activation = activation
168self._dropout_rate = dropout_rate
169
170def build(self, input_shape=None):
171layers = []
172for i in range(self._nlayers):
173layers.append(
174tf.keras.layers.Dense(
175units=self._hidden_size
176if i < (self._nlayers - 1)
177else self._output_size,
178activation=self._activation if i < (self._nlayers - 1) else None,
179kernel_initializer=tf.keras.initializers.get(
180self._initialization
181),
182use_bias=False,
183)
184)
185if i < (self._nlayers - 1):
186layers.append(tf.keras.layers.Dropout(rate=self._dropout_rate))
187self._model = tf.keras.Sequential(layers)
188
189def call(self, inputs):
190node_embeddings = self._model(inputs)
191similarities = self.compute_cosine_similarities(node_embeddings)
192return similarities
193
194def get_config(self):
195return dict(
196hidden_size=self._hidden_size,
197output_size=self._output_size,
198nlayers=self._nlayers,
199activation=self._activation,
200initialization=self._initialization,
201dropout_rate=self._dropout_rate,
202**super().get_config(),
203)
204
205
206def get_edge_scorer(
207name, node_features, **kwargs
208):
209if name == "mlp":
210return MLP(**kwargs)
211elif name == "attentive":
212return Attentive(**kwargs)
213elif name == "fp":
214return FP(node_features, **kwargs)
215else:
216raise ValueError(f"Edge scorer {name} is not defined.")
217