CSS-LM

Форк
0
/
gen_yelp_dataset_roberta_task_finetune.py 
109 строк · 3.6 Кб
1
import pickle
2
import json
3
import argparse
4
import logging
5
import random
6
import numpy as np
7
import os
8
import json
9
import sys
10

11
import torch
12
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
13
from tqdm import tqdm, trange
14
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
15
from torch.utils.data.distributed import DistributedSampler
16
from transformers.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
17
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
18

19
#with open(FILE) as f:
20
#    file = pickle.load(f)
21

22
file_in = sys.argv[1]
23
file_out = sys.argv[2]
24
model = sys.argv[3]
25
num_samples = int(sys.argv[4])
26
#num_samples = 1000000
27

28
all_data_dict = dict()
29
max_length = 100
30
tail_hidd_list = list()
31
#device = "cpu"
32
device = "cuda"
33

34

35
pretrained_weights = model
36
tokenizer = RobertaTokenizer.from_pretrained(pretrained_weights)
37

38
fine_tuned_weight = model
39
model = RobertaForMaskedLM.from_pretrained(pretrained_weights, output_hidden_states=True,return_dict=True)
40
#model = RobertaForMaskedLMDomainTask.from_pretrained(pretrained_weights, output_hidden_states=True,return_dict=True)
41
#model.load_state_dict(torch.load(fine_tuned_weight), strict=False)
42

43
#model.to(device).half()
44
model.to(device)
45
model.eval()
46

47
#num_samples = 1000000
48

49
old = torch.FloatTensor(768)
50
with open(file_in) as f:
51
    #data = json.load(f)
52
    for index, d in tqdm(enumerate(f)):
53
        #print(type(index),index)
54
        #if index == 1000000:
55
        if index == int(num_samples):
56
            break
57
        if len(d) == 0:
58
            continue
59
        #print(d["sentence"])
60
        tokens = tokenizer.tokenize(d)
61
        if len(tokens)>=max_length-2:
62
            tokens = tokens[:max_length-2]
63
            tokens = ["<s>"] + tokens + ["</s>"]
64
            ids_tail = len(tokens)-1
65
        else:
66
            ids_tail = len(tokens)-1
67
            tokens = ["<s>"]+tokens+["</s>"]
68
        attention_mask = [1]*len(tokens)
69
        padding = ["<pad>"]*(max_length-len(tokens))
70
        tokens += padding
71
        attention_mask += [0]*len(padding)
72

73

74
        ids = tokenizer.encode(tokens, add_special_tokens=False)
75
        torch_ids = torch.tensor([ids]).to(device)
76
        attention_mask = torch.tensor([attention_mask]).to(device)
77
        output = model(input_ids=torch_ids, attention_mask=attention_mask) #([1,100,768])
78
        #output = model(input_ids=torch_ids, attention_mask=attention_mask, func="in_domain_task_rep") #([1,100,768])
79
        #last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
80
        with torch.no_grad():
81
            #every <s> in each layer
82
            tail_hidd = [x[0] for x in output.hidden_states[:]]
83
            tail_hidd = torch.stack(tail_hidd)
84
            tail_hidd = tail_hidd[:,0,:]
85
            #tail_hidd = output
86

87
        tail_hidd = tail_hidd.to("cpu")
88
        #all_data_dict[index] = {"sentence":d["sentence"], "aspect":d["aspect"], "sentiment":d["sentiment"], "ids":ids}
89
        all_data_dict[index] = {"sentence":d}
90
        tail_hidd_list.append(tail_hidd)
91

92
        #########
93
        if torch.equal(tail_hidd,old):
94
            #print(tail_hidd)
95
            #print(old)
96
            print(index)
97
            print("------")
98
            old = tail_hidd
99
        else:
100
            old = tail_hidd
101

102
with open(file_out+'.json', 'w') as outfile:
103
    json.dump(all_data_dict, outfile)
104
    #torch.save(all_data_dict, outfile)
105

106
#tail_hidd_tensor = torch.FloatTensor(CLS_hidd_list)
107
tail_hidd_tensor = torch.stack(tail_hidd_list)
108
print(tail_hidd_tensor.shape)
109
torch.save(tail_hidd_tensor, file_out+'_CLS.pt')
110

111

112

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.