CSS-LM

Форк
0
/
test_disb.py 
106 строк · 3.7 Кб
1
import pickle
2
import json
3
import argparse
4
import logging
5
import random
6
import numpy as np
7
import os
8
import json
9
import sys
10
import time
11

12
import torch
13
from transformers import BertTokenizer, BertForPreTraining, BertForSequenceClassification
14
from tqdm import tqdm, trange
15
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
16
from torch.utils.data.distributed import DistributedSampler
17
from transformers.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
18
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
19

20
max_length=100
21
k=10
22
device="cpu"
23

24
pretrained_weights = '/data5/private/suyusheng/task_selecte/bert-base-uncased-128/'
25
tokenizer = BertTokenizer.from_pretrained(pretrained_weights, do_lower_case=True)
26

27
fine_tuned_weight = '/data5/private/suyusheng/task_selecte/output_finetune/pytorch_model.bin_1314'
28
model = BertForPreTraining.from_pretrained(pretrained_weights, output_hidden_states=True,return_dict=True)
29
model.load_state_dict(torch.load(fine_tuned_weight), strict=False)
30
model.to(device)
31

32

33
#out_CLS = torch.load("/data5/private/suyusheng/task_selecte/data/open_domain_preprocessed/opendomain_CLS.pt")
34
out_CLS = torch.load("/data5/private/suyusheng/task_selecte/data/open_domain_preprocessed/opendomain_CLS_res.pt")
35
out_CLS = out_CLS.to(device)
36

37
#with open("/data5/private/suyusheng/task_selecte/data/open_domain_preprocessed/opendomain.json") as f:
38
with open("/data5/private/suyusheng/task_selecte/data/open_domain_preprocessed/opendomain_res.json") as f:
39
    out_data = json.load(f)
40

41
with open("../data/restaurant/train.json") as f:
42
    data = json.load(f)
43
    for index, d in enumerate(tqdm(data)):
44
        #if index <= 1:
45
        #    continue
46

47
        ids = tokenizer.encode(d["sentence"],add_special_tokens=True)
48
        ids = ids+[0]*(max_length-len(ids))
49
        torch_ids = torch.tensor([ids])
50
        torch_ids = torch_ids.to(device)
51
        output = model(torch_ids) #([1,100,768])
52
        CLS_hidd = output["hidden_states"][-1][0][0]
53
        #print(CLS_hidd.shape)
54
        t_start = time.time()
55
        result = CLS_hidd.matmul(out_CLS.reshape(out_CLS.shape[1],out_CLS.shape[0]))
56
        #top_n = torch.topk(result, 10, dim=None, largest=True, sorted=False, out=None)
57
        top_n = result.topk(k=k, dim=0, largest=True, sorted=False)
58
        #print(top_n)
59
        #print(type(top_n))
60
        #exit()
61
        bottom_n = result.topk(k=k, dim=0, largest=False, sorted=False)
62
        #print(bottom_n)
63
        #result = CLS_hidd.dot(out_CLS)
64
        t_end = time.time()
65
        print("time:",t_end-t_start)
66
        #print(result.shape)
67

68
        print("===Ranking===")
69
        print(d['sentence'])
70
        print(d['aspect'])
71
        print(d['sentiment'])
72
        print("===============")
73
        print("===============")
74
        print("top_n")
75
        for i in range(k):
76
            score = top_n[0][i]
77
            index = top_n[1][i]
78
            print(score)
79
            print(out_data[str(int(index))]['sentence'])
80
            print(out_data[str(int(index))]['aspect'])
81
            print(out_data[str(int(index))]['sentiment'])
82
            print("---")
83
        print("===============")
84
        print("===============")
85
        print("bottom_n")
86
        for i in range(k):
87
            score = bottom_n[0][i]
88
            index = bottom_n[1][i]
89
            #print(score, out_data[str(int(index))])
90
            print(score)
91
            print(out_data[str(int(index))]['sentence'])
92
            print(out_data[str(int(index))]['aspect'])
93
            print(out_data[str(int(index))]['sentiment'])
94
            print("---")
95

96
        #for v, id in top_n[0]:
97
        #    print(v,id)
98

99
        exit()
100

101
#CPU
102
#time: 0.018637657165527344
103
#time: 0.051631927490234375
104
#GPU
105
#time: 0.00010919570922851562
106
#time: 0.00010061264038085938
107

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.