google-research

Форк
0
389 строк · 13.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# coding=utf-8
17
# Copyright 2022 The Google Research Authors.
18
#
19
# Licensed under the Apache License, Version 2.0 (the "License");
20
# you may not use this file except in compliance with the License.
21
# You may obtain a copy of the License at
22
#
23
#     http://www.apache.org/licenses/LICENSE-2.0
24
#
25
# Unless required by applicable law or agreed to in writing, software
26
# distributed under the License is distributed on an "AS IS" BASIS,
27
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
# See the License for the specific language governing permissions and
29
# limitations under the License.
30
"""GATRNN model.
31

32
Defines the encoder, decoder, and complete GATRNN model architecture.
33
The gumbel softmax sampling function is also implemented to enable
34
gradient passing through the predicted graph.
35
"""
36

37
import numpy as np
38
from pytorch_lightning import LightningModule
39
import torch
40
from torch import nn
41
import torch.nn.functional as F
42

43
from editable_graph_temporal.model import gat_cell
44

45
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46

47

48
def sample_gumbel(shape, eps=1e-20):
49
  """Sample from the Gumbel distribution.
50

51
  Args:
52
    shape: shape of the random variables to be sampled.
53
    eps: a small value used to avoid doing logarithms on zero.
54

55
  Returns:
56
    [batch_size, n_class] sample from the Gumbel distribution.
57
  """
58
  uniform_rand = torch.rand(shape).to(device)
59
  return -torch.autograd.Variable(
60
      torch.log(-torch.log(uniform_rand + eps) + eps))
61

62

63
def gumbel_softmax_sample(logits, temperature, eps=1e-10):
64
  """Sample from the Gumbel-Softmax distribution.
65

66
  Args:
67
    logits: [batch_size, n_class] unnormalized log-probs.
68
      temperature: non-negative scalar.
69
    eps: a small value used to avoid doing logarithms on zero.
70

71
  Returns:
72
    [batch_size, n_class] sample from the Gumbel-Softmax distribution.
73
  """
74
  sample = sample_gumbel(logits.size(), eps=eps)
75
  y = logits + sample
76
  return F.softmax(y / temperature, dim=-1)
77

78

79
def gumbel_softmax(logits, temperature, hard=False, eps=1e-10):
80
  """Sample from the Gumbel-Softmax distribution and optionally discretize.
81

82
  Args:
83
    logits: [batch_size, n_class] unnormalized log-probs.
84
    temperature: non-negative scalar.
85
    hard: if True, take argmax, but differentiate w.r.t. soft sample y.
86
    eps: a small value used to avoid doing logarithms on zero.
87

88
  Returns:
89
      [batch_size, n_class] sample from the Gumbel-Softmax distribution.
90
      If hard=True, then the returned sample will be one-hot, otherwise it
91
      will be a probabilitiy distribution that sums to 1 across classes.
92
  """
93
  y_soft = gumbel_softmax_sample(logits, temperature=temperature, eps=eps)
94
  if hard:
95
    shape = logits.size()
96
    _, k = y_soft.data.max(-1)
97
    y_hard = torch.zeros(*shape).to(device)
98
    y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0)
99
    y = torch.autograd.Variable(y_hard - y_soft.data) + y_soft
100
  else:
101
    y = y_soft
102
  return y
103

104

105
class Encoder(LightningModule, gat_cell.Seq2SeqAttrs):
106
  """Implements GATRNN encoder model.
107

108
  Encodes the input time series sequence to the hidden vector.
109
  """
110

111
  def __init__(self, args):
112
    """Instantiates the GATRNN encoder model.
113

114
    Args:
115
      args: python argparse.ArgumentParser class, we only use model-related
116
        arguments here.
117
    """
118
    super().__init__()
119
    self._initialize_arguments(args)
120
    self.embedding = nn.Linear(self.input_dim, self.rnn_units)
121
    torch.nn.init.normal_(self.embedding.weight)
122

123
    self.gat_layers = nn.ModuleList(
124
        [gat_cell.GATGRUCell(args) for _ in range(self.num_rnn_layers)])
125
    self.dropout = nn.Dropout(self.dropout)
126
    self.tanh = nn.Tanh()
127
    self.relu = nn.ReLU()
128

129
  def forward(self, inputs, adj, global_embs, hidden_state=None):
130
    r"""Encoder forward pass.
131

132
    Args:
133
      inputs: input one-step time series, with shape (batch_size,
134
        self.num_nodes, self.input_dim).
135
      adj: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
136
      global_embs: global embedding matrix, with shape (self.num_nodes,
137
        self.rnn_units).
138
      hidden_state (tensor): hidden vectors, with shape (num_layers, batch_size,
139
        self.rnn_units) optional, zeros if not provided.
140

141
    Returns:
142
      output: outputs, with shape (batch_size, self.num_nodes,
143
        self.rnn_units).
144
      hidden_state: output hidden vectors, with shape (num_layers,
145
        batch_size, self.num_nodes, self.rnn_units),
146
        (lower indices mean lower layers).
147
    """
148
    linear_weights = self.embedding.weight
149
    if torch.any(torch.isnan(linear_weights)):
150
      print("weight nan")
151
    embedded = self.embedding(inputs)
152
    embedded = self.tanh(embedded)
153

154
    output = self.dropout(embedded)
155

156
    if hidden_state is None:
157
      hidden_state = torch.zeros((self.num_rnn_layers, inputs.shape[0],
158
                                  self.num_nodes, self.rnn_units),
159
                                 device=device)
160
    hidden_states = []
161
    for layer_num, gat_layer in enumerate(self.gat_layers):
162
      next_hidden_state = gat_layer(output, hidden_state[layer_num], adj,
163
                                    global_embs)
164
      hidden_states.append(next_hidden_state)
165
      output = next_hidden_state
166

167
    # output = self.batch_norm(output)
168
    if self.activation == "relu":
169
      output = self.relu(output)
170
    elif self.activation == "tanh":
171
      output = self.tanh(output)
172
    elif self.activation == "linear":
173
      pass
174

175
    return output, torch.stack(
176
        hidden_states)  # runs in O(num_layers) so not too slow
177

178

179
class Decoder(LightningModule, gat_cell.Seq2SeqAttrs):
180
  """Implements GATRNN encoder model.
181

182
  Decodes the input hidden vector to the output time series sequence.
183
  """
184

185
  def __init__(self, args):
186
    """Instantiates the GATRNN encoder model.
187

188
    Args:
189
      args: python argparse.ArgumentParser class, we only use model-related
190
        arguments here.
191
    """
192
    super().__init__()
193
    self._initialize_arguments(args)
194
    self.embedding = nn.Linear(self.output_dim, self.rnn_units)
195

196
    self.gat_layers = nn.ModuleList(
197
        [gat_cell.GATGRUCell(args) for _ in range(self.num_rnn_layers)])
198
    self.fc_out = nn.Linear(self.rnn_units, self.output_dim)
199
    self.dropout = nn.Dropout(self.dropout)
200
    self.relu = nn.ReLU()
201
    self.tanh = nn.Tanh()
202

203
  def forward(self, inputs, adj, global_embs, hidden_state=None):
204
    r"""Decoder forward pass.
205

206
    Args:
207
      inputs: input one-step time series, with shape (batch_size,
208
        self.num_nodes, self.output_dim).
209
      adj: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
210
      global_embs: global embedding matrix, with shape (self.num_nodes,
211
        self.rnn_units).
212
      hidden_state (tensor): hidden vectors, with shape (num_layers, batch_size,
213
        self.rnn_units) optional, zeros if not provided.
214

215
    Returns:
216
      output: outputs, with shape (batch_size, self.num_nodes,
217
        self.output_dim).
218
      hidden_state: output hidden vectors, with shape (num_layers,
219
        batch_size, self.num_nodes, self.rnn_units),
220
        (lower indices mean lower layers).
221
    """
222
    embedded = self.tanh(self.embedding(inputs))
223
    output = self.dropout(embedded)
224
    # output = embedded
225

226
    hidden_states = []
227
    for layer_num, gat_layer in enumerate(self.gat_layers):
228
      next_hidden_state = gat_layer(output, hidden_state[layer_num], adj,
229
                                    global_embs)
230
      hidden_states.append(next_hidden_state)
231
      output = next_hidden_state
232

233
    # output = self.batch_norm(output)
234

235
    output = self.fc_out(output.view(-1, self.rnn_units))
236
    output = output.view(-1, self.num_nodes, self.output_dim)
237

238
    if self.activation == "relu":
239
      output = self.relu(output)
240
    elif self.activation == "tanh":
241
      output = self.tanh(output)
242
    elif self.activation == "linear":
243
      pass
244

245
    return output, torch.stack(hidden_states)
246

247

248
class GATRNN(LightningModule, gat_cell.Seq2SeqAttrs):
249
  """Implements the GATRNN model."""
250

251
  def __init__(self, adj_mx, args):
252
    """Instantiates the GATRNN encoder model.
253

254
    Args:
255
      adj_mx: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
256
      args: python argparse.ArgumentParser class, we only use model-related
257
        arguments here.
258
    """
259
    super().__init__()
260
    self._initialize_arguments(args)
261
    self.temperature = args.temperature
262
    self.adj_type = args.adj_type
263
    if args.adj_type == "fixed":
264
      self.adj_mx = adj_mx.to(device)
265
    elif args.adj_type == "empty":
266
      self.adj_mx = torch.zeros(
267
          size=(args.num_nodes, args.num_nodes, args.num_relation_types),
268
          device=device).float()
269

270
    self.global_embs = nn.Parameter(
271
        torch.empty((self.num_nodes, self.rnn_units), device=device))
272
    torch.nn.init.xavier_normal_(self.global_embs)
273
    self.fc_out = nn.Linear(self.rnn_units * 2, self.rnn_units)
274
    self.fc_cat = nn.Linear(self.rnn_units, self.num_relation_types)
275
    self.encoder = Encoder(args)
276
    self.decoder = Decoder(args)
277
    self.fc_graph_rec, self.fc_graph_send = self._get_fc_graph_rec_send()
278
    self.loss = nn.L1Loss()
279

280
  def _get_fc_graph_rec_send(self):
281
    """Gets all two-node receiver and sender node indexes.
282

283
    This returns one-hot vectors for each of the pairs.
284

285
    Returns:
286
      (receiver node one-hot indexs, sender node one-hot indexs).
287
    """
288

289
    def encode_onehot(labels):
290
      """One-hot encoding.
291

292
      Args:
293
        labels: input labels containing integer numbers.
294

295
      Returns:
296
        label_onehot: one-hot vectors of labels.
297
      """
298
      classes = set(labels)
299
      classes_dict = {
300
          c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)
301
      }
302
      labels_onehot = np.array(
303
          list(map(classes_dict.get, labels)), dtype=np.int32)
304
      return labels_onehot
305

306
    off_diag = np.ones([self.num_nodes, self.num_nodes])
307
    rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
308
    rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
309

310
    return torch.FloatTensor(rel_rec).to(device), torch.FloatTensor(
311
        rel_send).to(device)
312

313
  def forward(self, inputs):
314
    """GATRNN forward pass.
315

316
    Args:
317
      inputs: input time series sequence, with shape (batch_size,
318
        self.input_len, self.num_nodes, self.input_dim).
319

320
    Returns:
321
      outputs: output time series sequence, with shape (batch_size,
322
        self.output_len, self.num_nodes,  self.output_dim).
323
      edge_prob: predicted relation type probability for each two-node pair,
324
        with shape (self.num_nodes*self.num_nodes, self.num_relation_types).
325
    """
326
    if self.adj_type == "learned":
327
      adj, edge_prob = self.pred_adj()
328
    elif self.adj_type in ["fixed", "empty"]:
329
      edge_prob = None
330
      adj = self.adj_mx if self.adj_mx.dim() == 3 else self.adj_mx.unsqueeze(2)
331

332
    inputs = inputs.permute(1, 0, 2, 3)
333

334
    encoder_hidden_state = None
335
    for t in range(self.input_len):
336
      _, encoder_hidden_state = self.encoder(inputs[t], adj, self.global_embs,
337
                                             encoder_hidden_state)
338

339
    decoder_hidden_state = encoder_hidden_state
340
    decoder_input = torch.zeros((encoder_hidden_state.shape[1], self.num_nodes,
341
                                 self.decoder.output_dim),
342
                                device=device)
343
    outputs = []
344
    for t in range(self.output_len):
345
      decoder_output, decoder_hidden_state = self.decoder(
346
          decoder_input, adj, self.global_embs, decoder_hidden_state)
347
      outputs.append(decoder_output)
348
      decoder_input = decoder_output
349

350
    outputs = torch.stack(outputs)
351

352
    del encoder_hidden_state
353
    del decoder_hidden_state
354

355
    outputs = outputs.permute(1, 0, 2, 3)
356
    return outputs, edge_prob
357

358
  def pred_adj(self):
359
    """Predict relational graph.
360

361
    Returns:
362
      adj: predicted adjacency matrix of relational graph,
363
        with shape (self.num_nodes, self.num_nodes,
364
        self.num_relation_types-1).
365
      prob: predicted relation type probability for each two-node pair,
366
        with shape (self.num_nodes*self.num_nodes, self.num_relation_types).
367
    """
368
    receivers = torch.matmul(self.fc_graph_rec, self.global_embs)
369
    senders = torch.matmul(self.fc_graph_send, self.global_embs)
370
    x = torch.cat([senders, receivers], dim=1)
371
    x = torch.relu(self.fc_out(x))
372
    x = self.fc_cat(x)
373
    prob = F.softmax(x, dim=-1)
374

375
    if self.training:
376
      adj = gumbel_softmax(x, temperature=self.temperature, hard=True)
377
    else:
378
      adj = x.argmax(dim=1)
379
      adj = F.one_hot(adj, num_classes=self.num_relation_types)
380

381
    adj = adj[:, 1:].clone().reshape(self.num_nodes, self.num_nodes,
382
                                     self.num_relation_types - 1)
383

384
    mask = torch.eye(self.num_nodes, self.num_nodes).bool().to(device)
385
    mask = mask.unsqueeze(2).repeat_interleave(
386
        self.num_relation_types - 1, dim=2)
387
    adj.masked_fill_(mask, 0)
388

389
    return adj, prob
390

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.