google-research
366 строк · 10.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Models on graphs.
17
18These models take a graph as input, and output per-node or per-edge features.
19"""
20
21import abc
22from typing import Any, Callable, List, Optional, Tuple
23
24import flax
25from flax import struct
26import flax.linen as nn
27
28import jax
29from jax.nn import initializers
30import jax.numpy as jnp
31
32import numpy as np
33
34from jaxsel._src import graph_api
35
36################
37# Graph models #
38################
39
40
41class GraphModel(abc.ABC):
42"""Abstract class for all graph models.
43
44Graph models take a batch of problem specific features (node, task, edges)
45as input.
46Their output is task specific, e.g. usually some feature vector per node,
47which may be aggregated later, possibly class logits.
48"""
49
50@abc.abstractmethod
51def __call__(
52self,
53node_features,
54adjacency_mat,
55qstar,
56):
57"""Performs a forward pass on the model.
58
59Args:
60node_features: features associated to the nodes on the extracted subgraph.
61adjacency_mat: Extracted adjacency matrix.
62qstar: Optimal weights on the nodes, given by our subgraph extraction
63scheme. If not using subgraph extraction, `qstar` should be a vector of
64ones.
65
66Returns:
67Output of the model, e.g. logprobs for a classification task...
68"""
69Ellipsis
70
71
72###########################
73# Flax based Graph Models #
74###########################
75
76# Transformer models are adapted from
77# https://github.com/google/flax/blob/main/examples/wmt/models.py
78
79
80@struct.dataclass
81class TransformerConfig:
82"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
83graph_parameters: graph_api.GraphParameters
84hidden_dim: int # Used to standardize node feature and position embeddings.
85num_classes: int
86image_size: int
87dtype: Any = jnp.float32
88embedding_dim: int = 512
89num_heads: int = 8
90num_layers: int = 6
91qkv_dim: int = 512
92mlp_dim: int = 2048
93dropout_rate: float = 0.1
94attention_dropout_rate: float = 0.1
95deterministic: bool = False
96# Initializers take in (key, shape, dtype) and return arrays.
97kernel_init: Callable[[Any, Any, Any], jnp.ndarray] = (
98nn.initializers.xavier_uniform()
99)
100bias_init: Callable[[Any, Any, Any], jnp.ndarray] = nn.initializers.normal(
101stddev=1e-6
102)
103
104
105class MlpBlock(nn.Module):
106"""Transformer MLP / feed-forward block.
107
108Attributes:
109config: TransformerConfig dataclass containing hyperparameters.
110out_dim: optionally specify out dimension.
111"""
112config: TransformerConfig
113out_dim: Optional[int] = None
114
115@nn.compact
116def __call__(self, inputs):
117"""Applies Transformer MlpBlock module."""
118cfg = self.config
119actual_out_dim = (
120inputs.shape[-1] if self.out_dim is None else self.out_dim)
121x = nn.Dense(
122cfg.mlp_dim,
123dtype=cfg.dtype,
124kernel_init=cfg.kernel_init,
125bias_init=cfg.bias_init)(
126inputs)
127x = nn.relu(x)
128x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic)
129output = nn.Dense(
130actual_out_dim,
131dtype=cfg.dtype,
132kernel_init=cfg.kernel_init,
133bias_init=cfg.bias_init)(
134x)
135output = nn.Dropout(rate=cfg.dropout_rate)(
136output, deterministic=cfg.deterministic)
137return output
138
139
140class Encoder1DBlock(nn.Module):
141"""Transformer encoder layer.
142
143Attributes:
144config: TransformerConfig dataclass containing hyperparameters.
145"""
146config: TransformerConfig
147
148@nn.compact
149def __call__( # pytype: disable=annotation-type-mismatch # jnp-array
150self, inputs, encoder_mask = None
151):
152"""Applies Encoder1DBlock module.
153
154Args:
155inputs: input data.
156encoder_mask: encoder self-attention mask.
157
158Returns:
159output after transformer encoder block.
160"""
161cfg = self.config
162
163# Attention block.
164assert inputs.ndim == 2
165x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
166x = nn.SelfAttention(
167num_heads=cfg.num_heads,
168dtype=cfg.dtype,
169qkv_features=cfg.qkv_dim,
170kernel_init=cfg.kernel_init,
171bias_init=cfg.bias_init,
172use_bias=False,
173broadcast_dropout=False,
174dropout_rate=cfg.attention_dropout_rate,
175deterministic=cfg.deterministic)(x, encoder_mask)
176
177x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic)
178x = x + inputs
179
180# MLP block.
181y = nn.LayerNorm(dtype=cfg.dtype)(x)
182y = MlpBlock(config=cfg)(y)
183
184return x + y
185
186
187class SubgraphEmbedding(nn.Module):
188"""Embeds a bag of nodes features and positions."""
189config: TransformerConfig
190
191def setup(self):
192cfg = self.config
193self.node_embedding = nn.Embed(cfg.graph_parameters.node_vocab_size,
194cfg.embedding_dim)
195# graph_embedding is for embedding the whole bag of nodes. Similar to the
196# CLS token in BERT.
197self.graph_embedding = nn.Embed(1, cfg.hidden_dim)
198# The +2 accounts for the -1 "out of bounds" node, and the "not a node"
199# index.
200# The "not a node" index stems from jax sparse: for an array of shape n,
201# if part of the `nse` elements of the array are actually 0,
202# they will be matched to the index `n`.
203# This happens in our setup when we use L1 penalties causing the actual
204# size of the subgraph to be stricly smaller than max_subgraph_size.
205# Because jax arrays clip out of bounds indices, we only need to
206# add 1 element in the embedding to account for this, and not mix the info
207# with a different node.
208self.position_embedding = nn.Embed(cfg.image_size + 2, cfg.embedding_dim)
209
210self.node_hidden_layer = nn.Dense(cfg.hidden_dim)
211self.position_hidden_layer = nn.Dense(cfg.hidden_dim)
212
213def __call__(
214self, node_features, node_ids
215):
216"""Embeds nodes by features and node_id.
217
218Args:
219node_features: float or int tensor representing the current node's fixed
220features. These features are not learned.
221node_ids: id of the node in the image. Used in place of the position in
222the image.
223
224Returns:
225logits: float tensor of shape (num_classes,)
226"""
227cfg = self.config
228
229num_nodes = len(node_ids)
230
231# Embed nodes
232node_embs = self.node_embedding(node_features)
233node_embs = node_embs.reshape(num_nodes, -1)
234node_hiddens = self.node_hidden_layer(node_embs)
235graph_hidden = self.graph_embedding(jnp.zeros(1, dtype=int))
236node_hiddens = jnp.vstack((node_hiddens, graph_hidden))
237
238# Embed positions
239# TODO(gnegiar): We need to clip the "not a node" node to make sure it
240# propagates gradients correctly. jax.experimental.sparse uses an out of
241# bounds index to encode elements with 0 value.
242# See https://github.com/google/jax/issues/5760
243node_ids = jnp.clip(node_ids, a_max=cfg.image_size - 1)
244position_embs = self.position_embedding(node_ids + 1)
245position_hiddens = self.position_hidden_layer(position_embs)
246# The graph node has no position.
247position_hiddens = jnp.vstack(
248(position_hiddens, jnp.zeros(position_hiddens.shape[-1])))
249
250return node_hiddens, position_hiddens
251
252
253class TransformerGraphEncoder(nn.Module):
254"""Encodes a bag of nodes into a subgraph representation.
255
256Adapted from https://github.com/google/flax/blob/main/examples/wmt/models.py
257"""
258config: TransformerConfig
259
260@nn.compact
261def __call__(
262self,
263node_feature_embeddings,
264node_position_embeddings,
265adjacency_mat,
266qstar,
267):
268"""Applies the TransformerEncoder module.
269
270Args:
271node_feature_embeddings: Embeddings representing nodes.
272node_position_embeddings: Embeddings representing node positions.
273adjacency_mat: Adjacency matrix over the nodes. Not used for now.
274qstar: float tensor of shape (num_of_nodes,) The optimal q weighting over
275the nodes of the graph, from the subgraph selection module.
276
277Returns:
278encoded: Encoded nodes, with extra Graph node at the end.
279"""
280cfg = self.config
281x = node_feature_embeddings + node_position_embeddings
282
283# Add average weight to graph node for scale
284qstar = jnp.append(qstar, jnp.mean(qstar))
285
286# Multiply embeddings by node weights. => learn the agent model.
287x = x * qstar[Ellipsis, None]
288x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic)
289
290x = x.astype(cfg.dtype)
291# TODO(gnegiar): Plot x here to check
292# Keep nodes with positive weights
293mask1d = qstar != 0
294encoder_mask = nn.attention.make_attention_mask(mask1d, mask1d)
295
296# Input Encoder
297for lyr in range(cfg.num_layers):
298x = Encoder1DBlock(
299config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask)
300x = x * mask1d[Ellipsis, None]
301# TODO(gnegiar): Also plot x here
302# Possibly plot gradient norms per encoder layer
303# Plot attention weights?
304encoded = nn.LayerNorm(dtype=cfg.dtype, name="encoder_norm")(x)
305return encoded
306
307
308class ClassificationHead(nn.Module):
309"""A 2 layer fully connected network for classification."""
310config: TransformerConfig
311
312def setup(self):
313cfg = self.config
314self.fc1 = nn.Dense(cfg.hidden_dim)
315self.fc2 = nn.Dense(cfg.num_classes if cfg.num_classes > 2 else 1)
316
317def __call__(self, x):
318x = nn.relu(self.fc1(x))
319logits = self.fc2(x)
320if self.config.num_classes > 2:
321logits = jax.nn.log_softmax(logits)
322return logits
323
324
325class TransformerClassifier(nn.Module):
326"""A transformer based graph classifier.
327
328Attributes:
329config: Configuration for the model.
330"""
331config: TransformerConfig
332
333def setup(self):
334cfg = self.config
335self.embedder = SubgraphEmbedding(cfg)
336self.encoder = TransformerGraphEncoder(cfg)
337self.classifier = ClassificationHead(cfg)
338
339def encode(
340self,
341node_features,
342node_ids,
343adjacency_mat,
344qstar,
345):
346node_feature_embeddings, node_position_embeddings = self.embedder(
347node_features, node_ids)
348return self.encoder(node_feature_embeddings, node_position_embeddings,
349adjacency_mat, qstar)
350
351def decode(self, encoded_graph):
352graph_embedding = encoded_graph[-1]
353logits = self.classifier(graph_embedding)
354return logits
355
356def __call__(
357self,
358node_features,
359node_ids,
360adjacency_mat,
361qstar,
362):
363adjacency_mat = adjacency_mat.squeeze(-1)
364encoded_graph = self.encode(node_features, node_ids, adjacency_mat, qstar)
365# The encoder encodes the whole graph in a special token in last position
366return self.decode(encoded_graph)
367
368
369