WantWords

Форк
0
/
model.py 
108 строк · 5.2 Кб
1
import torch
2

3
class Encoder(torch.nn.Module):
4
    def __init__(self, vocab_size, embed_dim, hidden_dim, layers, class_num, encoder, sememe_num, chara_num, mode):
5
        super().__init__()
6
        self.vocab_size = vocab_size
7
        self.embed_dim = 200
8
        self.hidden_dim = 768
9
        self.layers = layers
10
        self.class_num = class_num
11
        self.sememe_num = sememe_num
12
        self.chara_num = chara_num
13
        self.embedding = torch.nn.Embedding(self.vocab_size, self.embed_dim, padding_idx=0, max_norm=5, sparse=True)
14
        self.embedding.weight.requires_grad = False
15
        self.embedding_dropout = torch.nn.Dropout(0.2)
16
        self.encoder = encoder
17
        self.fc = torch.nn.Linear(self.hidden_dim, self.embed_dim)
18
        self.loss = torch.nn.CrossEntropyLoss()
19
        self.relu = torch.nn.ReLU()
20
        if 'P' in mode:
21
            self.fc2 = torch.nn.Linear(self.hidden_dim, 13)
22
        if 's' in mode:
23
            self.fc1 = torch.nn.Linear(self.hidden_dim, self.sememe_num)
24
        if 'c' in mode:
25
            self.fc3 = torch.nn.Linear(self.hidden_dim, self.chara_num)
26
        if 'C' in mode:
27
            self.fc_C1 = torch.nn.Linear(self.hidden_dim, 12)
28
            self.fc_C2 = torch.nn.Linear(self.hidden_dim, 95)
29
            self.fc_C3 = torch.nn.Linear(self.hidden_dim, 1425)
30
        
31
    def forward(self, operation, x=None, w=None, ws=None, wP=None, wc=None, wC=None, msk_s=None, msk_c=None, mode=None):
32
        # x: T(bat, max_word_num)
33
        # w: T(bat)
34
        # x_embedding: T(bat, max_word_num, embed_dim)
35
        x = x.long()
36
        attention_mask = torch.gt(x, 0).to(torch.int64)
37
        h = self.encoder(x, attention_mask=attention_mask)[0]
38
        h_1 = self.embedding_dropout(h[:,0,:])
39
        vd = self.fc(h_1)
40
        # score0: T(bat, 30000) = [bat, emb] .mm [class_num, emb].t()
41
        score0 = vd.mm(self.embedding.weight.data[[range(self.class_num)]].t())
42
        score = score0
43
        
44
        if 'C' in mode:
45
            # scC[i]: T(bat, Ci_size)
46
            # 词林的层次分类训练的慢,其实这样不公平,不平衡,因为词预测先收敛了,而cilin的分类还没效果,其他信息的利用也有同样的问题,不一定同时收敛!!!
47
            scC = [self.fc_C1(h_1), self.fc_C2(h_1), self.fc_C3(h_1)]
48
            score2 = torch.zeros((score0.shape[0], score0.shape[1]), dtype=torch.float32)
49
            rank = 0.6
50
            for i in range(3):
51
                # wC[i]: T(class_num, Ci_size)
52
                # C_sc: T(bat, class_num)
53
                score2 += self.relu(scC[i].mm(wC[i].t())*(rank**i))
54
            #----------add mean cilin-class score to those who have no cilin-class
55
            mean_cilin_sc = torch.mean(score2, 1)
56
            score2 = score2*(1-msk_c) + mean_cilin_sc.unsqueeze(1).mm(msk_c.unsqueeze(0))
57
            #----------
58
            score = score + score2/2
59
        if 'P' in mode:
60
            ## POS prediction
61
            # score_POS: T(bat, 13) pos_num=12+1
62
            score_POS = self.fc2(h_1)
63
            # s: (class_num, 13) multi-hot
64
            # weight_sc: T(bat, class_num) = [bat, 13] .mm [class_num, 13].t()
65
            weight_sc = self.relu(score_POS.mm(wP.t()))
66
            #print(torch.max(weight_sc), torch.min(weight_sc))
67
            score = score + weight_sc
68
        if 's' in mode:
69
            ## sememe prediction
70
            # pos_score: T(bat, max_word_num, sememe_num)
71
            pos_score = self.fc1(h)
72
            # sem_score: T(bat, sememe_num)
73
            sem_score, _ = torch.max(pos_score, dim=1)
74
            # score: T(bat, class_num) = [bat, sememe_num] .mm [class_num, sememe_num].t()
75
            score1 = self.relu(sem_score.mm(ws.t()))
76
            #----------add mean sememe score to those who have no sememes
77
            # mean_sem_sc: T(bat)
78
            mean_sem_sc = torch.mean(score1, 1)
79
            # msk: T(class_num)
80
            score1 = score1 + mean_sem_sc.unsqueeze(1).mm(msk_s.unsqueeze(0))
81
            #----------
82
            score = score + score1
83
        if 'c' in mode:
84
            ## character prediction
85
            # pos_score: T(bat, max_word_num, sememe_num)
86
            pos_score = self.fc3(h)
87
            # chara_score: T(bat, chara_num)
88
            chara_score, _ = torch.max(pos_score, dim=1)
89
            #chara_score = torch.sum(pos_score * alpha, 1)
90
            # score: T(bat, class_num) = [bat, sememe_num] .mm [class_num, sememe_num].t()
91
            score3 = self.relu(chara_score.mm(wc.t()))
92
            score = score + score3
93
        '''
94
        if RD_mode ==  'CC':
95
            # fine-tune depended on the target word shouldn't exist in the definition.
96
            #score_res = score.clone().detach()
97
            mask1 = torch.lt(x, self.class_num).to(torch.int64)
98
            mask2 = torch.ones((score.shape[0], score.shape[1]), dtype=torch.float32)
99
            for i in range(x.shape[0]):
100
                mask2[i][x[i]*mask1[i]] = 0.
101
            score = score * mask2 + (-1e6)*(1-mask2)
102
        '''
103
        #_, indices = torch.sort(score, descending=True)
104
        if operation == 'train':
105
            loss = self.loss(score, w.long())
106
            return loss, score, indices
107
        elif operation == 'test':
108
            return score #, indices
109

110

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.