paddlenlp
89 строк · 3.1 Кб
1# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import paddle16import paddle.nn as nn17import paddle.nn.functional as F18
19
20class PointwiseMatching(nn.Layer):21def __init__(self, pretrained_model, dropout=None):22super().__init__()23self.ptm = pretrained_model24self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)25
26# num_labels = 2 (similar or dissimilar)27self.classifier = nn.Linear(self.ptm.config["hidden_size"], 2)28
29def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):30
31_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids, attention_mask)32
33cls_embedding = self.dropout(cls_embedding)34logits = self.classifier(cls_embedding)35probs = F.softmax(logits)36
37return probs38
39
40class PairwiseMatching(nn.Layer):41def __init__(self, pretrained_model, dropout=None, margin=0.1):42super().__init__()43self.ptm = pretrained_model44self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)45self.margin = margin46
47# hidden_size -> 1, calculate similarity48self.similarity = nn.Linear(self.ptm.config["hidden_size"], 1)49
50def predict(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):51
52_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids, attention_mask)53
54cls_embedding = self.dropout(cls_embedding)55sim_score = self.similarity(cls_embedding)56sim_score = F.sigmoid(sim_score)57
58return sim_score59
60def forward(61self,62pos_input_ids,63neg_input_ids,64pos_token_type_ids=None,65neg_token_type_ids=None,66pos_position_ids=None,67neg_position_ids=None,68pos_attention_mask=None,69neg_attention_mask=None,70):71
72_, pos_cls_embedding = self.ptm(pos_input_ids, pos_token_type_ids, pos_position_ids, pos_attention_mask)73
74_, neg_cls_embedding = self.ptm(neg_input_ids, neg_token_type_ids, neg_position_ids, neg_attention_mask)75
76pos_embedding = self.dropout(pos_cls_embedding)77neg_embedding = self.dropout(neg_cls_embedding)78
79pos_sim = self.similarity(pos_embedding)80neg_sim = self.similarity(neg_embedding)81
82pos_sim = F.sigmoid(pos_sim)83neg_sim = F.sigmoid(neg_sim)84
85labels = paddle.full(shape=[pos_cls_embedding.shape[0]], fill_value=1.0, dtype="float32")86
87loss = F.margin_ranking_loss(pos_sim, neg_sim, labels, margin=self.margin)88
89return loss90