paddlenlp
122 строки · 5.0 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import paddle
16import paddle.nn as nn
17
18from paddlenlp.transformers import ElectraConfig, ElectraModel, ElectraPretrainedModel
19
20
21class ElectraForBinaryTokenClassification(ElectraPretrainedModel):
22"""
23Electra Model with two linear layers on top of the hidden-states output layers,
24designed for token classification tasks with nesting.
25
26Args:
27electra (:class:`ElectraModel`):
28An instance of ElectraModel.
29num_classes (list):
30The number of classes.
31dropout (float, optionl):
32The dropout probability for output of Electra.
33If None, use the same value as `hidden_dropout_prob' of 'ElectraModel`
34instance `electra`. Defaults to None.
35"""
36
37def __init__(self, config: ElectraConfig, num_classes_oth, num_classes_sym):
38super(ElectraForBinaryTokenClassification, self).__init__(config)
39self.num_classes_oth = num_classes_oth
40self.num_classes_sym = num_classes_sym
41self.electra = ElectraModel(config)
42self.dropout = nn.Dropout(config.hidden_dropout_prob)
43self.classifier_oth = nn.Linear(config.hidden_size, self.num_classes_oth)
44self.classifier_sym = nn.Linear(config.hidden_size, self.num_classes_sym)
45
46def forward(self, input_ids=None, token_type_ids=None, position_ids=None, attention_mask=None):
47sequence_output = self.electra(input_ids, token_type_ids, position_ids, attention_mask)
48sequence_output = self.dropout(sequence_output)
49
50logits_sym = self.classifier_sym(sequence_output)
51logits_oth = self.classifier_oth(sequence_output)
52
53return logits_oth, logits_sym
54
55
56class MultiHeadAttentionForSPO(nn.Layer):
57"""
58Multi-head attention layer for SPO task.
59"""
60
61def __init__(self, embed_dim, num_heads, scale_value=768):
62super(MultiHeadAttentionForSPO, self).__init__()
63self.embed_dim = embed_dim
64self.num_heads = num_heads
65self.scale_value = scale_value**-0.5
66self.q_proj = nn.Linear(embed_dim, embed_dim * num_heads)
67self.k_proj = nn.Linear(embed_dim, embed_dim * num_heads)
68
69def forward(self, query, key):
70q = self.q_proj(query)
71k = self.k_proj(key)
72q = paddle.reshape(q, shape=[0, 0, self.num_heads, self.embed_dim])
73k = paddle.reshape(k, shape=[0, 0, self.num_heads, self.embed_dim])
74q = paddle.transpose(q, perm=[0, 2, 1, 3])
75k = paddle.transpose(k, perm=[0, 2, 1, 3])
76scores = paddle.matmul(q, k, transpose_y=True)
77scores = paddle.scale(scores, scale=self.scale_value)
78return scores
79
80
81class ElectraForSPO(ElectraPretrainedModel):
82"""
83Electra Model with a linear layer on top of the hidden-states output
84layers for entity recognition, and a multi-head attention layer for
85relation classification.
86
87Args:
88electra (:class:`ElectraModel`):
89An instance of ElectraModel.
90num_classes (int):
91The number of classes.
92dropout (float, optionl):
93The dropout probability for output of Electra.
94If None, use the same value as `hidden_dropout_prob' of 'ElectraModel`
95instance `electra`. Defaults to None.
96"""
97
98def __init__(self, config: ElectraConfig):
99super(ElectraForSPO, self).__init__(config)
100self.num_classes = config.num_labels
101self.electra = ElectraModel(config)
102self.dropout = nn.Dropout(config.hidden_dropout_prob)
103self.classifier = nn.Linear(config.hidden_size, 2)
104self.span_attention = MultiHeadAttentionForSPO(config.hidden_size, config.num_labels)
105
106def forward(self, input_ids=None, token_type_ids=None, position_ids=None, attention_mask=None):
107outputs = self.electra(
108input_ids, token_type_ids, position_ids, attention_mask, output_hidden_states=True, return_dict=True
109)
110sequence_outputs = outputs.last_hidden_state
111all_hidden_states = outputs.hidden_states
112sequence_outputs = self.dropout(sequence_outputs)
113ent_logits = self.classifier(sequence_outputs)
114
115subject_output = all_hidden_states[-2]
116cls_output = paddle.unsqueeze(sequence_outputs[:, 0, :], axis=1)
117subject_output = subject_output + cls_output
118
119output_size = self.num_classes + self.electra.config["hidden_size"] # noqa:F841
120rel_logits = self.span_attention(sequence_outputs, subject_output)
121
122return ent_logits, rel_logits
123