google-research

Форк
0
/
graph_models.py 
366 строк · 10.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Models on graphs.
17

18
These models take a graph as input, and output per-node or per-edge features.
19
"""
20

21
import abc
22
from typing import Any, Callable, List, Optional, Tuple
23

24
import flax
25
from flax import struct
26
import flax.linen as nn
27

28
import jax
29
from jax.nn import initializers
30
import jax.numpy as jnp
31

32
import numpy as np
33

34
from jaxsel._src import graph_api
35

36
################
37
# Graph models #
38
################
39

40

41
class GraphModel(abc.ABC):
42
  """Abstract class for all graph models.
43

44
  Graph models take a batch of problem specific features (node, task, edges)
45
  as input.
46
  Their output is task specific, e.g. usually some feature vector per node,
47
  which may be aggregated later, possibly class logits.
48
  """
49

50
  @abc.abstractmethod
51
  def __call__(
52
      self,
53
      node_features,
54
      adjacency_mat,
55
      qstar,
56
  ):
57
    """Performs a forward pass on the model.
58

59
    Args:
60
      node_features: features associated to the nodes on the extracted subgraph.
61
      adjacency_mat: Extracted adjacency matrix.
62
      qstar: Optimal weights on the nodes, given by our subgraph extraction
63
        scheme. If not using subgraph extraction, `qstar` should be a vector of
64
        ones.
65

66
    Returns:
67
      Output of the model, e.g. logprobs for a classification task...
68
    """
69
    Ellipsis
70

71

72
###########################
73
# Flax based Graph Models #
74
###########################
75

76
# Transformer models are adapted from
77
# https://github.com/google/flax/blob/main/examples/wmt/models.py
78

79

80
@struct.dataclass
81
class TransformerConfig:
82
  """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
83
  graph_parameters: graph_api.GraphParameters
84
  hidden_dim: int  # Used to standardize node feature and position embeddings.
85
  num_classes: int
86
  image_size: int
87
  dtype: Any = jnp.float32
88
  embedding_dim: int = 512
89
  num_heads: int = 8
90
  num_layers: int = 6
91
  qkv_dim: int = 512
92
  mlp_dim: int = 2048
93
  dropout_rate: float = 0.1
94
  attention_dropout_rate: float = 0.1
95
  deterministic: bool = False
96
  # Initializers take in (key, shape, dtype) and return arrays.
97
  kernel_init: Callable[[Any, Any, Any], jnp.ndarray] = (
98
      nn.initializers.xavier_uniform()
99
  )
100
  bias_init: Callable[[Any, Any, Any], jnp.ndarray] = nn.initializers.normal(
101
      stddev=1e-6
102
  )
103

104

105
class MlpBlock(nn.Module):
106
  """Transformer MLP / feed-forward block.
107

108
  Attributes:
109
    config: TransformerConfig dataclass containing hyperparameters.
110
    out_dim: optionally specify out dimension.
111
  """
112
  config: TransformerConfig
113
  out_dim: Optional[int] = None
114

115
  @nn.compact
116
  def __call__(self, inputs):
117
    """Applies Transformer MlpBlock module."""
118
    cfg = self.config
119
    actual_out_dim = (
120
        inputs.shape[-1] if self.out_dim is None else self.out_dim)
121
    x = nn.Dense(
122
        cfg.mlp_dim,
123
        dtype=cfg.dtype,
124
        kernel_init=cfg.kernel_init,
125
        bias_init=cfg.bias_init)(
126
            inputs)
127
    x = nn.relu(x)
128
    x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic)
129
    output = nn.Dense(
130
        actual_out_dim,
131
        dtype=cfg.dtype,
132
        kernel_init=cfg.kernel_init,
133
        bias_init=cfg.bias_init)(
134
            x)
135
    output = nn.Dropout(rate=cfg.dropout_rate)(
136
        output, deterministic=cfg.deterministic)
137
    return output
138

139

140
class Encoder1DBlock(nn.Module):
141
  """Transformer encoder layer.
142

143
  Attributes:
144
    config: TransformerConfig dataclass containing hyperparameters.
145
  """
146
  config: TransformerConfig
147

148
  @nn.compact
149
  def __call__(  # pytype: disable=annotation-type-mismatch  # jnp-array
150
      self, inputs, encoder_mask = None
151
  ):
152
    """Applies Encoder1DBlock module.
153

154
    Args:
155
      inputs: input data.
156
      encoder_mask: encoder self-attention mask.
157

158
    Returns:
159
      output after transformer encoder block.
160
    """
161
    cfg = self.config
162

163
    # Attention block.
164
    assert inputs.ndim == 2
165
    x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
166
    x = nn.SelfAttention(
167
        num_heads=cfg.num_heads,
168
        dtype=cfg.dtype,
169
        qkv_features=cfg.qkv_dim,
170
        kernel_init=cfg.kernel_init,
171
        bias_init=cfg.bias_init,
172
        use_bias=False,
173
        broadcast_dropout=False,
174
        dropout_rate=cfg.attention_dropout_rate,
175
        deterministic=cfg.deterministic)(x, encoder_mask)
176

177
    x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic)
178
    x = x + inputs
179

180
    # MLP block.
181
    y = nn.LayerNorm(dtype=cfg.dtype)(x)
182
    y = MlpBlock(config=cfg)(y)
183

184
    return x + y
185

186

187
class SubgraphEmbedding(nn.Module):
188
  """Embeds a bag of nodes features and positions."""
189
  config: TransformerConfig
190

191
  def setup(self):
192
    cfg = self.config
193
    self.node_embedding = nn.Embed(cfg.graph_parameters.node_vocab_size,
194
                                   cfg.embedding_dim)
195
    # graph_embedding is for embedding the whole bag of nodes. Similar to the
196
    # CLS token in BERT.
197
    self.graph_embedding = nn.Embed(1, cfg.hidden_dim)
198
    # The +2 accounts for the -1 "out of bounds" node, and the "not a node"
199
    # index.
200
    # The "not a node" index stems from jax sparse: for an array of shape n,
201
    # if part of the `nse` elements of the array are actually 0,
202
    # they will be matched to the index `n`.
203
    # This happens in our setup when we use L1 penalties causing the actual
204
    # size of the subgraph to be stricly smaller than max_subgraph_size.
205
    # Because jax arrays clip out of bounds indices, we only need to
206
    # add 1 element in the embedding to account for this, and not mix the info
207
    # with a different node.
208
    self.position_embedding = nn.Embed(cfg.image_size + 2, cfg.embedding_dim)
209

210
    self.node_hidden_layer = nn.Dense(cfg.hidden_dim)
211
    self.position_hidden_layer = nn.Dense(cfg.hidden_dim)
212

213
  def __call__(
214
      self, node_features, node_ids
215
  ):
216
    """Embeds nodes by features and node_id.
217

218
    Args:
219
      node_features: float or int tensor representing the current node's fixed
220
        features. These features are not learned.
221
      node_ids: id of the node in the image. Used in place of the position in
222
        the image.
223

224
    Returns:
225
      logits: float tensor of shape (num_classes,)
226
    """
227
    cfg = self.config
228

229
    num_nodes = len(node_ids)
230

231
    # Embed nodes
232
    node_embs = self.node_embedding(node_features)
233
    node_embs = node_embs.reshape(num_nodes, -1)
234
    node_hiddens = self.node_hidden_layer(node_embs)
235
    graph_hidden = self.graph_embedding(jnp.zeros(1, dtype=int))
236
    node_hiddens = jnp.vstack((node_hiddens, graph_hidden))
237

238
    # Embed positions
239
    # TODO(gnegiar): We need to clip the "not a node" node to make sure it
240
    # propagates gradients correctly. jax.experimental.sparse uses an out of
241
    # bounds index to encode elements with 0 value.
242
    # See https://github.com/google/jax/issues/5760
243
    node_ids = jnp.clip(node_ids, a_max=cfg.image_size - 1)
244
    position_embs = self.position_embedding(node_ids + 1)
245
    position_hiddens = self.position_hidden_layer(position_embs)
246
    # The graph node has no position.
247
    position_hiddens = jnp.vstack(
248
        (position_hiddens, jnp.zeros(position_hiddens.shape[-1])))
249

250
    return node_hiddens, position_hiddens
251

252

253
class TransformerGraphEncoder(nn.Module):
254
  """Encodes a bag of nodes into a subgraph representation.
255

256
  Adapted from https://github.com/google/flax/blob/main/examples/wmt/models.py
257
  """
258
  config: TransformerConfig
259

260
  @nn.compact
261
  def __call__(
262
      self,
263
      node_feature_embeddings,
264
      node_position_embeddings,
265
      adjacency_mat,
266
      qstar,
267
  ):
268
    """Applies the TransformerEncoder module.
269

270
    Args:
271
      node_feature_embeddings: Embeddings representing nodes.
272
      node_position_embeddings: Embeddings representing node positions.
273
      adjacency_mat: Adjacency matrix over the nodes. Not used for now.
274
      qstar: float tensor of shape (num_of_nodes,) The optimal q weighting over
275
        the nodes of the graph, from the subgraph selection module.
276

277
    Returns:
278
      encoded: Encoded nodes, with extra Graph node at the end.
279
    """
280
    cfg = self.config
281
    x = node_feature_embeddings + node_position_embeddings
282

283
    # Add average weight to graph node for scale
284
    qstar = jnp.append(qstar, jnp.mean(qstar))
285

286
    # Multiply embeddings by node weights. => learn the agent model.
287
    x = x * qstar[Ellipsis, None]
288
    x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic)
289

290
    x = x.astype(cfg.dtype)
291
    # TODO(gnegiar): Plot x here to check
292
    # Keep nodes with positive weights
293
    mask1d = qstar != 0
294
    encoder_mask = nn.attention.make_attention_mask(mask1d, mask1d)
295

296
    # Input Encoder
297
    for lyr in range(cfg.num_layers):
298
      x = Encoder1DBlock(
299
          config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask)
300
      x = x * mask1d[Ellipsis, None]
301
      # TODO(gnegiar): Also plot x here
302
      # Possibly plot gradient norms per encoder layer
303
      # Plot attention weights?
304
    encoded = nn.LayerNorm(dtype=cfg.dtype, name="encoder_norm")(x)
305
    return encoded
306

307

308
class ClassificationHead(nn.Module):
309
  """A 2 layer fully connected network for classification."""
310
  config: TransformerConfig
311

312
  def setup(self):
313
    cfg = self.config
314
    self.fc1 = nn.Dense(cfg.hidden_dim)
315
    self.fc2 = nn.Dense(cfg.num_classes if cfg.num_classes > 2 else 1)
316

317
  def __call__(self, x):
318
    x = nn.relu(self.fc1(x))
319
    logits = self.fc2(x)
320
    if self.config.num_classes > 2:
321
      logits = jax.nn.log_softmax(logits)
322
    return logits
323

324

325
class TransformerClassifier(nn.Module):
326
  """A transformer based graph classifier.
327

328
  Attributes:
329
    config: Configuration for the model.
330
  """
331
  config: TransformerConfig
332

333
  def setup(self):
334
    cfg = self.config
335
    self.embedder = SubgraphEmbedding(cfg)
336
    self.encoder = TransformerGraphEncoder(cfg)
337
    self.classifier = ClassificationHead(cfg)
338

339
  def encode(
340
      self,
341
      node_features,
342
      node_ids,
343
      adjacency_mat,
344
      qstar,
345
  ):
346
    node_feature_embeddings, node_position_embeddings = self.embedder(
347
        node_features, node_ids)
348
    return self.encoder(node_feature_embeddings, node_position_embeddings,
349
                        adjacency_mat, qstar)
350

351
  def decode(self, encoded_graph):
352
    graph_embedding = encoded_graph[-1]
353
    logits = self.classifier(graph_embedding)
354
    return logits
355

356
  def __call__(
357
      self,
358
      node_features,
359
      node_ids,
360
      adjacency_mat,
361
      qstar,
362
  ):
363
    adjacency_mat = adjacency_mat.squeeze(-1)
364
    encoded_graph = self.encode(node_features, node_ids, adjacency_mat, qstar)
365
    # The encoder encodes the whole graph in a special token in last position
366
    return self.decode(encoded_graph)
367

368

369

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.