google-research
246 строк · 9.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# coding=utf-8
17# Copyright 2022 The Google Research Authors.
18#
19# Licensed under the Apache License, Version 2.0 (the "License");
20# you may not use this file except in compliance with the License.
21# You may obtain a copy of the License at
22#
23# http://www.apache.org/licenses/LICENSE-2.0
24#
25# Unless required by applicable law or agreed to in writing, software
26# distributed under the License is distributed on an "AS IS" BASIS,
27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28# See the License for the specific language governing permissions and
29# limitations under the License.
30"""GATRNN unit cell model.
31
32Defines the unit cell used in the GATRNN model.
33"""
34
35from pytorch_lightning import LightningModule36import torch37from torch import nn38from torch_geometric.utils import softmax39
40device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')41
42
43class Seq2SeqAttrs:44"""Stores model-related arguments."""45
46def _initialize_arguments(self, args):47"""Initializes model arguments.48
49Args:
50args: python argparse.ArgumentParser class, we only use model-related
51arguments here.
52"""
53self.input_dim = args.input_dim54self.output_dim = args.output_dim55self.rnn_units = args.hidden_dim56self.num_nodes = args.num_nodes57self.input_len = args.input_len58self.output_len = args.output_len59self.num_relation_types = args.num_relation_types60
61self.dropout = args.dropout62self.negative_slope = args.negative_slope63
64self.num_rnn_layers = args.num_layers65self.lr = args.learning_rate66self.activation = args.activation67self.share_attn_weights = args.share_attn_weights68
69
70class GATGRUCell(LightningModule, Seq2SeqAttrs):71"""Implements a single unit cell of GATRNN model."""72
73def __init__(self, args):74"""Instantiates the GATRNN unit cell model.75
76Args:
77args: python argparse.ArgumentParser class, we only use model-related
78arguments here.
79"""
80super().__init__()81self._initialize_arguments(args)82
83self.activation = torch.tanh84input_size = 2 * self.rnn_units85
86# gconv87weight_dim = input_size if self.share_attn_weights else 2 * input_size88biases_dim = self.rnn_units if self.share_attn_weights else 2 * self.rnn_units89
90self.r_weights = nn.Parameter(91torch.empty((weight_dim, self.rnn_units, self.num_relation_types - 1),92device=device))93self.r_biases = nn.Parameter(94torch.zeros((biases_dim, self.num_relation_types - 1), device=device))95self.u_weights = nn.Parameter(96torch.empty((weight_dim, self.rnn_units, self.num_relation_types - 1),97device=device))98self.u_biases = nn.Parameter(99torch.zeros((biases_dim, self.num_relation_types - 1), device=device))100self.c_weights = nn.Parameter(101torch.empty((weight_dim, self.rnn_units, self.num_relation_types - 1),102device=device))103self.c_biases = nn.Parameter(104torch.zeros((biases_dim, self.num_relation_types - 1), device=device))105
106torch.nn.init.xavier_normal_(self.r_weights)107torch.nn.init.xavier_normal_(self.u_weights)108torch.nn.init.xavier_normal_(self.c_weights)109
110def forward(self, inputs, hx, adj, global_embs):111r"""Forward computation of a single unit cell of GATRNN model.112
113The forward computation is generally the same as
114that of a GRU cell of sequence model, but gate vectors and candidate
115hidden vectors are computed by graph attention
116network based convolutions.
117
118Args:
119inputs: input one-step time series, with shape (batch_size,
120self.num_nodes, self.rnn_units).
121hx: hidden vectors from the last unit, with shape(batch_size,
122self.num_nodes, self.rnn_units). If this is the first unit, usually hx
123is supposed to be a zero vector.
124adj: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
125global_embs: global embedding matrix, with shape (self.num_nodes,
126self.rnn_units).
127
128Returns:
129hx: new hidden vector.
130"""
131r = torch.tanh(self._gconv(inputs, adj, global_embs, hx, 'r'))132u = torch.tanh(self._gconv(inputs, adj, global_embs, hx, 'u'))133c = self._gconv(inputs, adj, global_embs, r * hx,134'c') # element-wise multiplication135if self.activation is not None:136c = self.activation(c)137
138hx = u * hx + (1.0 - u) * c139
140del r141del u142del c143
144return hx145
146@staticmethod147def _concat(x, x_):148r"""Concatenates two tensors along the first dimension.149
150Args:
151x: first input tensor.
152x_: second input tensor.
153
154Returns:
155concatenation tensor of x and x_.
156"""
157x_ = x_.unsqueeze(0)158return torch.cat([x, x_], dim=0)159
160def _gconv(self, inputs, adj_mx, global_embs, state, option='r'):161r"""Graph attention network based convolution computation.162
163Args:
164inputs: input vector, with shape (batch_size, self.num_nodes,
165self.rnn_units).
166adj_mx: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
167global_embs: global embedding matrix, with shape (self.num_nodes,
168self.rnn_units).
169state: hidden vectors from the last unit, with shape(batch_size,
170self.num_nodes, self.rnn_units). If this is the first unit, usually hx
171is supposed to be a zero vector.
172option: indicate whether the output is reset gate vector ('r'), update
173gate vector ('u'), or candidate hidden vector ('c').
174
175Returns:
176out: output, can be reset gate vector (option is 'r'), update gate
177vector (option is 'u'), or
178candidate hidden vector (option is 'c').
179"""
180batch_size = inputs.shape[0]181num_nodes = self.num_nodes182
183x = torch.cat([inputs, state], dim=-1) # input_dim184out = torch.zeros(185size=(batch_size, num_nodes, self.rnn_units), device=device)186
187for relation_id in range(self.num_relation_types - 1):188if option == 'r':189r_weights_left = self.r_weights[:2 * self.rnn_units, :, relation_id]190r_biases_left = self.r_biases[:self.rnn_units, relation_id]191r_weights_right = r_weights_left if self.share_attn_weights else self.r_weights[1922 * self.rnn_units:, :, relation_id]193r_biases_right = r_biases_left if self.share_attn_weights else self.r_biases[194self.rnn_units:, relation_id]195x_left = torch.matmul(x, r_weights_left) + r_biases_left196x_right = torch.matmul(x, r_weights_right) + r_biases_right197elif option == 'u':198u_weights_left = self.u_weights[:2 * self.rnn_units, :, relation_id]199u_biases_left = self.u_biases[:self.rnn_units, relation_id]200u_weights_right = u_weights_left if self.share_attn_weights else self.u_weights[2012 * self.rnn_units:, :, relation_id]202u_biases_right = u_biases_left if self.share_attn_weights else self.u_biases[203self.rnn_units:, relation_id]204x_left = torch.matmul(x, u_weights_left) + u_biases_left205x_right = torch.matmul(x, u_weights_right) + u_biases_right206elif option == 'c':207c_weights_left = self.c_weights[:2 * self.rnn_units, :, relation_id]208c_biases_left = self.c_biases[:self.rnn_units, relation_id]209c_weights_right = c_weights_left if self.share_attn_weights else self.c_weights[2102 * self.rnn_units:, :, relation_id]211c_biases_right = c_biases_left if self.share_attn_weights else self.c_biases[212self.rnn_units:, relation_id]213x_left = torch.matmul(x, c_weights_left) + c_biases_left214x_right = torch.matmul(x, c_weights_right) + c_biases_right215
216i, j = torch.nonzero(adj_mx[:, :, relation_id], as_tuple=True)217i, j = i.to(device), j.to(device)218x_left_per_edge = x_left.index_select(1, i)219x_right_per_edge = x_right.index_select(1, j)220x_per_edge = x_left_per_edge + x_right_per_edge221x_per_edge = nn.functional.leaky_relu(x_per_edge, self.negative_slope)222
223alpha = (x_per_edge * global_embs[i]).sum(dim=2)224alpha = softmax(alpha, index=i, num_nodes=num_nodes, dim=1)225
226attns = torch.zeros([batch_size, num_nodes, num_nodes], device=device)227batch_idxs = torch.arange(batch_size, device=device)228batch_expand = torch.repeat_interleave(batch_idxs, len(i), dim=0)229i_expand = torch.repeat_interleave(230i.view(1, -1), batch_size, dim=0).view(-1)231j_expand = torch.repeat_interleave(232j.view(1, -1), batch_size, dim=0).view(-1)233indices = (batch_expand, i_expand, j_expand)234attns.index_put_(indices, alpha.view(-1))235
236zero_mask = (adj_mx[:, :,237relation_id] == 0).unsqueeze(0).repeat_interleave(238batch_size, dim=0)239zero_coeffs = torch.ones([batch_size, num_nodes, num_nodes],240device=device) / zero_mask.float().sum(241dim=-1, keepdim=True)242attns[zero_mask] = zero_coeffs[zero_mask]243
244out += torch.bmm(adj_mx[:, :, relation_id] * attns, x_right) + x_left245
246return out247