google-research
389 строк · 13.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# coding=utf-8
17# Copyright 2022 The Google Research Authors.
18#
19# Licensed under the Apache License, Version 2.0 (the "License");
20# you may not use this file except in compliance with the License.
21# You may obtain a copy of the License at
22#
23# http://www.apache.org/licenses/LICENSE-2.0
24#
25# Unless required by applicable law or agreed to in writing, software
26# distributed under the License is distributed on an "AS IS" BASIS,
27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28# See the License for the specific language governing permissions and
29# limitations under the License.
30"""GATRNN model.
31
32Defines the encoder, decoder, and complete GATRNN model architecture.
33The gumbel softmax sampling function is also implemented to enable
34gradient passing through the predicted graph.
35"""
36
37import numpy as np38from pytorch_lightning import LightningModule39import torch40from torch import nn41import torch.nn.functional as F42
43from editable_graph_temporal.model import gat_cell44
45device = torch.device("cuda" if torch.cuda.is_available() else "cpu")46
47
48def sample_gumbel(shape, eps=1e-20):49"""Sample from the Gumbel distribution.50
51Args:
52shape: shape of the random variables to be sampled.
53eps: a small value used to avoid doing logarithms on zero.
54
55Returns:
56[batch_size, n_class] sample from the Gumbel distribution.
57"""
58uniform_rand = torch.rand(shape).to(device)59return -torch.autograd.Variable(60torch.log(-torch.log(uniform_rand + eps) + eps))61
62
63def gumbel_softmax_sample(logits, temperature, eps=1e-10):64"""Sample from the Gumbel-Softmax distribution.65
66Args:
67logits: [batch_size, n_class] unnormalized log-probs.
68temperature: non-negative scalar.
69eps: a small value used to avoid doing logarithms on zero.
70
71Returns:
72[batch_size, n_class] sample from the Gumbel-Softmax distribution.
73"""
74sample = sample_gumbel(logits.size(), eps=eps)75y = logits + sample76return F.softmax(y / temperature, dim=-1)77
78
79def gumbel_softmax(logits, temperature, hard=False, eps=1e-10):80"""Sample from the Gumbel-Softmax distribution and optionally discretize.81
82Args:
83logits: [batch_size, n_class] unnormalized log-probs.
84temperature: non-negative scalar.
85hard: if True, take argmax, but differentiate w.r.t. soft sample y.
86eps: a small value used to avoid doing logarithms on zero.
87
88Returns:
89[batch_size, n_class] sample from the Gumbel-Softmax distribution.
90If hard=True, then the returned sample will be one-hot, otherwise it
91will be a probabilitiy distribution that sums to 1 across classes.
92"""
93y_soft = gumbel_softmax_sample(logits, temperature=temperature, eps=eps)94if hard:95shape = logits.size()96_, k = y_soft.data.max(-1)97y_hard = torch.zeros(*shape).to(device)98y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0)99y = torch.autograd.Variable(y_hard - y_soft.data) + y_soft100else:101y = y_soft102return y103
104
105class Encoder(LightningModule, gat_cell.Seq2SeqAttrs):106"""Implements GATRNN encoder model.107
108Encodes the input time series sequence to the hidden vector.
109"""
110
111def __init__(self, args):112"""Instantiates the GATRNN encoder model.113
114Args:
115args: python argparse.ArgumentParser class, we only use model-related
116arguments here.
117"""
118super().__init__()119self._initialize_arguments(args)120self.embedding = nn.Linear(self.input_dim, self.rnn_units)121torch.nn.init.normal_(self.embedding.weight)122
123self.gat_layers = nn.ModuleList(124[gat_cell.GATGRUCell(args) for _ in range(self.num_rnn_layers)])125self.dropout = nn.Dropout(self.dropout)126self.tanh = nn.Tanh()127self.relu = nn.ReLU()128
129def forward(self, inputs, adj, global_embs, hidden_state=None):130r"""Encoder forward pass.131
132Args:
133inputs: input one-step time series, with shape (batch_size,
134self.num_nodes, self.input_dim).
135adj: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
136global_embs: global embedding matrix, with shape (self.num_nodes,
137self.rnn_units).
138hidden_state (tensor): hidden vectors, with shape (num_layers, batch_size,
139self.rnn_units) optional, zeros if not provided.
140
141Returns:
142output: outputs, with shape (batch_size, self.num_nodes,
143self.rnn_units).
144hidden_state: output hidden vectors, with shape (num_layers,
145batch_size, self.num_nodes, self.rnn_units),
146(lower indices mean lower layers).
147"""
148linear_weights = self.embedding.weight149if torch.any(torch.isnan(linear_weights)):150print("weight nan")151embedded = self.embedding(inputs)152embedded = self.tanh(embedded)153
154output = self.dropout(embedded)155
156if hidden_state is None:157hidden_state = torch.zeros((self.num_rnn_layers, inputs.shape[0],158self.num_nodes, self.rnn_units),159device=device)160hidden_states = []161for layer_num, gat_layer in enumerate(self.gat_layers):162next_hidden_state = gat_layer(output, hidden_state[layer_num], adj,163global_embs)164hidden_states.append(next_hidden_state)165output = next_hidden_state166
167# output = self.batch_norm(output)168if self.activation == "relu":169output = self.relu(output)170elif self.activation == "tanh":171output = self.tanh(output)172elif self.activation == "linear":173pass174
175return output, torch.stack(176hidden_states) # runs in O(num_layers) so not too slow177
178
179class Decoder(LightningModule, gat_cell.Seq2SeqAttrs):180"""Implements GATRNN encoder model.181
182Decodes the input hidden vector to the output time series sequence.
183"""
184
185def __init__(self, args):186"""Instantiates the GATRNN encoder model.187
188Args:
189args: python argparse.ArgumentParser class, we only use model-related
190arguments here.
191"""
192super().__init__()193self._initialize_arguments(args)194self.embedding = nn.Linear(self.output_dim, self.rnn_units)195
196self.gat_layers = nn.ModuleList(197[gat_cell.GATGRUCell(args) for _ in range(self.num_rnn_layers)])198self.fc_out = nn.Linear(self.rnn_units, self.output_dim)199self.dropout = nn.Dropout(self.dropout)200self.relu = nn.ReLU()201self.tanh = nn.Tanh()202
203def forward(self, inputs, adj, global_embs, hidden_state=None):204r"""Decoder forward pass.205
206Args:
207inputs: input one-step time series, with shape (batch_size,
208self.num_nodes, self.output_dim).
209adj: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
210global_embs: global embedding matrix, with shape (self.num_nodes,
211self.rnn_units).
212hidden_state (tensor): hidden vectors, with shape (num_layers, batch_size,
213self.rnn_units) optional, zeros if not provided.
214
215Returns:
216output: outputs, with shape (batch_size, self.num_nodes,
217self.output_dim).
218hidden_state: output hidden vectors, with shape (num_layers,
219batch_size, self.num_nodes, self.rnn_units),
220(lower indices mean lower layers).
221"""
222embedded = self.tanh(self.embedding(inputs))223output = self.dropout(embedded)224# output = embedded225
226hidden_states = []227for layer_num, gat_layer in enumerate(self.gat_layers):228next_hidden_state = gat_layer(output, hidden_state[layer_num], adj,229global_embs)230hidden_states.append(next_hidden_state)231output = next_hidden_state232
233# output = self.batch_norm(output)234
235output = self.fc_out(output.view(-1, self.rnn_units))236output = output.view(-1, self.num_nodes, self.output_dim)237
238if self.activation == "relu":239output = self.relu(output)240elif self.activation == "tanh":241output = self.tanh(output)242elif self.activation == "linear":243pass244
245return output, torch.stack(hidden_states)246
247
248class GATRNN(LightningModule, gat_cell.Seq2SeqAttrs):249"""Implements the GATRNN model."""250
251def __init__(self, adj_mx, args):252"""Instantiates the GATRNN encoder model.253
254Args:
255adj_mx: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
256args: python argparse.ArgumentParser class, we only use model-related
257arguments here.
258"""
259super().__init__()260self._initialize_arguments(args)261self.temperature = args.temperature262self.adj_type = args.adj_type263if args.adj_type == "fixed":264self.adj_mx = adj_mx.to(device)265elif args.adj_type == "empty":266self.adj_mx = torch.zeros(267size=(args.num_nodes, args.num_nodes, args.num_relation_types),268device=device).float()269
270self.global_embs = nn.Parameter(271torch.empty((self.num_nodes, self.rnn_units), device=device))272torch.nn.init.xavier_normal_(self.global_embs)273self.fc_out = nn.Linear(self.rnn_units * 2, self.rnn_units)274self.fc_cat = nn.Linear(self.rnn_units, self.num_relation_types)275self.encoder = Encoder(args)276self.decoder = Decoder(args)277self.fc_graph_rec, self.fc_graph_send = self._get_fc_graph_rec_send()278self.loss = nn.L1Loss()279
280def _get_fc_graph_rec_send(self):281"""Gets all two-node receiver and sender node indexes.282
283This returns one-hot vectors for each of the pairs.
284
285Returns:
286(receiver node one-hot indexs, sender node one-hot indexs).
287"""
288
289def encode_onehot(labels):290"""One-hot encoding.291
292Args:
293labels: input labels containing integer numbers.
294
295Returns:
296label_onehot: one-hot vectors of labels.
297"""
298classes = set(labels)299classes_dict = {300c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)301}302labels_onehot = np.array(303list(map(classes_dict.get, labels)), dtype=np.int32)304return labels_onehot305
306off_diag = np.ones([self.num_nodes, self.num_nodes])307rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)308rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)309
310return torch.FloatTensor(rel_rec).to(device), torch.FloatTensor(311rel_send).to(device)312
313def forward(self, inputs):314"""GATRNN forward pass.315
316Args:
317inputs: input time series sequence, with shape (batch_size,
318self.input_len, self.num_nodes, self.input_dim).
319
320Returns:
321outputs: output time series sequence, with shape (batch_size,
322self.output_len, self.num_nodes, self.output_dim).
323edge_prob: predicted relation type probability for each two-node pair,
324with shape (self.num_nodes*self.num_nodes, self.num_relation_types).
325"""
326if self.adj_type == "learned":327adj, edge_prob = self.pred_adj()328elif self.adj_type in ["fixed", "empty"]:329edge_prob = None330adj = self.adj_mx if self.adj_mx.dim() == 3 else self.adj_mx.unsqueeze(2)331
332inputs = inputs.permute(1, 0, 2, 3)333
334encoder_hidden_state = None335for t in range(self.input_len):336_, encoder_hidden_state = self.encoder(inputs[t], adj, self.global_embs,337encoder_hidden_state)338
339decoder_hidden_state = encoder_hidden_state340decoder_input = torch.zeros((encoder_hidden_state.shape[1], self.num_nodes,341self.decoder.output_dim),342device=device)343outputs = []344for t in range(self.output_len):345decoder_output, decoder_hidden_state = self.decoder(346decoder_input, adj, self.global_embs, decoder_hidden_state)347outputs.append(decoder_output)348decoder_input = decoder_output349
350outputs = torch.stack(outputs)351
352del encoder_hidden_state353del decoder_hidden_state354
355outputs = outputs.permute(1, 0, 2, 3)356return outputs, edge_prob357
358def pred_adj(self):359"""Predict relational graph.360
361Returns:
362adj: predicted adjacency matrix of relational graph,
363with shape (self.num_nodes, self.num_nodes,
364self.num_relation_types-1).
365prob: predicted relation type probability for each two-node pair,
366with shape (self.num_nodes*self.num_nodes, self.num_relation_types).
367"""
368receivers = torch.matmul(self.fc_graph_rec, self.global_embs)369senders = torch.matmul(self.fc_graph_send, self.global_embs)370x = torch.cat([senders, receivers], dim=1)371x = torch.relu(self.fc_out(x))372x = self.fc_cat(x)373prob = F.softmax(x, dim=-1)374
375if self.training:376adj = gumbel_softmax(x, temperature=self.temperature, hard=True)377else:378adj = x.argmax(dim=1)379adj = F.one_hot(adj, num_classes=self.num_relation_types)380
381adj = adj[:, 1:].clone().reshape(self.num_nodes, self.num_nodes,382self.num_relation_types - 1)383
384mask = torch.eye(self.num_nodes, self.num_nodes).bool().to(device)385mask = mask.unsqueeze(2).repeat_interleave(386self.num_relation_types - 1, dim=2)387adj.masked_fill_(mask, 0)388
389return adj, prob390