CSS-LM

Форк
0
/
gen_opendomain_dataset_roberta.py 
96 строк · 2.8 Кб
1
import pickle
2
import json
3
import argparse
4
import logging
5
import random
6
import numpy as np
7
import os
8
import json
9
import sys
10

11
import torch
12
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
13
from tqdm import tqdm, trange
14
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
15
from torch.utils.data.distributed import DistributedSampler
16
from transformers.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
17
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
18

19
#with open(FILE) as f:
20
#    file = pickle.load(f)
21

22
file_in = sys.argv[1]
23
file_out = sys.argv[2]
24

25
all_data_dict = dict()
26
max_length = 100
27
tail_hidd_list = list()
28
#device = "cpu"
29
device = "cuda"
30

31

32
pretrained_weights = 'roberta-base'
33
tokenizer = RobertaTokenizer.from_pretrained(pretrained_weights)
34

35
fine_tuned_weight = 'roberta-base'
36
model = RobertaForMaskedLM.from_pretrained(pretrained_weights, output_hidden_states=True,return_dict=True)
37
#model.load_state_dict(torch.load(fine_tuned_weight), strict=False)
38

39
#model.to(device).half()
40
model.to(device)
41
model.eval()
42

43

44
old = torch.FloatTensor(768)
45
with open(file_in) as f:
46
    data = json.load(f)
47
    for index, d in enumerate(tqdm(data)):
48
        #print(d["sentence"])
49
        ids = tokenizer.encode(d["sentence"],add_special_tokens=True)
50
        #print(ids)
51
        ids_tail = len(ids)-1
52
        attention_mask = [1]*len(ids)
53
        #<pad> --> 1
54
        ids = ids+[1]*(max_length-len(ids))
55
        padding = [0]*(max_length-len(ids))
56
        attention_mask += padding
57

58
        torch_ids = torch.tensor([ids]).to(device)
59
        output = model(input_ids=torch_ids, attention_mask=attention_mask) #([1,100,768])
60
        #last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
61
        with torch.no_grad():
62
            ###
63
            tail_hidd = output.hidden_states[0][0][ids_tail]
64
            #tail_hidd = output.hidden_states[0][0][0]
65
            ###
66

67
            ###
68
            #tail_hidd = output.hidden_states[0][0].mean(dim=0)
69
            ###
70

71

72
        tail_hidd = tail_hidd.to("cpu")
73
        all_data_dict[index] = {"sentence":d["sentence"], "aspect":d["aspect"], "sentiment":d["sentiment"], "ids":ids}
74
        tail_hidd_list.append(tail_hidd)
75

76
        #########
77
        '''
78
        if torch.equal(tail_hidd,old):
79
            for
80
            print(tail_hidd)
81
            print(old)
82
            exit()
83
            print(index)
84
            old = tail_hidd
85
        else:
86
            old = tail_hidd
87
        '''
88

89
with open(file_out+'.json', 'w') as outfile:
90
    json.dump(all_data_dict, outfile)
91
    #torch.save(all_data_dict, outfile)
92

93
#tail_hidd_tensor = torch.FloatTensor(CLS_hidd_list)
94
tail_hidd_tensor = torch.stack(tail_hidd_list)
95
print(tail_hidd_tensor.shape)
96
torch.save(tail_hidd_tensor, file_out+'_CLS.pt')
97

98

99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.