13
from transformers import BertTokenizer, BertForPreTraining, BertForSequenceClassification
14
from tqdm import tqdm, trange
15
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
16
from torch.utils.data.distributed import DistributedSampler
17
from transformers.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
18
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
24
pretrained_weights = '/data5/private/suyusheng/task_selecte/bert-base-uncased-128/'
25
tokenizer = BertTokenizer.from_pretrained(pretrained_weights, do_lower_case=True)
27
fine_tuned_weight = '/data5/private/suyusheng/task_selecte/output_finetune/pytorch_model.bin_1314'
28
model = BertForPreTraining.from_pretrained(pretrained_weights, output_hidden_states=True,return_dict=True)
29
model.load_state_dict(torch.load(fine_tuned_weight), strict=False)
34
out_CLS = torch.load("/data5/private/suyusheng/task_selecte/data/open_domain_preprocessed/opendomain_CLS_res.pt")
35
out_CLS = out_CLS.to(device)
38
with open("/data5/private/suyusheng/task_selecte/data/open_domain_preprocessed/opendomain_res.json") as f:
39
out_data = json.load(f)
41
with open("../data/restaurant/train.json") as f:
43
for index, d in enumerate(tqdm(data)):
47
ids = tokenizer.encode(d["sentence"],add_special_tokens=True)
48
ids = ids+[0]*(max_length-len(ids))
49
torch_ids = torch.tensor([ids])
50
torch_ids = torch_ids.to(device)
51
output = model(torch_ids)
52
CLS_hidd = output["hidden_states"][-1][0][0]
55
result = CLS_hidd.matmul(out_CLS.reshape(out_CLS.shape[1],out_CLS.shape[0]))
57
top_n = result.topk(k=k, dim=0, largest=True, sorted=False)
61
bottom_n = result.topk(k=k, dim=0, largest=False, sorted=False)
65
print("time:",t_end-t_start)
68
print("===Ranking===")
72
print("===============")
73
print("===============")
79
print(out_data[str(int(index))]['sentence'])
80
print(out_data[str(int(index))]['aspect'])
81
print(out_data[str(int(index))]['sentiment'])
83
print("===============")
84
print("===============")
87
score = bottom_n[0][i]
88
index = bottom_n[1][i]
91
print(out_data[str(int(index))]['sentence'])
92
print(out_data[str(int(index))]['aspect'])
93
print(out_data[str(int(index))]['sentiment'])