WantWords
/
model.py
108 строк · 5.2 Кб
1import torch
2
3class Encoder(torch.nn.Module):
4def __init__(self, vocab_size, embed_dim, hidden_dim, layers, class_num, encoder, sememe_num, chara_num, mode):
5super().__init__()
6self.vocab_size = vocab_size
7self.embed_dim = 200
8self.hidden_dim = 768
9self.layers = layers
10self.class_num = class_num
11self.sememe_num = sememe_num
12self.chara_num = chara_num
13self.embedding = torch.nn.Embedding(self.vocab_size, self.embed_dim, padding_idx=0, max_norm=5, sparse=True)
14self.embedding.weight.requires_grad = False
15self.embedding_dropout = torch.nn.Dropout(0.2)
16self.encoder = encoder
17self.fc = torch.nn.Linear(self.hidden_dim, self.embed_dim)
18self.loss = torch.nn.CrossEntropyLoss()
19self.relu = torch.nn.ReLU()
20if 'P' in mode:
21self.fc2 = torch.nn.Linear(self.hidden_dim, 13)
22if 's' in mode:
23self.fc1 = torch.nn.Linear(self.hidden_dim, self.sememe_num)
24if 'c' in mode:
25self.fc3 = torch.nn.Linear(self.hidden_dim, self.chara_num)
26if 'C' in mode:
27self.fc_C1 = torch.nn.Linear(self.hidden_dim, 12)
28self.fc_C2 = torch.nn.Linear(self.hidden_dim, 95)
29self.fc_C3 = torch.nn.Linear(self.hidden_dim, 1425)
30
31def forward(self, operation, x=None, w=None, ws=None, wP=None, wc=None, wC=None, msk_s=None, msk_c=None, mode=None):
32# x: T(bat, max_word_num)
33# w: T(bat)
34# x_embedding: T(bat, max_word_num, embed_dim)
35x = x.long()
36attention_mask = torch.gt(x, 0).to(torch.int64)
37h = self.encoder(x, attention_mask=attention_mask)[0]
38h_1 = self.embedding_dropout(h[:,0,:])
39vd = self.fc(h_1)
40# score0: T(bat, 30000) = [bat, emb] .mm [class_num, emb].t()
41score0 = vd.mm(self.embedding.weight.data[[range(self.class_num)]].t())
42score = score0
43
44if 'C' in mode:
45# scC[i]: T(bat, Ci_size)
46# 词林的层次分类训练的慢,其实这样不公平,不平衡,因为词预测先收敛了,而cilin的分类还没效果,其他信息的利用也有同样的问题,不一定同时收敛!!!
47scC = [self.fc_C1(h_1), self.fc_C2(h_1), self.fc_C3(h_1)]
48score2 = torch.zeros((score0.shape[0], score0.shape[1]), dtype=torch.float32)
49rank = 0.6
50for i in range(3):
51# wC[i]: T(class_num, Ci_size)
52# C_sc: T(bat, class_num)
53score2 += self.relu(scC[i].mm(wC[i].t())*(rank**i))
54#----------add mean cilin-class score to those who have no cilin-class
55mean_cilin_sc = torch.mean(score2, 1)
56score2 = score2*(1-msk_c) + mean_cilin_sc.unsqueeze(1).mm(msk_c.unsqueeze(0))
57#----------
58score = score + score2/2
59if 'P' in mode:
60## POS prediction
61# score_POS: T(bat, 13) pos_num=12+1
62score_POS = self.fc2(h_1)
63# s: (class_num, 13) multi-hot
64# weight_sc: T(bat, class_num) = [bat, 13] .mm [class_num, 13].t()
65weight_sc = self.relu(score_POS.mm(wP.t()))
66#print(torch.max(weight_sc), torch.min(weight_sc))
67score = score + weight_sc
68if 's' in mode:
69## sememe prediction
70# pos_score: T(bat, max_word_num, sememe_num)
71pos_score = self.fc1(h)
72# sem_score: T(bat, sememe_num)
73sem_score, _ = torch.max(pos_score, dim=1)
74# score: T(bat, class_num) = [bat, sememe_num] .mm [class_num, sememe_num].t()
75score1 = self.relu(sem_score.mm(ws.t()))
76#----------add mean sememe score to those who have no sememes
77# mean_sem_sc: T(bat)
78mean_sem_sc = torch.mean(score1, 1)
79# msk: T(class_num)
80score1 = score1 + mean_sem_sc.unsqueeze(1).mm(msk_s.unsqueeze(0))
81#----------
82score = score + score1
83if 'c' in mode:
84## character prediction
85# pos_score: T(bat, max_word_num, sememe_num)
86pos_score = self.fc3(h)
87# chara_score: T(bat, chara_num)
88chara_score, _ = torch.max(pos_score, dim=1)
89#chara_score = torch.sum(pos_score * alpha, 1)
90# score: T(bat, class_num) = [bat, sememe_num] .mm [class_num, sememe_num].t()
91score3 = self.relu(chara_score.mm(wc.t()))
92score = score + score3
93'''
94if RD_mode == 'CC':
95# fine-tune depended on the target word shouldn't exist in the definition.
96#score_res = score.clone().detach()
97mask1 = torch.lt(x, self.class_num).to(torch.int64)
98mask2 = torch.ones((score.shape[0], score.shape[1]), dtype=torch.float32)
99for i in range(x.shape[0]):
100mask2[i][x[i]*mask1[i]] = 0.
101score = score * mask2 + (-1e6)*(1-mask2)
102'''
103#_, indices = torch.sort(score, descending=True)
104if operation == 'train':
105loss = self.loss(score, w.long())
106return loss, score, indices
107elif operation == 'test':
108return score #, indices
109
110