CSS-LM
1791 строка · 77.6 Кб
1# coding=utf-8
2# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""BERT finetuning runner."""
17
18from __future__ import absolute_import, division, print_function, unicode_literals
19
20import argparse
21import logging
22import os
23import random
24from io import open
25import json
26import time
27
28import numpy as np
29import torch
30from torch.utils.data import DataLoader, Dataset, RandomSampler
31from torch.utils.data.distributed import DistributedSampler
32from tqdm import tqdm, trange
33
34from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
35from transformers.modeling_roberta import RobertaForMaskedLMDomainTask
36from transformers.optimization import AdamW, get_linear_schedule_with_warmup
37
38logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
39datefmt='%m/%d/%Y %H:%M:%S',
40level=logging.INFO)
41logger = logging.getLogger(__name__)
42
43
44#def default_all_type_sentence(batch):
45
46
47
48def return_Classifier(weight, bias, dim_in, dim_out):
49#LeakyReLU = torch.nn.LeakyReLU
50classifier = torch.nn.Linear(dim_in, dim_out , bias=True)
51#print(classifier)
52#print(classifier.weight)
53#print(classifier.weight.shape)
54#print(classifier.weight.data)
55#print(classifier.weight.data.shape)
56#print("---")
57classifier.weight.data = weight.to("cpu")
58classifier.bias.data = bias.to("cpu")
59classifier.requires_grad=False
60#print(classifier)
61#print(classifier.weight)
62#print(classifier.weight.shape)
63#print("---")
64#exit()
65#print(classifier)
66#exit()
67#logit = LeakyReLU(classifier)
68return classifier
69
70
71def load_GeneralDomain(dir_data_out):
72###Test
73if dir_data_out=="data/open_domain_preprocessed_roberta/":
74docs = torch.load(dir_data_out+"opendomain_CLS.pt")
75with open(dir_data_out+"opendomain.json") as file:
76data = json.load(file)
77print("train.json Done")
78print("===========")
79docs = docs.unsqueeze(1)
80return docs, data
81###
82
83
84###
85elif dir_data_out=="data/yelp/":
86print("===========")
87print("Load CLS.pt and train.json")
88print("-----------")
89docs = torch.load(dir_data_out+"train_CLS.pt")
90print("CLS.pt Done")
91print(docs.shape)
92print("-----------")
93with open(dir_data_out+"train.json") as file:
94data = json.load(file)
95print("train.json Done")
96print("===========")
97return docs, data
98###
99
100
101###
102elif dir_data_out=="data/yelp_finetune_noword_10000/":
103print("===========")
104print("Load CLS.pt and train.json")
105print("-----------")
106docs = torch.load(dir_data_out+"train_CLS.pt")
107print("CLS.pt Done")
108print(docs.shape)
109print("-----------")
110with open(dir_data_out+"train.json") as file:
111data = json.load(file)
112print("train.json Done")
113print("===========")
114return docs, data
115###
116
117
118def load_GeneralDomain_docs(dir_data_out):
119###Test
120if dir_data_out=="data/open_domain_preprocessed_roberta/":
121docs = torch.load(dir_data_out+"opendomain_CLS.pt")
122#with open(dir_data_out+"opendomain.json") as file:
123# data = json.load(file)
124#print("train.json Done")
125print("===========")
126docs = docs.unsqueeze(1)
127return docs
128###
129
130elif dir_data_out=="data/yelp/":
131###
132print("===========")
133print("Load CLS.pt and train.json")
134print("-----------")
135docs = torch.load(dir_data_out+"train_CLS.pt")
136print("CLS.pt Done")
137print(docs.shape)
138#print("-----------")
139#with open(dir_data_out+"train.json") as file:
140# data = json.load(file)
141#print("train.json Done")
142#print("===========")
143return docs
144###
145
146
147def load_GeneralDomain_data(dir_data_out):
148###Test
149if dir_data_out=="data/open_domain_preprocessed_roberta/":
150#docs = torch.load(dir_data_out+"opendomain_CLS.pt")
151with open(dir_data_out+"opendomain.json") as file:
152data = json.load(file)
153#print("train.json Done")
154#print("===========")
155#docs = docs.unsqueeze(1)
156return data
157###
158
159elif dir_data_out=="data/yelp/":
160###
161#print("===========")
162#print("Load CLS.pt and train.json")
163#print("-----------")
164#docs = torch.load(dir_data_out+"train_CLS.pt")
165#print("CLS.pt Done")
166#print(docs.shape)
167print("-----------")
168with open(dir_data_out+"train.json") as file:
169data = json.load(file)
170print("train.json Done")
171print("===========")
172return data
173###
174
175#Load outDomainData
176###Test
177#docs, data = load_GeneralDomain("data/open_domain_preprocessed_roberta/")
178#data = load_GeneralDomain_data("data/open_domain_preprocessed_roberta/")
179#data = load_GeneralDomain_data("data/yelp/")
180#docs = load_GeneralDomain_docs("data/yelp/")
181######
182###
183#docs, data = load_GeneralDomain("data/open_domain_preprocessed_roberta/")
184#docs, data = load_GeneralDomain("data/yelp")
185docs, data = load_GeneralDomain("data/yelp_finetune_noword_10000/")
186######
187if docs.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
188#last <s>
189#docs = docs[:,0,:].unsqueeze(1)
190#mean 13 layers <s>
191docs = docs.mean(1).unsqueeze(1)
192print(docs.shape)
193else:
194print(docs.shape)
195######
196
197def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
198###Open
199with open(data_dir_indomain+"train.json") as file:
200data = json.load(file)
201
202###Preprocess
203num_label_list = list()
204label_sentence_dict = dict()
205num_sentiment_label_list = list()
206sentiment_label_dict = dict()
207for line in data:
208#line["sentence"]
209#line["aspect"]
210#line["sentiment"]
211num_sentiment_label_list.append(line["sentiment"])
212num_label_list.append(line["aspect"])
213
214num_label = sorted(list(set(num_label_list)))
215label_map = {label : i for i , label in enumerate(num_label)}
216num_sentiment_label = sorted(list(set(num_sentiment_label_list)))
217sentiment_label_map = {label : i for i , label in enumerate(num_sentiment_label)}
218print("=======")
219print("label_map:")
220print(label_map)
221print("=======")
222print("=======")
223print("sentiment_label_map:")
224print(sentiment_label_map)
225print("=======")
226
227###Create data: 1 choosed data along with the rest of 7 class data
228
229'''
230all_input_ids = list()
231all_input_mask = list()
232all_segment_ids = list()
233all_lm_labels_ids = list()
234all_is_next = list()
235all_tail_idxs = list()
236all_sentence_labels = list()
237'''
238cur_tensors_list = list()
239#print(list(label_map.values()))
240candidate_label_list = list(label_map.values())
241candidate_sentiment_label_list = list(sentiment_label_map.values())
242all_type_sentence = [0]*len(candidate_label_list)
243all_type_sentiment_sentence = [0]*len(candidate_sentiment_label_list)
244for line in data:
245#line["sentence"]
246#line["aspect"]
247sentiment = line["sentiment"]
248sentence = line["sentence"]
249label = line["aspect"]
250
251
252tokens_a = tokenizer.tokenize(sentence)
253#input_ids = tokenizer.encode(sentence, add_special_tokens=False)
254'''
255if "</s>" in tokens_a:
256print("Have more than 1 </s>")
257#tokens_a[tokens_a.index("<s>")] = "s"
258for i in range(len(tokens_a)):
259if tokens_a[i] == "</s>":
260tokens_a[i] == "s"
261'''
262
263
264# tokenize
265cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
266# transform sample to features
267cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
268
269cur_tensors = (torch.tensor(cur_features.input_ids),
270torch.tensor(cur_features.input_ids_org),
271torch.tensor(cur_features.input_mask),
272torch.tensor(cur_features.segment_ids),
273torch.tensor(cur_features.lm_label_ids),
274torch.tensor(0),
275torch.tensor(cur_features.tail_idxs),
276torch.tensor(label_map[label]),
277torch.tensor(sentiment_label_map[sentiment]))
278
279cur_tensors_list.append(cur_tensors)
280
281###
282if label_map[label] in candidate_label_list:
283all_type_sentence[label_map[label]]=cur_tensors
284candidate_label_list.remove(label_map[label])
285
286if sentiment_label_map[sentiment] in candidate_sentiment_label_list:
287#print("----")
288#print(sentiment_label_map[sentiment])
289#print("----")
290all_type_sentiment_sentence[sentiment_label_map[sentiment]]=cur_tensors
291candidate_sentiment_label_list.remove(sentiment_label_map[sentiment])
292###
293
294
295
296
297'''
298all_input_ids.append(torch.tensor(cur_features.input_ids))
299all_input_mask.append(torch.tensor(cur_features.input_mask))
300all_segment_ids.append(torch.tensor(cur_features.segment_ids))
301all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
302all_is_next.append(torch.tensor(0))
303all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
304all_sentence_labels.append(torch.tensor(label_map[label]))
305
306cur_tensors = (torch.stack(all_input_ids),
307torch.stack(all_input_mask),
308torch.stack(all_segment_ids),
309torch.stack(all_lm_labels_ids),
310torch.stack(all_is_next),
311torch.stack(all_tail_idxs),
312torch.stack(all_sentence_labels))
313'''
314
315
316'''
317print("=====")
318print(candidate_label_list)
319print("---")
320print(all_type_sentence)
321print("---")
322print(len(cur_tensors_list))
323exit()
324'''
325
326#return cur_tensors
327#for line in all_type_sentiment_sentence:
328# print(line[-1])
329#exit()
330return all_type_sentiment_sentence, cur_tensors_list
331
332
333def in_Domain_Task_Data_binary(data_dir_indomain, tokenizer, max_seq_length):
334###Open
335with open(data_dir_indomain+"train.json") as file:
336data = json.load(file)
337
338###Preprocess
339num_label_list = list()
340label_sentence_dict = dict()
341for line in data:
342#line["sentence"]
343#line["aspect"]
344#line["sentiment"]
345num_label_list.append(line["aspect"])
346try:
347label_sentence_dict[line["aspect"]].append([line["sentence"]])
348except:
349label_sentence_dict[line["aspect"]] = [line["sentence"]]
350
351num_label = sorted(list(set(num_label_list)))
352label_map = {label : i for i , label in enumerate(num_label)}
353
354###Create data: 1 choosed data along with the rest of 7 class data
355all_cur_tensors = list()
356for line in data:
357#line["sentence"]
358#line["aspect"]
359#line["sentiment"]
360sentence = line["sentence"]
361label = line["aspect"]
362sentence_out = [(random.choice(label_sentence_dict[label_out])[0], label_out) for label_out in num_label if label_out!=label]
363all_sentence = [(sentence, label)] + sentence_out #1st sentence is choosed
364
365all_input_ids = list()
366all_input_mask = list()
367all_segment_ids = list()
368all_lm_labels_ids = list()
369all_is_next = list()
370all_tail_idxs = list()
371all_sentence_labels = list()
372for id, sentence_label in enumerate(all_sentence):
373#tokens_a = tokenizer.tokenize(sentence_label[0])
374tokens_a = tokenizer.tokenize(sentence_label[0])
375'''
376if "</s>" in tokens_a:
377print("Have more than 1 </s>")
378for i in range(len(tokens_a)):
379if tokens_a[i] == "</s>":
380tokens_a[i] = "s"
381'''
382
383# tokenize
384cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
385# transform sample to features
386cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
387
388all_input_ids.append(torch.tensor(cur_features.input_ids))
389all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
390all_input_mask.append(torch.tensor(cur_features.input_mask))
391all_segment_ids.append(torch.tensor(cur_features.segment_ids))
392all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
393all_is_next.append(torch.tensor(0))
394all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
395all_sentence_labels.append(torch.tensor(label_map[sentence_label[1]]))
396
397cur_tensors = (torch.stack(all_input_ids),
398torch.stack(all_input_ids_org),
399torch.stack(all_input_mask),
400torch.stack(all_segment_ids),
401torch.stack(all_lm_labels_ids),
402torch.stack(all_is_next),
403torch.stack(all_tail_idxs),
404torch.stack(all_sentence_labels))
405
406all_cur_tensors.append(cur_tensors)
407
408return all_cur_tensors
409
410
411
412def AugmentationData_Domain(top_k, tokenizer, max_seq_length):
413#top_k_shape = top_k.indices.shape
414#ids = top_k.indices.reshape(top_k_shape[0]*top_k_shape[1]).tolist()
415top_k_shape = top_k["indices"].shape
416ids = top_k["indices"].reshape(top_k_shape[0]*top_k_shape[1]).tolist()
417
418all_input_ids = list()
419all_input_ids_org = list()
420all_input_mask = list()
421all_segment_ids = list()
422all_lm_labels_ids = list()
423all_is_next = list()
424all_tail_idxs = list()
425
426for id, i in enumerate(ids):
427t1 = data[str(i)]['sentence']
428
429#tokens_a = tokenizer.tokenize(t1)
430tokens_a = tokenizer.tokenize(t1)
431'''
432if "</s>" in tokens_a:
433print("Have more than 1 </s>")
434#tokens_a[tokens_a.index("<s>")] = "s"
435for i in range(len(tokens_a)):
436if tokens_a[i] == "</s>":
437tokens_a[i] = "s"
438'''
439
440# tokenize
441cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
442
443# transform sample to features
444cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
445
446all_input_ids.append(torch.tensor(cur_features.input_ids))
447all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
448all_input_mask.append(torch.tensor(cur_features.input_mask))
449all_segment_ids.append(torch.tensor(cur_features.segment_ids))
450all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
451all_is_next.append(torch.tensor(0))
452all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
453
454
455cur_tensors = (torch.stack(all_input_ids),
456torch.stack(all_input_ids_org),
457torch.stack(all_input_mask),
458torch.stack(all_segment_ids),
459torch.stack(all_lm_labels_ids),
460torch.stack(all_is_next),
461torch.stack(all_tail_idxs))
462
463return cur_tensors
464
465
466def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None):
467top_k_shape = top_k["indices"].shape
468sentence_ids = top_k["indices"]
469
470all_input_ids = list()
471all_input_ids_org = list()
472all_input_mask = list()
473all_segment_ids = list()
474all_lm_labels_ids = list()
475all_is_next = list()
476all_tail_idxs = list()
477all_sentence_labels = list()
478all_sentiment_labels = list()
479
480add_org = tuple(t.to('cpu') for t in add_org)
481#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
482input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
483
484###
485#print("input_ids_",input_ids_.shape)
486#print("---")
487#print("sentence_ids",sentence_ids.shape)
488#print("---")
489#print("sentence_label_",sentence_label_.shape)
490#exit()
491
492
493for id_1, sent in enumerate(sentence_ids):
494for id_2, sent_id in enumerate(sent):
495
496t1 = data[str(int(sent_id))]['sentence']
497
498tokens_a = tokenizer.tokenize(t1)
499
500# tokenize
501cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
502
503# transform sample to features
504cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
505
506all_input_ids.append(torch.tensor(cur_features.input_ids))
507all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
508all_input_mask.append(torch.tensor(cur_features.input_mask))
509all_segment_ids.append(torch.tensor(cur_features.segment_ids))
510all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
511all_is_next.append(torch.tensor(0))
512all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
513all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
514all_sentiment_labels.append(torch.tensor(sentiment_label_[id_1]))
515'''
516#if len(sentence_label_) != len(sentence_label_):
517# print(len(sentence_label_) != len(sentence_label_))
518try:
519all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
520except:
521#all_sentence_labels.append(torch.tensor([0]))
522print(sentence_ids)
523print(sentence_label_)
524print("==========================")
525print("input_ids_",input_ids_.shape)
526print("---")
527print("sentence_ids",sentence_ids.shape)
528print("---")
529print("sentence_label_",sentence_label_.shape)
530exit()
531'''
532
533all_input_ids.append(input_ids_[id_1])
534all_input_ids_org.append(input_ids_org_[id_1])
535all_input_mask.append(input_mask_[id_1])
536all_segment_ids.append(segment_ids_[id_1])
537all_lm_labels_ids.append(lm_label_ids_[id_1])
538all_is_next.append(is_next_[id_1])
539all_tail_idxs.append(tail_idxs_[id_1])
540all_sentence_labels.append(sentence_label_[id_1])
541all_sentiment_labels.append(sentiment_label_[id_1])
542
543
544cur_tensors = (torch.stack(all_input_ids),
545torch.stack(all_input_ids_org),
546torch.stack(all_input_mask),
547torch.stack(all_segment_ids),
548torch.stack(all_lm_labels_ids),
549torch.stack(all_is_next),
550torch.stack(all_tail_idxs),
551torch.stack(all_sentence_labels),
552torch.stack(all_sentiment_labels)
553)
554
555
556return cur_tensors
557
558
559def AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None):
560'''
561top_k_shape = top_k.indices.shape
562sentence_ids = top_k.indices
563'''
564#top_k_shape = top_k["indices"].shape
565#sentence_ids = top_k["indices"]
566
567
568#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
569input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
570
571#print(sentence_label_.shape)
572
573all_sentence_binary_label = list()
574all_in_task_rep_comb = list()
575for id_1, num in enumerate(sentence_label_):
576#print([sentence_label_==num])
577#print(type([sentence_label_==num]))
578sentence_label_int = (sentence_label_==num).to(torch.long)
579#print(in_task_rep[id_1].shape)
580in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
581in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
582#print(in_task_rep_comb.shape)
583#exit()
584#sentence_label_int = sentence_label_int.to(torch.float32)
585#print(sentence_label_int)
586#exit()
587#all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
588#all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
589all_sentence_binary_label.append(sentence_label_int)
590all_in_task_rep_comb.append(in_task_rep_comb)
591all_sentence_binary_label = torch.stack(all_sentence_binary_label)
592all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
593
594cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
595
596return cur_tensors
597
598
599'''
600all_input_ids_batch = list()
601all_input_ids_org_batch = list()
602all_input_mask_batch = list()
603all_segment_ids_batch = list()
604all_lm_labels_ids_batch = list()
605all_is_next_batch = list()
606all_tail_idxs_batch = list()
607all_sentence_labels_batch = list()
608all_sentence_binary_label_batch = list()
609
610for id_1, sent in enumerate(sentence_ids):
611all_input_ids = list()
612all_input_ids_org = list()
613all_input_mask = list()
614all_segment_ids = list()
615all_lm_labels_ids = list()
616all_is_next = list()
617all_tail_idxs = list()
618all_sentence_labels = list()
619all_sentence_binary_label = list()
620
621#print(sent)
622#print(sent.shape)
623#exit()
624for id_2, sent_id in enumerate(sent):
625
626t1 = data[str(int(sent_id))]['sentence']
627
628#tokens_a = tokenizer.tokenize(t1)
629tokens_a = tokenizer.tokenize(t1)
630
631# tokenize
632cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
633
634# transform sample to features
635cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
636
637all_input_ids.append(torch.tensor(cur_features.input_ids))
638all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
639all_input_mask.append(torch.tensor(cur_features.input_mask))
640all_segment_ids.append(torch.tensor(cur_features.segment_ids))
641all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
642all_is_next.append(torch.tensor(0))
643all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
644all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
645#all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
646all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
647#all_sentence_binary_label = torch.tensor(all_sentence_binary_label)
648#print(all_sentence_binary_label)
649#print(all_sentence_binary_label[0].shape)
650#exit()
651
652all_input_ids_batch.append(torch.stack(all_input_ids))
653all_input_ids_org_batch.append(torch.stack(all_input_ids_org))
654all_input_mask_batch.append(torch.stack(all_input_mask))
655all_segment_ids_batch.append(torch.stack(all_segment_ids))
656all_lm_labels_ids_batch.append(torch.stack(all_lm_labels_ids))
657all_is_next_batch.append(torch.stack(all_is_next))
658all_tail_idxs_batch.append(torch.stack(all_tail_idxs))
659all_sentence_labels_batch.append(torch.stack(all_sentence_labels))
660all_sentence_binary_label_batch.append(torch.stack(all_sentence_binary_label))
661#print(all_sentence_binary_label_batch)
662#print(all_sentence_binary_label_batch[0].shape)
663#exit()
664#print("===")
665#print(all_sentence_binary_label_batch)
666#print(len(all_sentence_binary_label_batch), len(all_sentence_binary_label_batch[0]))
667#exit()
668
669
670
671cur_tensors = (torch.stack(all_input_ids_batch),
672torch.stack(all_input_ids_org_batch),
673torch.stack(all_input_mask_batch),
674torch.stack(all_segment_ids_batch),
675torch.stack(all_lm_labels_ids_batch),
676torch.stack(all_is_next_batch),
677torch.stack(all_tail_idxs_batch),
678torch.stack(all_sentence_labels_batch),
679torch.stack(all_sentence_binary_label_batch)
680)
681
682
683return cur_tensors
684'''
685
686
687
688class Dataset_noNext(Dataset):
689def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
690
691self.vocab_size = tokenizer.vocab_size
692self.tokenizer = tokenizer
693self.seq_len = seq_len
694self.on_memory = on_memory
695self.corpus_lines = corpus_lines # number of non-empty lines in input corpus
696self.corpus_path = corpus_path
697self.encoding = encoding
698self.current_doc = 0 # to avoid random sentence from same doc
699
700# for loading samples directly from file
701self.sample_counter = 0 # used to keep track of full epochs on file
702self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair
703
704# for loading samples in memory
705self.current_random_doc = 0
706self.num_docs = 0
707self.sample_to_doc = [] # map sample index to doc and line
708
709# load samples into memory
710if on_memory:
711self.all_docs = []
712doc = []
713self.corpus_lines = 0
714with open(corpus_path, "r", encoding=encoding) as f:
715for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
716line = line.strip()
717if line == "":
718self.all_docs.append(doc)
719doc = []
720#remove last added sample because there won't be a subsequent line anymore in the doc
721self.sample_to_doc.pop()
722else:
723#store as one sample
724sample = {"doc_id": len(self.all_docs),
725"line": len(doc)}
726self.sample_to_doc.append(sample)
727doc.append(line)
728self.corpus_lines = self.corpus_lines + 1
729
730# if last row in file is not empty
731if self.all_docs[-1] != doc:
732self.all_docs.append(doc)
733self.sample_to_doc.pop()
734
735self.num_docs = len(self.all_docs)
736
737# load samples later lazily from disk
738else:
739if self.corpus_lines is None:
740with open(corpus_path, "r", encoding=encoding) as f:
741self.corpus_lines = 0
742for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
743if line.strip() == "":
744self.num_docs += 1
745else:
746self.corpus_lines += 1
747
748# if doc does not end with empty line
749if line.strip() != "":
750self.num_docs += 1
751
752self.file = open(corpus_path, "r", encoding=encoding)
753self.random_file = open(corpus_path, "r", encoding=encoding)
754
755def __len__(self):
756# last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
757return self.corpus_lines - self.num_docs - 1
758
759def __getitem__(self, item):
760cur_id = self.sample_counter
761self.sample_counter += 1
762if not self.on_memory:
763# after one epoch we start again from beginning of file
764if cur_id != 0 and (cur_id % len(self) == 0):
765self.file.close()
766self.file = open(self.corpus_path, "r", encoding=self.encoding)
767
768#t1, t2, is_next_label = self.random_sent(item)
769t1, is_next_label = self.random_sent(item)
770if is_next_label == None:
771is_next_label = 0
772
773
774#tokens_a = self.tokenizer.tokenize(t1)
775tokens_a = tokenizer.tokenize(t1)
776'''
777if "</s>" in tokens_a:
778print("Have more than 1 </s>")
779#tokens_a[tokens_a.index("<s>")] = "s"
780for i in range(len(tokens_a)):
781if tokens_a[i] == "</s>":
782tokens_a[i] = "s"
783'''
784#tokens_b = self.tokenizer.tokenize(t2)
785
786# tokenize
787cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
788
789# transform sample to features
790cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
791
792cur_tensors = (torch.tensor(cur_features.input_ids),
793torch.tensor(cur_features.input_ids_org),
794torch.tensor(cur_features.input_mask),
795torch.tensor(cur_features.segment_ids),
796torch.tensor(cur_features.lm_label_ids),
797torch.tensor(cur_features.is_next),
798torch.tensor(cur_features.tail_idxs))
799
800return cur_tensors
801
802def random_sent(self, index):
803"""
804Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
805from one doc. With 50% the second sentence will be a random one from another doc.
806:param index: int, index of sample.
807:return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
808"""
809t1, t2 = self.get_corpus_line(index)
810return t1, None
811
812def get_corpus_line(self, item):
813"""
814Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
815:param item: int, index of sample.
816:return: (str, str), two subsequent sentences from corpus
817"""
818t1 = ""
819t2 = ""
820assert item < self.corpus_lines
821if self.on_memory:
822sample = self.sample_to_doc[item]
823t1 = self.all_docs[sample["doc_id"]][sample["line"]]
824# used later to avoid random nextSentence from same doc
825self.current_doc = sample["doc_id"]
826return t1, t2
827#return t1
828else:
829if self.line_buffer is None:
830# read first non-empty line of file
831while t1 == "" :
832t1 = next(self.file).strip()
833else:
834# use t2 from previous iteration as new t1
835t1 = self.line_buffer
836# skip empty rows that are used for separating documents and keep track of current doc id
837while t1 == "":
838t1 = next(self.file).strip()
839self.current_doc = self.current_doc+1
840self.line_buffer = next(self.file).strip()
841
842assert t1 != ""
843return t1, t2
844
845
846def get_random_line(self):
847"""
848Get random line from another document for nextSentence task.
849:return: str, content of one line
850"""
851# Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
852# corpora. However, just to be careful, we try to make sure that
853# the random document is not the same as the document we're processing.
854for _ in range(10):
855if self.on_memory:
856rand_doc_idx = random.randint(0, len(self.all_docs)-1)
857rand_doc = self.all_docs[rand_doc_idx]
858line = rand_doc[random.randrange(len(rand_doc))]
859else:
860rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
861#pick random line
862for _ in range(rand_index):
863line = self.get_next_line()
864#check if our picked random line is really from another doc like we want it to be
865if self.current_random_doc != self.current_doc:
866break
867return line
868
869def get_next_line(self):
870""" Gets next line of random_file and starts over when reaching end of file"""
871try:
872line = next(self.random_file).strip()
873#keep track of which document we are currently looking at to later avoid having the same doc as t1
874if line == "":
875self.current_random_doc = self.current_random_doc + 1
876line = next(self.random_file).strip()
877except StopIteration:
878self.random_file.close()
879self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
880line = next(self.random_file).strip()
881return line
882
883
884class InputExample(object):
885"""A single training/test example for the language model."""
886
887def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
888"""Constructs a InputExample.
889Args:
890guid: Unique id for the example.
891tokens_a: string. The untokenized text of the first sequence. For single
892sequence tasks, only this sequence must be specified.
893tokens_b: (Optional) string. The untokenized text of the second sequence.
894Only must be specified for sequence pair tasks.
895label: (Optional) string. The label of the example. This should be
896specified for train and dev examples, but not for test examples.
897"""
898self.guid = guid
899self.tokens_a = tokens_a
900self.tokens_b = tokens_b
901self.is_next = is_next # nextSentence
902self.lm_labels = lm_labels # masked words for language model
903
904
905class InputFeatures(object):
906"""A single set of features of data."""
907
908def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
909self.input_ids = input_ids
910self.input_ids_org = input_ids_org
911self.input_mask = input_mask
912self.segment_ids = segment_ids
913self.is_next = is_next
914self.lm_label_ids = lm_label_ids
915self.tail_idxs = tail_idxs
916
917
918def random_word(tokens, tokenizer):
919"""
920Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
921:param tokens: list of str, tokenized sentence.
922:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
923:return: (list of str, list of int), masked tokens and related labels for LM prediction
924"""
925output_label = []
926
927for i, token in enumerate(tokens):
928
929prob = random.random()
930# mask token with 15% probability
931if prob < 0.15:
932prob /= 0.15
933#candidate_id = random.randint(0,tokenizer.vocab_size)
934#print(tokenizer.convert_ids_to_tokens(candidate_id))
935
936
937# 80% randomly change token to mask token
938if prob < 0.8:
939#tokens[i] = "[MASK]"
940tokens[i] = "<mask>"
941
942# 10% randomly change token to random token
943elif prob < 0.9:
944#tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
945#tokens[i] = tokenizer.convert_ids_to_tokens(candidate_id)
946candidate_id = random.randint(0,tokenizer.vocab_size)
947w = tokenizer.convert_ids_to_tokens(candidate_id)
948'''
949if tokens[i] == None:
950candidate_id = 100
951w = tokenizer.convert_ids_to_tokens(candidate_id)
952'''
953tokens[i] = w
954
955
956# -> rest 10% randomly keep current token
957
958# append current token to output (we will predict these later)
959try:
960#output_label.append(tokenizer.vocab[token])
961w = tokenizer.convert_tokens_to_ids(token)
962if w!= None:
963output_label.append(w)
964else:
965print("Have no this tokens in ids")
966exit()
967except KeyError:
968# For unknown words (should not occur with BPE vocab)
969#output_label.append(tokenizer.vocab["<unk>"])
970w = tokenizer.convert_tokens_to_ids("<unk>")
971output_label.append(w)
972logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
973else:
974# no masking token (will be ignored by loss function later)
975output_label.append(-1)
976
977return tokens, output_label
978
979
980def convert_example_to_features(example, max_seq_length, tokenizer):
981"""
982Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
983IDs, LM labels, input_mask, CLS and SEP tokens etc.
984:param example: InputExample, containing sentence input as strings and is_next label
985:param max_seq_length: int, maximum length of sequence.
986:param tokenizer: Tokenizer
987:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
988"""
989#now tokens_a is input_ids
990tokens_a = example.tokens_a
991tokens_b = example.tokens_b
992# Modifies `tokens_a` and `tokens_b` in place so that the total
993# length is less than the specified length.
994# Account for [CLS], [SEP], [SEP] with "- 3"
995#_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
996_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
997
998#print(tokens_a)
999tokens_a_org = tokens_a.copy()
1000tokens_a, t1_label = random_word(tokens_a, tokenizer)
1001#print("----")
1002#print(tokens_a)
1003#print(tokens_a_org)
1004#exit()
1005#print(t1_label)
1006#exit()
1007#tokens_b, t2_label = random_word(tokens_b, tokenizer)
1008# concatenate lm labels and account for CLS, SEP, SEP
1009#lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
1010lm_label_ids = ([-1] + t1_label + [-1])
1011
1012# The convention in BERT is:
1013# (a) For sequence pairs:
1014# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
1015# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
1016# (b) For single sequences:
1017# tokens: [CLS] the dog is hairy . [SEP]
1018# type_ids: 0 0 0 0 0 0 0
1019#
1020# Where "type_ids" are used to indicate whether this is the first
1021# sequence or the second sequence. The embedding vectors for `type=0` and
1022# `type=1` were learned during pre-training and are added to the wordpiece
1023# embedding vector (and position vector). This is not *strictly* necessary
1024# since the [SEP] token unambigiously separates the sequences, but it makes
1025# it easier for the model to learn the concept of sequences.
1026#
1027# For classification tasks, the first vector (corresponding to [CLS]) is
1028# used as as the "sentence vector". Note that this only makes sense because
1029# the entire model is fine-tuned.
1030tokens = []
1031tokens_org = []
1032segment_ids = []
1033tokens.append("<s>")
1034tokens_org.append("<s>")
1035segment_ids.append(0)
1036for i, token in enumerate(tokens_a):
1037if token!="</s>":
1038tokens.append(tokens_a[i])
1039tokens_org.append(tokens_a_org[i])
1040segment_ids.append(0)
1041else:
1042tokens.append("s")
1043tokens_org.append("s")
1044segment_ids.append(0)
1045tokens.append("</s>")
1046tokens_org.append("</s>")
1047segment_ids.append(0)
1048
1049#tokens.append("[SEP]")
1050#segment_ids.append(1)
1051
1052#input_ids = tokenizer.convert_tokens_to_ids(tokens)
1053input_ids = tokenizer.encode(tokens, add_special_tokens=False)
1054input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
1055tail_idxs = len(input_ids)+1
1056
1057#print(input_ids)
1058input_ids = [w if w!=None else 0 for w in input_ids]
1059input_ids_org = [w if w!=None else 0 for w in input_ids_org]
1060#print(input_ids)
1061#exit()
1062
1063# The mask has 1 for real tokens and 0 for padding tokens. Only real
1064# tokens are attended to.
1065input_mask = [1] * len(input_ids)
1066
1067# Zero-pad up to the sequence length.
1068pad_id = tokenizer.convert_tokens_to_ids("<pad>")
1069while len(input_ids) < max_seq_length:
1070input_ids.append(pad_id)
1071input_ids_org.append(pad_id)
1072input_mask.append(0)
1073segment_ids.append(0)
1074lm_label_ids.append(-1)
1075
1076
1077assert len(input_ids) == max_seq_length
1078assert len(input_ids_org) == max_seq_length
1079assert len(input_mask) == max_seq_length
1080assert len(segment_ids) == max_seq_length
1081assert len(lm_label_ids) == max_seq_length
1082
1083'''
1084if example.guid < 5:
1085logger.info("*** Example ***")
1086logger.info("guid: %s" % (example.guid))
1087logger.info("tokens: %s" % " ".join(
1088[str(x) for x in tokens]))
1089logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
1090logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
1091logger.info(
1092"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
1093logger.info("LM label: %s " % (lm_label_ids))
1094logger.info("Is next sentence label: %s " % (example.is_next))
1095'''
1096
1097features = InputFeatures(input_ids=input_ids,
1098input_ids_org = input_ids_org,
1099input_mask=input_mask,
1100segment_ids=segment_ids,
1101lm_label_ids=lm_label_ids,
1102is_next=example.is_next,
1103tail_idxs=tail_idxs)
1104return features
1105
1106
1107def main():
1108parser = argparse.ArgumentParser()
1109
1110## Required parameters
1111parser.add_argument("--data_dir_indomain",
1112default=None,
1113type=str,
1114required=True,
1115help="The input train corpus.(In Domain)")
1116parser.add_argument("--data_dir_outdomain",
1117default=None,
1118type=str,
1119required=True,
1120help="The input train corpus.(Out Domain)")
1121parser.add_argument("--pretrain_model", default=None, type=str, required=True,
1122help="Bert pre-trained model selected in the list: bert-base-uncased, "
1123"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
1124parser.add_argument("--output_dir",
1125default=None,
1126type=str,
1127required=True,
1128help="The output directory where the model checkpoints will be written.")
1129parser.add_argument("--augment_times",
1130default=None,
1131type=int,
1132required=True,
1133help="Default batch_size/augment_times to save model")
1134## Other parameters
1135parser.add_argument("--max_seq_length",
1136default=128,
1137type=int,
1138help="The maximum total input sequence length after WordPiece tokenization. \n"
1139"Sequences longer than this will be truncated, and sequences shorter \n"
1140"than this will be padded.")
1141parser.add_argument("--do_train",
1142action='store_true',
1143help="Whether to run training.")
1144parser.add_argument("--train_batch_size",
1145default=32,
1146type=int,
1147help="Total batch size for training.")
1148parser.add_argument("--learning_rate",
1149default=3e-5,
1150type=float,
1151help="The initial learning rate for Adam.")
1152parser.add_argument("--num_train_epochs",
1153default=3.0,
1154type=float,
1155help="Total number of training epochs to perform.")
1156parser.add_argument("--warmup_proportion",
1157default=0.1,
1158type=float,
1159help="Proportion of training to perform linear learning rate warmup for. "
1160"E.g., 0.1 = 10%% of training.")
1161parser.add_argument("--no_cuda",
1162action='store_true',
1163help="Whether not to use CUDA when available")
1164parser.add_argument("--on_memory",
1165action='store_true',
1166help="Whether to load train samples into memory or use disk")
1167parser.add_argument("--do_lower_case",
1168action='store_true',
1169help="Whether to lower case the input text. True for uncased models, False for cased models.")
1170parser.add_argument("--local_rank",
1171type=int,
1172default=-1,
1173help="local_rank for distributed training on gpus")
1174parser.add_argument('--seed',
1175type=int,
1176default=42,
1177help="random seed for initialization")
1178parser.add_argument('--gradient_accumulation_steps',
1179type=int,
1180default=1,
1181help="Number of updates steps to accumualte before performing a backward/update pass.")
1182parser.add_argument('--fp16',
1183action='store_true',
1184help="Whether to use 16-bit float precision instead of 32-bit")
1185parser.add_argument('--loss_scale',
1186type = float, default = 0,
1187help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
1188"0 (default value): dynamic loss scaling.\n"
1189"Positive power of 2: static loss scaling value.\n")
1190####
1191parser.add_argument("--num_labels_task",
1192default=None, type=int,
1193required=True,
1194help="num_labels_task")
1195parser.add_argument("--weight_decay",
1196default=0.0,
1197type=float,
1198help="Weight decay if we apply some.")
1199parser.add_argument("--adam_epsilon",
1200default=1e-8,
1201type=float,
1202help="Epsilon for Adam optimizer.")
1203parser.add_argument("--max_grad_norm",
1204default=1.0,
1205type=float,
1206help="Max gradient norm.")
1207parser.add_argument('--fp16_opt_level',
1208type=str,
1209default='O1',
1210help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
1211"See details at https://nvidia.github.io/apex/amp.html")
1212parser.add_argument("--task",
1213default=None,
1214type=int,
1215required=True,
1216help="Choose Task")
1217####
1218
1219args = parser.parse_args()
1220
1221if args.local_rank == -1 or args.no_cuda:
1222device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1223n_gpu = torch.cuda.device_count()
1224else:
1225torch.cuda.set_device(args.local_rank)
1226device = torch.device("cuda", args.local_rank)
1227n_gpu = 1
1228# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
1229torch.distributed.init_process_group(backend='nccl')
1230logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
1231device, n_gpu, bool(args.local_rank != -1), args.fp16))
1232
1233if args.gradient_accumulation_steps < 1:
1234raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
1235args.gradient_accumulation_steps))
1236
1237args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
1238
1239random.seed(args.seed)
1240np.random.seed(args.seed)
1241torch.manual_seed(args.seed)
1242if n_gpu > 0:
1243torch.cuda.manual_seed_all(args.seed)
1244
1245if not args.do_train:
1246raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
1247
1248if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
1249raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
1250if not os.path.exists(args.output_dir):
1251os.makedirs(args.output_dir)
1252
1253#tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model, do_lower_case=args.do_lower_case)
1254tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)
1255
1256
1257#train_examples = None
1258num_train_optimization_steps = None
1259if args.do_train:
1260print("Loading Train Dataset", args.data_dir_indomain)
1261#train_dataset = Dataset_noNext(args.data_dir, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory)
1262all_type_sentence, train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
1263num_train_optimization_steps = int(
1264len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
1265if args.local_rank != -1:
1266num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
1267
1268
1269
1270# Prepare model
1271model = RobertaForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1272#model = RobertaForSequenceClassification.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1273model.to(device)
1274
1275
1276
1277# Prepare optimizer
1278if args.do_train:
1279param_optimizer = list(model.named_parameters())
1280'''
1281for par in param_optimizer:
1282print(par[0])
1283exit()
1284'''
1285no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
1286optimizer_grouped_parameters = [
1287{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
1288{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
1289]
1290optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
1291scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
1292
1293if args.fp16:
1294try:
1295from apex import amp
1296except ImportError:
1297raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
1298exit()
1299
1300model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
1301
1302
1303if n_gpu > 1:
1304model = torch.nn.DataParallel(model)
1305
1306if args.local_rank != -1:
1307model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
1308
1309
1310
1311global_step = 0
1312if args.do_train:
1313logger.info("***** Running training *****")
1314logger.info(" Num examples = %d", len(train_dataset))
1315logger.info(" Batch size = %d", args.train_batch_size)
1316logger.info(" Num steps = %d", num_train_optimization_steps)
1317
1318if args.local_rank == -1:
1319train_sampler = RandomSampler(train_dataset)
1320#all_type_sentence_sampler = RandomSampler(all_type_sentence)
1321else:
1322#TODO: check if this works with current data generator from disk that relies on next(file)
1323# (it doesn't return item back by index)
1324train_sampler = DistributedSampler(train_dataset)
1325#all_type_sentence_sampler = DistributedSampler(all_type_sentence)
1326train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1327#all_type_sentence_dataloader = DataLoader(all_type_sentence, sampler=all_type_sentence_sampler, batch_size=len(all_type_sentence_label))
1328
1329output_loss_file = os.path.join(args.output_dir, "loss")
1330loss_fout = open(output_loss_file, 'w')
1331
1332
1333output_loss_file_no_pseudo = os.path.join(args.output_dir, "loss_no_pseudo")
1334loss_fout_no_pseudo = open(output_loss_file_no_pseudo, 'w')
1335model.train()
1336
1337
1338#print(model.parameters)
1339#print(model.modules.module)
1340#print(model.modules.module.RobertaForMaskedLMDomainTask)
1341#print(list(model.named_parameters()))
1342#print([i for i,j in model.named_parameters()])
1343#print(model.parameters["RobertaForMaskedLMDomainTask"])
1344
1345##Need to confirm use input_ids or input_ids_org !!!!!!!!
1346###
1347#[10000000, 13, 768] ---> [1000000, 768] --> [1,,] --> [batch_size,,]
1348###
1349###
1350###
1351#print(docs.shape)
1352#exit()
1353#docs = docs[:,0,:]
1354#docs = docs.unsqueeze(0)
1355#docs = docs.expand(batch_size, -1, -1)
1356
1357#################
1358#################
1359#alpha = float(1/(args.num_train_epochs*len(train_dataloader)))
1360alpha = float(1/args.num_train_epochs)
1361
1362#print(docs.shape)
1363#([1000000, 13, 768]) -> yelp
1364#([5474, 1, 768]) -> open domain
1365
1366#Test
1367#docs = load_GeneralDomain_docs("data/open_domain_preprocessed_roberta/")
1368#
1369'''
1370docs = load_GeneralDomain_docs("data/yelp/")
1371if docs.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
1372docs = docs[:,0,:].unsqueeze(1)
1373print(docs.shape)
1374else:
1375print(docs.shape)
1376'''
1377### All label rank (first)
1378#1.train --> classifier (after 1 epoch)
1379#2.all_label in batch_size --> the same number as batch_size
1380#3.reduce variable
1381#4.
1382
1383
1384k=8
1385#k=16
1386all_type_sentence_label = list()
1387all_previous_sentence_label = list()
1388all_type_sentiment_label = list()
1389all_previous_sentiment_label = list()
1390top_k_all_type = dict()
1391bottom_k_all_type = dict()
1392for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1393tr_loss = 0
1394nb_tr_examples, nb_tr_steps = 0, 0
1395for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1396
1397#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = batch_
1398#print(input_ids_.shape)
1399#exit()
1400
1401#######################
1402######################
1403###Init 8 type sentence
1404###Init 2 type sentiment
1405if (step == 0) and (epo == 0):
1406#batch_ = tuple(t.to(device) for t in batch_)
1407#all_type_sentence_ = tuple(t.to(device) for t in all_type_sentence)
1408
1409input_ids_ = torch.stack([line[0] for line in all_type_sentence]).to(device)
1410input_ids_org_ = torch.stack([line[1] for line in all_type_sentence]).to(device)
1411input_mask_ = torch.stack([line[2] for line in all_type_sentence]).to(device)
1412segment_ids_ = torch.stack([line[3] for line in all_type_sentence]).to(device)
1413lm_label_ids_ = torch.stack([line[4] for line in all_type_sentence]).to(device)
1414is_next_ = torch.stack([line[5] for line in all_type_sentence]).to(device)
1415tail_idxs_ = torch.stack([line[6] for line in all_type_sentence]).to(device)
1416sentence_label_ = torch.stack([line[7] for line in all_type_sentence]).to(device)
1417sentiment_label_ = torch.stack([line[8] for line in all_type_sentence]).to(device)
1418
1419#print(sentence_label_)
1420#print(sentiment_label_)
1421#exit()
1422
1423with torch.no_grad():
1424'''
1425in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1426# Search id from Docs and ranking via (Domain/Task)
1427query_domain = in_domain_rep.float().to("cpu")
1428query_domain = query_domain.unsqueeze(1)
1429query_task = in_task_rep.float().to("cpu")
1430query_task = query_task.unsqueeze(1)
1431'''
1432
1433in_domain_rep_mean, in_task_rep_mean = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep_mean")
1434# Search id from Docs and ranking via (Domain/Task)
1435query_domain = in_domain_rep_mean.float().to("cpu")
1436query_domain = query_domain.unsqueeze(1)
1437query_task = in_task_rep_mean.float().to("cpu")
1438query_task = query_task.unsqueeze(1)
1439
1440######Attend to a certain layer
1441'''
1442results_domain = torch.matmul(query_domain,docs[-1,:,:].T)
1443results_task = torch.matmul(query_task,docs[-1,:,:].T)
1444'''
1445######
1446######Attend to all 13 layers
1447'''
1448start = time.time()
1449results_domain = torch.matmul(docs, query_domain.transpose(0,1))
1450domain_attention =
1451results_domain = results_domain.transpose(1,2).transpose(0,1).sum(2)
1452results_task = torch.matmul(docs, query_task.transpose(0,1))
1453task_attention =
1454results_task = results_task.transpose(1,2).transpose(0,1).sum(2)
1455end = time.time()
1456print("Time:", (end-start)/60)
1457'''
1458
1459#start = time.time()
1460#docs: [batch_size, 1000000, 768]
1461#query: [batch_size, 1, 768]
1462
1463#docs = docs[:,0,:]
1464#docs = docs.unsqueeze(0)
1465#docs = docs.expand(batch_size, -1, -1)
1466
1467task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1468task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1469task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1470task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1471
1472
1473domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1474domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1475domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1476domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768, 2)
1477
1478#start = time.time()
1479query_domain = query_domain.expand(-1, docs.shape[0], -1)
1480query_task = query_domain.expand(-1, docs.shape[0], -1)
1481
1482#################
1483#################
1484#Ranking
1485
1486LeakyReLU = torch.nn.LeakyReLU()
1487#Domain logit
1488domain_binary_logit = LeakyReLU(domain_binary_classifier(docs))
1489domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1490#domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentence_label_.shape[0], -1)
1491domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1492#Task logit
1493#task_binary_logit = LeakyReLU(task_binary_classifier(torch.cat([query_task, docs[:,0,:].unsqueeze(0).expand(sentence_label_.shape[0], -1, -1)], dim=2)))
1494task_binary_logit = LeakyReLU(task_binary_classifier(torch.cat([query_task, docs[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2)))
1495task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1496
1497#end = time.time()
1498#print("Time:", (end-start)/60)
1499######
1500results_all_type = domain_binary_logit + task_binary_logit
1501del domain_binary_logit, task_binary_logit
1502bottom_k_all_type = torch.topk(results_all_type, k, dim=1, largest=False, sorted=False)
1503top_k_all_type = torch.topk(results_all_type, k, dim=1, largest=True, sorted=False)
1504del results_all_type
1505#all_type_sentence_label = sentence_label_.to('cpu')
1506all_type_sentiment_label = sentiment_label_.to('cpu')
1507#print("--")
1508#print(bottom_k_all_type.values)
1509#print("--")
1510#exit()
1511bottom_k_all_type = {"values":bottom_k_all_type.values, "indices":bottom_k_all_type.indices}
1512top_k_all_type = {"values":top_k_all_type.values, "indices":top_k_all_type.indices}
1513
1514######################
1515######################
1516
1517
1518
1519###Normal mode
1520batch_ = tuple(t.to(device) for t in batch_)
1521#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = batch_
1522input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = batch_
1523
1524
1525###
1526# Generate query representation
1527in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1528
1529in_domain_rep_mean, in_task_rep_mean = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep_mean")
1530
1531#if (step%10 == 0) or (sentence_label_.shape[0] != args.train_batch_size):
1532if (step%10 == 0) or (sentiment_label_.shape[0] != args.train_batch_size):
1533with torch.no_grad():
1534# Search id from Docs and ranking via (Domain/Task)
1535query_domain = in_domain_rep.float().to("cpu")
1536query_domain = query_domain.unsqueeze(1)
1537query_task = in_task_rep.float().to("cpu")
1538query_task = query_task.unsqueeze(1)
1539
1540######Attend to a certain layer
1541'''
1542results_domain = torch.matmul(query_domain,docs[-1,:,:].T)
1543results_task = torch.matmul(query_task,docs[-1,:,:].T)
1544'''
1545######
1546######Attend to all 13 layers
1547'''
1548start = time.time()
1549results_domain = torch.matmul(docs, query_domain.transpose(0,1))
1550domain_attention =
1551results_domain = results_domain.transpose(1,2).transpose(0,1).sum(2)
1552results_task = torch.matmul(docs, query_task.transpose(0,1))
1553task_attention =
1554results_task = results_task.transpose(1,2).transpose(0,1).sum(2)
1555end = time.time()
1556print("Time:", (end-start)/60)
1557'''
1558
1559#start = time.time()
1560#docs: [batch_size, 1000000, 768]
1561#query: [batch_size, 1, 768]
1562
1563#docs = docs[:,0,:]
1564#docs = docs.unsqueeze(0)
1565#docs = docs.expand(batch_size, -1, -1)
1566
1567task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1568task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1569task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1570task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1571
1572
1573domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1574domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1575domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1576domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768, 2)
1577
1578#start = time.time()
1579query_domain = query_domain.expand(-1, docs.shape[0], -1)
1580query_task = query_domain.expand(-1, docs.shape[0], -1)
1581
1582#################
1583#################
1584#Ranking
1585
1586LeakyReLU = torch.nn.LeakyReLU()
1587#Domain logit
1588domain_binary_logit = LeakyReLU(domain_binary_classifier(docs))
1589domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1590#domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentence_label_.shape[0], -1)
1591domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1592#Task logit
1593#task_binary_logit = LeakyReLU(task_binary_classifier(torch.cat([query_task, docs[:,0,:].unsqueeze(0).expand(sentence_label_.shape[0], -1, -1)], dim=2)))
1594task_binary_logit = LeakyReLU(task_binary_classifier(torch.cat([query_task, docs[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2)))
1595task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1596
1597#end = time.time()
1598#print("Time:", (end-start)/60)
1599######
1600results = domain_binary_logit + task_binary_logit
1601del domain_binary_logit, task_binary_logit
1602bottom_k = torch.topk(results, k, dim=1, largest=False, sorted=False)
1603bottom_k = {"values":bottom_k.values, "indices":bottom_k.indices}
1604top_k = torch.topk(results, k, dim=1, largest=True, sorted=False)
1605top_k = {"values":top_k.values, "indices":top_k.indices}
1606del results
1607
1608#all_previous_sentence_label = sentence_label_.to('cpu')
1609all_previous_sentiment_label = sentiment_label_.to('cpu')
1610
1611#print(bottom_k.values)
1612#print(bottom_k["values"])
1613#print("==")
1614#print(bottom_k_all_type.values)
1615#exit()
1616
1617#print(torch.cat((bottom_k.values, bottom_k_all_type.values)))
1618#exit()
1619bottom_k_previous = {"values":torch.cat((bottom_k["values"], bottom_k_all_type["values"]),0), "indices":torch.cat((bottom_k["indices"], bottom_k_all_type["indices"]),0)}
1620top_k_previous = {"values":torch.cat((top_k["values"], top_k_all_type["values"]),0), "indices":torch.cat((top_k["indices"], top_k_all_type["indices"]),0)}
1621#all_previous_sentence_label = torch.cat((all_previous_sentence_label, all_type_sentence_label))
1622#print("=====")
1623#print(all_previous_sentiment_label.shape)
1624#print(all_type_sentiment_label.shape)
1625#print("=====")
1626#exit()
1627all_previous_sentiment_label = torch.cat((all_previous_sentiment_label, all_type_sentiment_label))
1628
1629else:
1630#print("all_type_sentence_label",all_type_sentence_label) #fix
1631#print("all_previous_sentence_label",all_previous_sentence_label) #prev
1632#print("sentence_label_",sentence_label_) #present
1633
1634#used_idx = torch.tensor([random.choice(((all_previous_sentence_label==int(idx_)).nonzero()).tolist())[0] for idx_ in sentence_label_])
1635used_idx = torch.tensor([random.choice(((all_previous_sentiment_label==int(idx_)).nonzero()).tolist())[0] for idx_ in sentiment_label_])
1636top_k = {"values":top_k_previous["values"].index_select(0,used_idx), "indices":top_k_previous["indices"].index_select(0,used_idx)}
1637
1638bottom_k = {"values":bottom_k_previous["values"].index_select(0,used_idx), "indices":bottom_k_previous["indices"].index_select(0,used_idx)}
1639#random.choice(((all_previous_sentence_label==id_).nonzero()).tolist()[0]) for id_ in
1640
1641
1642#################
1643#################
1644#Train Domain Binary Classifier
1645#Domain
1646#pos: n ; neg:k*n
1647#bottom_k
1648#Use sample!!! at first
1649#bottom_k = torch.topk(results, k, dim=1, largest=False, sorted=False)
1650batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1651batch = tuple(t.to(device) for t in batch)
1652input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs = batch
1653#domain_binary_loss, domain_binary_logit = model(input_ids_org=input_ids_org, masked_lm_labels=lm_label_ids, attention_mask=input_mask, func="domain_binary_classifier", in_domain_rep=in_domain_rep.to(device))
1654domain_binary_loss, domain_binary_logit = model(input_ids_org=input_ids_org, masked_lm_labels=lm_label_ids, attention_mask=input_mask, func="domain_binary_classifier_mean", in_domain_rep=in_domain_rep_mean.to(device))
1655
1656#################
1657#################
1658#Train Task Binary Classifier
1659#Pseudo Task --> Won't bp to PLM: only train classifier [In domain data]
1660#batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep)
1661batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep_mean)
1662batch = tuple(t.to(device) for t in batch)
1663all_in_task_rep_comb, all_sentence_binary_label = batch
1664#task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier")
1665task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier_mean")
1666
1667
1668#################
1669#################
1670#Only train Task classifier
1671#Task
1672#top_k
1673#top_k = torch.topk(results, k, dim=1, largest=True, sorted=False)
1674batch = AugmentationData_Task(top_k, tokenizer, args.max_seq_length, add_org=batch_)
1675batch = tuple(t.to(device) for t in batch)
1676#input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label = batch
1677input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1678#split into: in_dom and query_ --> different weight
1679#task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentence_label_, attention_mask=input_mask_, func="task_class")
1680task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1681
1682if epo > 2:
1683#task_loss_query, class_logit_query = model(input_ids_org=input_ids_org, sentence_label=sentence_label, attention_mask=input_mask, func="task_class")
1684task_loss_query, class_logit_query = model(input_ids_org=input_ids_org, sentence_label=sentiment_label, attention_mask=input_mask, func="task_class")
1685else:
1686task_loss_query = torch.tensor([0.0])
1687
1688#loss = task_loss_org + (task_loss_query*alpha*epo*step)/k
1689#loss = domain_binary_loss + task_binary_loss + task_loss_org + (task_loss_query*alpha*epo*step)/k
1690
1691##############################
1692##############################
1693
1694if n_gpu > 1:
1695#loss = loss.mean() # mean() to average on multi-gpu.
1696#loss = domain_binary_loss.mean() + task_binary_loss.mean() + task_loss_org.mean() + (task_loss_query.mean()*alpha*epo*step)/k
1697
1698#pseudo = (task_loss_query.mean()*alpha*epo*step)/k
1699#pseudo = (task_loss_query.mean()*alpha*epo*step)
1700pseudo = (task_loss_query.mean()*alpha*epo)
1701loss = domain_binary_loss.mean() + task_binary_loss.mean() + task_loss_org.mean() + pseudo
1702
1703if args.gradient_accumulation_steps > 1:
1704loss = loss / args.gradient_accumulation_steps
1705if args.fp16:
1706#optimizer.backward(loss)
1707with amp.scale_loss(loss, optimizer) as scaled_loss:
1708scaled_loss.backward()
1709else:
1710loss.backward()
1711
1712###
1713loss_fout.write("{}\n".format(loss.item()))
1714###
1715
1716###
1717loss_fout_no_pseudo.write("{}\n".format(loss.item()-pseudo.item()))
1718###
1719
1720tr_loss += loss.item()
1721#nb_tr_examples += input_ids.size(0)
1722nb_tr_examples += input_ids_.size(0)
1723nb_tr_steps += 1
1724if (step + 1) % args.gradient_accumulation_steps == 0:
1725if args.fp16:
1726# modify learning rate with special warm up BERT uses
1727# if args.fp16 is False, BertAdam is used that handles this automatically
1728#lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
1729torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1730###
1731else:
1732torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1733###
1734
1735optimizer.step()
1736###
1737scheduler.step()
1738###
1739#optimizer.zero_grad()
1740model.zero_grad()
1741global_step += 1
1742
1743
1744
1745model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1746output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1747torch.save(model_to_save.state_dict(), output_model_file)
1748####
1749'''
1750#if args.num_train_epochs/args.augment_times in [1,2,3]:
1751if (args.num_train_epochs/(args.augment_times/5))%5 == 0:
1752model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1753output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1754torch.save(model_to_save.state_dict(), output_model_file)
1755'''
1756####
1757
1758loss_fout.close()
1759
1760# Save a trained model
1761logger.info("** ** * Saving fine - tuned model ** ** * ")
1762model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1763output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1764if args.do_train:
1765torch.save(model_to_save.state_dict(), output_model_file)
1766
1767
1768
1769def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1770"""Truncates a sequence pair in place to the maximum length."""
1771
1772# This is a simple heuristic which will always truncate the longer sequence
1773# one token at a time. This makes more sense than truncating an equal percent
1774# of tokens from each, since if one sequence is very short then each token
1775# that's truncated likely contains more information than a longer sequence.
1776while True:
1777#total_length = len(tokens_a) + len(tokens_b)
1778total_length = len(tokens_a)
1779if total_length <= max_length:
1780break
1781else:
1782tokens_a.pop()
1783
1784
1785def accuracy(out, labels):
1786outputs = np.argmax(out, axis=1)
1787return np.sum(outputs == labels)
1788
1789
1790if __name__ == "__main__":
1791main()
1792