google-research

Форк
0
246 строк · 9.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# coding=utf-8
17
# Copyright 2022 The Google Research Authors.
18
#
19
# Licensed under the Apache License, Version 2.0 (the "License");
20
# you may not use this file except in compliance with the License.
21
# You may obtain a copy of the License at
22
#
23
#     http://www.apache.org/licenses/LICENSE-2.0
24
#
25
# Unless required by applicable law or agreed to in writing, software
26
# distributed under the License is distributed on an "AS IS" BASIS,
27
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
# See the License for the specific language governing permissions and
29
# limitations under the License.
30
"""GATRNN unit cell model.
31

32
Defines the unit cell used in the GATRNN model.
33
"""
34

35
from pytorch_lightning import LightningModule
36
import torch
37
from torch import nn
38
from torch_geometric.utils import softmax
39

40
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41

42

43
class Seq2SeqAttrs:
44
  """Stores model-related arguments."""
45

46
  def _initialize_arguments(self, args):
47
    """Initializes model arguments.
48

49
    Args:
50
      args: python argparse.ArgumentParser class, we only use model-related
51
        arguments here.
52
    """
53
    self.input_dim = args.input_dim
54
    self.output_dim = args.output_dim
55
    self.rnn_units = args.hidden_dim
56
    self.num_nodes = args.num_nodes
57
    self.input_len = args.input_len
58
    self.output_len = args.output_len
59
    self.num_relation_types = args.num_relation_types
60

61
    self.dropout = args.dropout
62
    self.negative_slope = args.negative_slope
63

64
    self.num_rnn_layers = args.num_layers
65
    self.lr = args.learning_rate
66
    self.activation = args.activation
67
    self.share_attn_weights = args.share_attn_weights
68

69

70
class GATGRUCell(LightningModule, Seq2SeqAttrs):
71
  """Implements a single unit cell of GATRNN model."""
72

73
  def __init__(self, args):
74
    """Instantiates the GATRNN unit cell model.
75

76
    Args:
77
      args: python argparse.ArgumentParser class, we only use model-related
78
        arguments here.
79
    """
80
    super().__init__()
81
    self._initialize_arguments(args)
82

83
    self.activation = torch.tanh
84
    input_size = 2 * self.rnn_units
85

86
    # gconv
87
    weight_dim = input_size if self.share_attn_weights else 2 * input_size
88
    biases_dim = self.rnn_units if self.share_attn_weights else 2 * self.rnn_units
89

90
    self.r_weights = nn.Parameter(
91
        torch.empty((weight_dim, self.rnn_units, self.num_relation_types - 1),
92
                    device=device))
93
    self.r_biases = nn.Parameter(
94
        torch.zeros((biases_dim, self.num_relation_types - 1), device=device))
95
    self.u_weights = nn.Parameter(
96
        torch.empty((weight_dim, self.rnn_units, self.num_relation_types - 1),
97
                    device=device))
98
    self.u_biases = nn.Parameter(
99
        torch.zeros((biases_dim, self.num_relation_types - 1), device=device))
100
    self.c_weights = nn.Parameter(
101
        torch.empty((weight_dim, self.rnn_units, self.num_relation_types - 1),
102
                    device=device))
103
    self.c_biases = nn.Parameter(
104
        torch.zeros((biases_dim, self.num_relation_types - 1), device=device))
105

106
    torch.nn.init.xavier_normal_(self.r_weights)
107
    torch.nn.init.xavier_normal_(self.u_weights)
108
    torch.nn.init.xavier_normal_(self.c_weights)
109

110
  def forward(self, inputs, hx, adj, global_embs):
111
    r"""Forward computation of a single unit cell of GATRNN model.
112

113
    The forward computation is generally the same as
114
    that of a GRU cell of sequence model, but gate vectors and candidate
115
    hidden vectors are computed by graph attention
116
    network based convolutions.
117

118
    Args:
119
      inputs: input one-step time series, with shape (batch_size,
120
        self.num_nodes, self.rnn_units).
121
      hx: hidden vectors from the last unit, with shape(batch_size,
122
        self.num_nodes, self.rnn_units). If this is the first unit, usually hx
123
        is supposed to be a zero vector.
124
      adj: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
125
      global_embs: global embedding matrix, with shape (self.num_nodes,
126
        self.rnn_units).
127

128
    Returns:
129
      hx: new hidden vector.
130
    """
131
    r = torch.tanh(self._gconv(inputs, adj, global_embs, hx, 'r'))
132
    u = torch.tanh(self._gconv(inputs, adj, global_embs, hx, 'u'))
133
    c = self._gconv(inputs, adj, global_embs, r * hx,
134
                    'c')  # element-wise multiplication
135
    if self.activation is not None:
136
      c = self.activation(c)
137

138
    hx = u * hx + (1.0 - u) * c
139

140
    del r
141
    del u
142
    del c
143

144
    return hx
145

146
  @staticmethod
147
  def _concat(x, x_):
148
    r"""Concatenates two tensors along the first dimension.
149

150
    Args:
151
      x: first input tensor.
152
      x_: second input tensor.
153

154
    Returns:
155
      concatenation tensor of x and x_.
156
    """
157
    x_ = x_.unsqueeze(0)
158
    return torch.cat([x, x_], dim=0)
159

160
  def _gconv(self, inputs, adj_mx, global_embs, state, option='r'):
161
    r"""Graph attention network based convolution computation.
162

163
    Args:
164
      inputs: input vector, with shape (batch_size, self.num_nodes,
165
        self.rnn_units).
166
      adj_mx: adjacency matrix, with shape (self.num_nodes, self.num_nodes).
167
      global_embs: global embedding matrix, with shape (self.num_nodes,
168
        self.rnn_units).
169
      state: hidden vectors from the last unit, with shape(batch_size,
170
        self.num_nodes, self.rnn_units). If this is the first unit, usually hx
171
        is supposed to be a zero vector.
172
      option: indicate whether the output is reset gate vector ('r'), update
173
        gate vector ('u'), or candidate hidden vector ('c').
174

175
    Returns:
176
      out: output, can be reset gate vector (option is 'r'), update gate
177
      vector (option is 'u'), or
178
        candidate hidden vector (option is 'c').
179
    """
180
    batch_size = inputs.shape[0]
181
    num_nodes = self.num_nodes
182

183
    x = torch.cat([inputs, state], dim=-1)  # input_dim
184
    out = torch.zeros(
185
        size=(batch_size, num_nodes, self.rnn_units), device=device)
186

187
    for relation_id in range(self.num_relation_types - 1):
188
      if option == 'r':
189
        r_weights_left = self.r_weights[:2 * self.rnn_units, :, relation_id]
190
        r_biases_left = self.r_biases[:self.rnn_units, relation_id]
191
        r_weights_right = r_weights_left if self.share_attn_weights else self.r_weights[
192
            2 * self.rnn_units:, :, relation_id]
193
        r_biases_right = r_biases_left if self.share_attn_weights else self.r_biases[
194
            self.rnn_units:, relation_id]
195
        x_left = torch.matmul(x, r_weights_left) + r_biases_left
196
        x_right = torch.matmul(x, r_weights_right) + r_biases_right
197
      elif option == 'u':
198
        u_weights_left = self.u_weights[:2 * self.rnn_units, :, relation_id]
199
        u_biases_left = self.u_biases[:self.rnn_units, relation_id]
200
        u_weights_right = u_weights_left if self.share_attn_weights else self.u_weights[
201
            2 * self.rnn_units:, :, relation_id]
202
        u_biases_right = u_biases_left if self.share_attn_weights else self.u_biases[
203
            self.rnn_units:, relation_id]
204
        x_left = torch.matmul(x, u_weights_left) + u_biases_left
205
        x_right = torch.matmul(x, u_weights_right) + u_biases_right
206
      elif option == 'c':
207
        c_weights_left = self.c_weights[:2 * self.rnn_units, :, relation_id]
208
        c_biases_left = self.c_biases[:self.rnn_units, relation_id]
209
        c_weights_right = c_weights_left if self.share_attn_weights else self.c_weights[
210
            2 * self.rnn_units:, :, relation_id]
211
        c_biases_right = c_biases_left if self.share_attn_weights else self.c_biases[
212
            self.rnn_units:, relation_id]
213
        x_left = torch.matmul(x, c_weights_left) + c_biases_left
214
        x_right = torch.matmul(x, c_weights_right) + c_biases_right
215

216
      i, j = torch.nonzero(adj_mx[:, :, relation_id], as_tuple=True)
217
      i, j = i.to(device), j.to(device)
218
      x_left_per_edge = x_left.index_select(1, i)
219
      x_right_per_edge = x_right.index_select(1, j)
220
      x_per_edge = x_left_per_edge + x_right_per_edge
221
      x_per_edge = nn.functional.leaky_relu(x_per_edge, self.negative_slope)
222

223
      alpha = (x_per_edge * global_embs[i]).sum(dim=2)
224
      alpha = softmax(alpha, index=i, num_nodes=num_nodes, dim=1)
225

226
      attns = torch.zeros([batch_size, num_nodes, num_nodes], device=device)
227
      batch_idxs = torch.arange(batch_size, device=device)
228
      batch_expand = torch.repeat_interleave(batch_idxs, len(i), dim=0)
229
      i_expand = torch.repeat_interleave(
230
          i.view(1, -1), batch_size, dim=0).view(-1)
231
      j_expand = torch.repeat_interleave(
232
          j.view(1, -1), batch_size, dim=0).view(-1)
233
      indices = (batch_expand, i_expand, j_expand)
234
      attns.index_put_(indices, alpha.view(-1))
235

236
      zero_mask = (adj_mx[:, :,
237
                          relation_id] == 0).unsqueeze(0).repeat_interleave(
238
                              batch_size, dim=0)
239
      zero_coeffs = torch.ones([batch_size, num_nodes, num_nodes],
240
                               device=device) / zero_mask.float().sum(
241
                                   dim=-1, keepdim=True)
242
      attns[zero_mask] = zero_coeffs[zero_mask]
243

244
      out += torch.bmm(adj_mx[:, :, relation_id] * attns, x_right) + x_left
245

246
    return out
247

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.