12
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
13
from tqdm import tqdm, trange
14
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
15
from torch.utils.data.distributed import DistributedSampler
16
from transformers.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
17
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
20
# file = pickle.load(f)
27
tail_hidd_list = list()
32
pretrained_weights = 'roberta-base'
33
tokenizer = RobertaTokenizer.from_pretrained(pretrained_weights)
35
fine_tuned_weight = 'roberta-base'
36
model = RobertaForMaskedLM.from_pretrained(pretrained_weights, output_hidden_states=True,return_dict=True)
37
#model.load_state_dict(torch.load(fine_tuned_weight), strict=False)
39
#model.to(device).half()
44
old = torch.FloatTensor(768)
45
with open(file_in) as f:
47
for index, d in enumerate(tqdm(data)):
49
ids = tokenizer.encode(d["sentence"],add_special_tokens=True)
52
attention_mask = [1]*len(ids)
54
ids = ids+[1]*(max_length-len(ids))
55
padding = [0]*(max_length-len(ids))
56
attention_mask += padding
58
torch_ids = torch.tensor([ids]).to(device)
59
output = model(input_ids=torch_ids, attention_mask=attention_mask) #([1,100,768])
60
#last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
63
tail_hidd = output.hidden_states[0][0][ids_tail]
64
#tail_hidd = output.hidden_states[0][0][0]
68
#tail_hidd = output.hidden_states[0][0].mean(dim=0)
72
tail_hidd = tail_hidd.to("cpu")
73
all_data_dict[index] = {"sentence":d["sentence"], "aspect":d["aspect"], "sentiment":d["sentiment"], "ids":ids}
74
tail_hidd_list.append(tail_hidd)
78
if torch.equal(tail_hidd,old):
89
with open(file_out+'.json', 'w') as outfile:
90
json.dump(all_data_dict, outfile)
91
#torch.save(all_data_dict, outfile)
93
#tail_hidd_tensor = torch.FloatTensor(CLS_hidd_list)
94
tail_hidd_tensor = torch.stack(tail_hidd_list)
95
print(tail_hidd_tensor.shape)
96
torch.save(tail_hidd_tensor, file_out+'_CLS.pt')