16
"""BERT finetuning runner."""
18
from __future__ import absolute_import, division, print_function, unicode_literals
27
from torch.autograd import Variable
28
import torch.nn.functional as F
32
from torch.utils.data import DataLoader, Dataset, RandomSampler
33
from torch.utils.data.distributed import DistributedSampler
34
from tqdm import tqdm, trange
35
from torch.nn import CrossEntropyLoss
37
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
39
from transformers.modeling_roberta_updateRep_self import RobertaForMaskedLMDomainTask
40
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
42
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
43
datefmt='%m/%d/%Y %H:%M:%S',
45
logger = logging.getLogger(__name__)
48
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
51
def get_parameter(parser):
54
parser.add_argument("--data_dir_indomain",
58
help="The input train corpus.(In Domain)")
59
parser.add_argument("--data_dir_outdomain",
63
help="The input train corpus.(Out Domain)")
64
parser.add_argument("--pretrain_model", default=None, type=str, required=True,
65
help="Bert pre-trained model selected in the list: bert-base-uncased, "
66
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
67
parser.add_argument("--output_dir",
71
help="The output directory where the model checkpoints will be written.")
72
parser.add_argument("--augment_times",
76
help="Default batch_size/augment_times to save model")
78
parser.add_argument("--max_seq_length",
81
help="The maximum total input sequence length after WordPiece tokenization. \n"
82
"Sequences longer than this will be truncated, and sequences shorter \n"
83
"than this will be padded.")
84
parser.add_argument("--do_train",
86
help="Whether to run training.")
87
parser.add_argument("--train_batch_size",
90
help="Total batch size for training.")
91
parser.add_argument("--learning_rate",
94
help="The initial learning rate for Adam.")
95
parser.add_argument("--num_train_epochs",
98
help="Total number of training epochs to perform.")
99
parser.add_argument("--warmup_proportion",
102
help="Proportion of training to perform linear learning rate warmup for. "
103
"E.g., 0.1 = 10%% of training.")
104
parser.add_argument("--no_cuda",
106
help="Whether not to use CUDA when available")
107
parser.add_argument("--on_memory",
109
help="Whether to load train samples into memory or use disk")
110
parser.add_argument("--do_lower_case",
112
help="Whether to lower case the input text. True for uncased models, False for cased models.")
113
parser.add_argument("--local_rank",
116
help="local_rank for distributed training on gpus")
117
parser.add_argument('--seed',
120
help="random seed for initialization")
121
parser.add_argument('--gradient_accumulation_steps',
124
help="Number of updates steps to accumualte before performing a backward/update pass.")
125
parser.add_argument('--fp16',
127
help="Whether to use 16-bit float precision instead of 32-bit")
128
parser.add_argument('--loss_scale',
129
type = float, default = 0,
130
help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
131
"0 (default value): dynamic loss scaling.\n"
132
"Positive power of 2: static loss scaling value.\n")
134
parser.add_argument("--num_labels_task",
135
default=None, type=int,
137
help="num_labels_task")
138
parser.add_argument("--weight_decay",
141
help="Weight decay if we apply some.")
142
parser.add_argument("--adam_epsilon",
145
help="Epsilon for Adam optimizer.")
146
parser.add_argument("--max_grad_norm",
149
help="Max gradient norm.")
150
parser.add_argument('--fp16_opt_level',
153
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
154
"See details at https://nvidia.github.io/apex/amp.html")
155
parser.add_argument("--task",
160
parser.add_argument("--K",
169
def return_Classifier(weight, bias, dim_in, dim_out):
171
classifier = torch.nn.Linear(dim_in, dim_out , bias=True)
178
classifier.weight.data = weight.to("cpu")
179
classifier.bias.data = bias.to("cpu")
180
classifier.requires_grad=False
192
def load_GeneralDomain(dir_data_out):
196
print("Load CLS.pt and train.json")
198
docs_head = torch.load(dir_data_out+"train_head.pt")
199
docs_tail = torch.load(dir_data_out+"train_tail.pt")
201
print(docs_head.shape)
202
print(docs_tail.shape)
204
with open(dir_data_out+"train.json") as file:
205
data = json.load(file)
206
print("train.json Done")
208
docs_tail_head = torch.cat([docs_tail, docs_head],2)
209
return docs_tail_head, docs_head, docs_tail, data
212
def load_InDomain(dir_data_in):
214
with open(dir_data_in+"train.json") as file:
215
data = json.load(file)
216
for id, line in enumerate(data):
217
in_data[id]=line["sentence"]
220
parser = argparse.ArgumentParser()
221
parser = get_parameter(parser)
222
args = parser.parse_args()
225
docs_tail_head, docs_head, docs_tail, data = load_GeneralDomain(args.data_dir_outdomain)
226
in_data = load_InDomain(args.data_dir_indomain)
228
if docs_head.shape[1]!=1:
232
docs_head = docs_head.mean(1).unsqueeze(1)
233
print(docs_head.shape)
235
print(docs_head.shape)
236
if docs_tail.shape[1]!=1:
240
docs_tail = docs_tail.mean(1).unsqueeze(1)
241
print(docs_tail.shape)
243
print(docs_tail.shape)
246
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
248
with open(data_dir_indomain+"train.json") as file:
249
data = json.load(file)
252
num_label_list = list()
253
label_sentence_dict = dict()
254
num_sentiment_label_list = list()
255
sentiment_label_dict = dict()
260
num_sentiment_label_list.append(line["sentiment"])
262
num_label_list.append(line["sentiment"])
264
num_label = sorted(list(set(num_label_list)))
265
label_map = {label : i for i , label in enumerate(num_label)}
266
num_sentiment_label = sorted(list(set(num_sentiment_label_list)))
267
sentiment_label_map = {label : i for i , label in enumerate(num_sentiment_label)}
273
print("sentiment_label_map:")
274
print(sentiment_label_map)
280
all_input_ids = list()
281
all_input_mask = list()
282
all_segment_ids = list()
283
all_lm_labels_ids = list()
285
all_tail_idxs = list()
286
all_sentence_labels = list()
288
cur_tensors_list = list()
290
candidate_label_list = list(label_map.values())
291
candidate_sentiment_label_list = list(sentiment_label_map.values())
292
all_type_sentence = [0]*len(candidate_label_list)
293
all_type_sentiment_sentence = [0]*len(candidate_sentiment_label_list)
297
sentiment = line["sentiment"]
298
sentence = line["sentence"]
300
label = line["sentiment"]
303
tokens_a = tokenizer.tokenize(sentence)
306
if "</s>" in tokens_a:
307
print("Have more than 1 </s>")
308
#tokens_a[tokens_a.index("<s>")] = "s"
309
for i in range(len(tokens_a)):
310
if tokens_a[i] == "</s>":
316
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
318
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
320
cur_tensors = (torch.tensor(cur_features.input_ids),
321
torch.tensor(cur_features.input_ids_org),
322
torch.tensor(cur_features.input_mask),
323
torch.tensor(cur_features.segment_ids),
324
torch.tensor(cur_features.lm_label_ids),
326
torch.tensor(cur_features.tail_idxs),
327
torch.tensor(label_map[label]),
328
torch.tensor(sentiment_label_map[sentiment]))
330
cur_tensors_list.append(cur_tensors)
333
if label_map[label] in candidate_label_list:
334
all_type_sentence[label_map[label]]=cur_tensors
335
candidate_label_list.remove(label_map[label])
337
if sentiment_label_map[sentiment] in candidate_sentiment_label_list:
341
all_type_sentiment_sentence[sentiment_label_map[sentiment]]=cur_tensors
342
candidate_sentiment_label_list.remove(sentiment_label_map[sentiment])
348
return all_type_sentiment_sentence, cur_tensors_list
352
def AugmentationData_Domain(train_batch_size, k, tokenizer, max_seq_length):
361
ids_pos = random.sample(range(0,len(in_data)),train_batch_size)
363
ids_neg = random.sample(range(0,len(data)),train_batch_size*k)
368
all_input_ids = list()
369
all_input_ids_org = list()
370
all_input_mask = list()
371
all_segment_ids = list()
372
all_lm_labels_ids = list()
374
all_tail_idxs = list()
375
all_id_domain = list()
377
for id, i in enumerate(ids_pos):
378
t1 = data[str(i)]['sentence']
381
tokens_a = tokenizer.tokenize(t1)
383
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
386
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
388
all_input_ids.append(torch.tensor(cur_features.input_ids))
389
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
390
all_input_mask.append(torch.tensor(cur_features.input_mask))
391
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
392
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
393
all_is_next.append(torch.tensor(0))
394
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
399
all_id_domain.append(torch.tensor(1))
402
for id, i in enumerate(ids_neg):
403
t1 = data[str(i)]['sentence']
406
tokens_a = tokenizer.tokenize(t1)
408
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
411
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
413
all_input_ids.append(torch.tensor(cur_features.input_ids))
414
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
415
all_input_mask.append(torch.tensor(cur_features.input_mask))
416
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
417
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
418
all_is_next.append(torch.tensor(0))
419
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
424
all_id_domain.append(torch.tensor(0))
428
cur_tensors = (torch.stack(all_input_ids),
429
torch.stack(all_input_ids_org),
430
torch.stack(all_input_mask),
431
torch.stack(all_segment_ids),
432
torch.stack(all_lm_labels_ids),
433
torch.stack(all_is_next),
434
torch.stack(all_tail_idxs),
435
torch.stack(all_id_domain))
440
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None):
441
top_k_shape = top_k["indices"].shape
442
sentence_ids = top_k["indices"]
444
all_input_ids = list()
445
all_input_ids_org = list()
446
all_input_mask = list()
447
all_segment_ids = list()
448
all_lm_labels_ids = list()
450
all_tail_idxs = list()
451
all_sentence_labels = list()
452
all_sentiment_labels = list()
454
add_org = tuple(t.to('cpu') for t in add_org)
456
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
467
for id_1, sent in enumerate(sentence_ids):
468
for id_2, sent_id in enumerate(sent):
470
t1 = data[str(int(sent_id))]['sentence']
472
tokens_a = tokenizer.tokenize(t1)
475
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
478
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
480
all_input_ids.append(torch.tensor(cur_features.input_ids))
481
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
482
all_input_mask.append(torch.tensor(cur_features.input_mask))
483
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
484
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
485
all_is_next.append(torch.tensor(0))
486
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
487
all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
488
all_sentiment_labels.append(torch.tensor(sentiment_label_[id_1]))
490
all_input_ids.append(input_ids_[id_1])
491
all_input_ids_org.append(input_ids_org_[id_1])
492
all_input_mask.append(input_mask_[id_1])
493
all_segment_ids.append(segment_ids_[id_1])
494
all_lm_labels_ids.append(lm_label_ids_[id_1])
495
all_is_next.append(is_next_[id_1])
496
all_tail_idxs.append(tail_idxs_[id_1])
497
all_sentence_labels.append(sentence_label_[id_1])
498
all_sentiment_labels.append(sentiment_label_[id_1])
501
cur_tensors = (torch.stack(all_input_ids),
502
torch.stack(all_input_ids_org),
503
torch.stack(all_input_mask),
504
torch.stack(all_segment_ids),
505
torch.stack(all_lm_labels_ids),
506
torch.stack(all_is_next),
507
torch.stack(all_tail_idxs),
508
torch.stack(all_sentence_labels),
509
torch.stack(all_sentiment_labels)
519
def AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None, in_domain_rep=None):
521
top_k_shape = top_k.indices.shape
522
sentence_ids = top_k.indices
529
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
534
all_sentence_binary_label = list()
536
all_in_rep_comb = list()
538
for id_1, num in enumerate(sentence_label_):
541
sentence_label_int = (sentence_label_==num).to(torch.long)
547
in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
548
in_domain_rep_append = in_domain_rep[id_1].unsqueeze(0).expand(in_domain_rep.shape[0],-1)
551
in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
552
in_domain_rep_comb = torch.cat((in_domain_rep_append,in_domain_rep),-1)
563
in_rep_comb = torch.cat([in_domain_rep_comb,in_task_rep_comb],-1)
566
all_sentence_binary_label.append(sentence_label_int)
568
all_in_rep_comb.append(in_rep_comb)
569
all_sentence_binary_label = torch.stack(all_sentence_binary_label)
571
all_in_rep_comb = torch.stack(all_in_rep_comb)
574
cur_tensors = (all_in_rep_comb, all_sentence_binary_label)
581
def AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None):
583
top_k_shape = top_k.indices.shape
584
sentence_ids = top_k.indices
591
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
596
all_sentence_binary_label = list()
597
all_in_task_rep_comb = list()
599
for id_1, num in enumerate(sentence_label_):
602
sentence_label_int = (sentence_label_==num).to(torch.long)
608
in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
611
in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
620
all_sentence_binary_label.append(sentence_label_int)
621
all_in_task_rep_comb.append(in_task_rep_comb)
622
all_sentence_binary_label = torch.stack(all_sentence_binary_label)
623
all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
625
cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
632
class Dataset_noNext(Dataset):
633
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
635
self.vocab_size = tokenizer.vocab_size
636
self.tokenizer = tokenizer
637
self.seq_len = seq_len
638
self.on_memory = on_memory
639
self.corpus_lines = corpus_lines
640
self.corpus_path = corpus_path
641
self.encoding = encoding
645
self.sample_counter = 0
646
self.line_buffer = None
649
self.current_random_doc = 0
651
self.sample_to_doc = []
657
self.corpus_lines = 0
658
with open(corpus_path, "r", encoding=encoding) as f:
659
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
662
self.all_docs.append(doc)
665
self.sample_to_doc.pop()
668
sample = {"doc_id": len(self.all_docs),
670
self.sample_to_doc.append(sample)
672
self.corpus_lines = self.corpus_lines + 1
675
if self.all_docs[-1] != doc:
676
self.all_docs.append(doc)
677
self.sample_to_doc.pop()
679
self.num_docs = len(self.all_docs)
683
if self.corpus_lines is None:
684
with open(corpus_path, "r", encoding=encoding) as f:
685
self.corpus_lines = 0
686
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
687
if line.strip() == "":
690
self.corpus_lines += 1
693
if line.strip() != "":
696
self.file = open(corpus_path, "r", encoding=encoding)
697
self.random_file = open(corpus_path, "r", encoding=encoding)
701
return self.corpus_lines - self.num_docs - 1
703
def __getitem__(self, item):
704
cur_id = self.sample_counter
705
self.sample_counter += 1
706
if not self.on_memory:
708
if cur_id != 0 and (cur_id % len(self) == 0):
710
self.file = open(self.corpus_path, "r", encoding=self.encoding)
713
t1, is_next_label = self.random_sent(item)
714
if is_next_label == None:
719
tokens_a = tokenizer.tokenize(t1)
721
if "</s>" in tokens_a:
722
print("Have more than 1 </s>")
723
#tokens_a[tokens_a.index("<s>")] = "s"
724
for i in range(len(tokens_a)):
725
if tokens_a[i] == "</s>":
731
cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
734
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
736
cur_tensors = (torch.tensor(cur_features.input_ids),
737
torch.tensor(cur_features.input_ids_org),
738
torch.tensor(cur_features.input_mask),
739
torch.tensor(cur_features.segment_ids),
740
torch.tensor(cur_features.lm_label_ids),
741
torch.tensor(cur_features.is_next),
742
torch.tensor(cur_features.tail_idxs))
746
def random_sent(self, index):
748
Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
749
from one doc. With 50% the second sentence will be a random one from another doc.
750
:param index: int, index of sample.
751
:return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
753
t1, t2 = self.get_corpus_line(index)
756
def get_corpus_line(self, item):
758
Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
759
:param item: int, index of sample.
760
:return: (str, str), two subsequent sentences from corpus
764
assert item < self.corpus_lines
766
sample = self.sample_to_doc[item]
767
t1 = self.all_docs[sample["doc_id"]][sample["line"]]
769
self.current_doc = sample["doc_id"]
773
if self.line_buffer is None:
776
t1 = next(self.file).strip()
779
t1 = self.line_buffer
782
t1 = next(self.file).strip()
783
self.current_doc = self.current_doc+1
784
self.line_buffer = next(self.file).strip()
790
def get_random_line(self):
792
Get random line from another document for nextSentence task.
793
:return: str, content of one line
800
rand_doc_idx = random.randint(0, len(self.all_docs)-1)
801
rand_doc = self.all_docs[rand_doc_idx]
802
line = rand_doc[random.randrange(len(rand_doc))]
804
rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
806
for _ in range(rand_index):
807
line = self.get_next_line()
809
if self.current_random_doc != self.current_doc:
813
def get_next_line(self):
814
""" Gets next line of random_file and starts over when reaching end of file"""
816
line = next(self.random_file).strip()
819
self.current_random_doc = self.current_random_doc + 1
820
line = next(self.random_file).strip()
821
except StopIteration:
822
self.random_file.close()
823
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
824
line = next(self.random_file).strip()
828
class InputExample(object):
829
"""A single training/test example for the language model."""
831
def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
832
"""Constructs a InputExample.
834
guid: Unique id for the example.
835
tokens_a: string. The untokenized text of the first sequence. For single
836
sequence tasks, only this sequence must be specified.
837
tokens_b: (Optional) string. The untokenized text of the second sequence.
838
Only must be specified for sequence pair tasks.
839
label: (Optional) string. The label of the example. This should be
840
specified for train and dev examples, but not for test examples.
843
self.tokens_a = tokens_a
844
self.tokens_b = tokens_b
845
self.is_next = is_next
846
self.lm_labels = lm_labels
849
class InputFeatures(object):
850
"""A single set of features of data."""
852
def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
853
self.input_ids = input_ids
854
self.input_ids_org = input_ids_org
855
self.input_mask = input_mask
856
self.segment_ids = segment_ids
857
self.is_next = is_next
858
self.lm_label_ids = lm_label_ids
859
self.tail_idxs = tail_idxs
862
def random_word(tokens, tokenizer):
864
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
865
:param tokens: list of str, tokenized sentence.
866
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
867
:return: (list of str, list of int), masked tokens and related labels for LM prediction
871
for i, token in enumerate(tokens):
873
prob = random.random()
890
candidate_id = random.randint(0,tokenizer.vocab_size)
891
w = tokenizer.convert_ids_to_tokens(candidate_id)
893
if tokens[i] == None:
895
w = tokenizer.convert_ids_to_tokens(candidate_id)
905
w = tokenizer.convert_tokens_to_ids(token)
907
output_label.append(w)
909
print("Have no this tokens in ids")
914
w = tokenizer.convert_tokens_to_ids("<unk>")
915
output_label.append(w)
916
logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
919
output_label.append(-1)
921
return tokens, output_label
924
def convert_example_to_features(example, max_seq_length, tokenizer):
926
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
927
IDs, LM labels, input_mask, CLS and SEP tokens etc.
928
:param example: InputExample, containing sentence input as strings and is_next label
929
:param max_seq_length: int, maximum length of sequence.
930
:param tokenizer: Tokenizer
931
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
934
tokens_a = example.tokens_a
935
tokens_b = example.tokens_b
940
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
943
tokens_a_org = tokens_a.copy()
944
tokens_a, t1_label = random_word(tokens_a, tokenizer)
954
lm_label_ids = ([-1] + t1_label + [-1])
978
tokens_org.append("<s>")
979
segment_ids.append(0)
980
for i, token in enumerate(tokens_a):
982
tokens.append(tokens_a[i])
983
tokens_org.append(tokens_a_org[i])
984
segment_ids.append(0)
987
tokens_org.append("s")
988
segment_ids.append(0)
989
tokens.append("</s>")
990
tokens_org.append("</s>")
991
segment_ids.append(0)
997
input_ids = tokenizer.encode(tokens, add_special_tokens=False)
998
input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
999
tail_idxs = len(input_ids)-1
1002
input_ids = [w if w!=None else 0 for w in input_ids]
1003
input_ids_org = [w if w!=None else 0 for w in input_ids_org]
1009
input_mask = [1] * len(input_ids)
1012
pad_id = tokenizer.convert_tokens_to_ids("<pad>")
1013
while len(input_ids) < max_seq_length:
1014
input_ids.append(pad_id)
1015
input_ids_org.append(pad_id)
1016
input_mask.append(0)
1017
segment_ids.append(0)
1018
lm_label_ids.append(-1)
1021
assert len(input_ids) == max_seq_length
1022
assert len(input_ids_org) == max_seq_length
1023
assert len(input_mask) == max_seq_length
1024
assert len(segment_ids) == max_seq_length
1025
assert len(lm_label_ids) == max_seq_length
1027
print("!!!Warning!!!")
1028
input_ids = input_ids[:max_seq_length-1]
1029
if 2 not in input_ids:
1032
input_ids += [pad_id]
1033
input_ids_org = input_ids_org[:max_seq_length-1]
1034
if 2 not in input_ids_org:
1035
input_ids_org += [2]
1037
input_ids_org += [pad_id]
1038
input_mask = input_mask[:max_seq_length-1]+[0]
1039
segment_ids = segment_ids[:max_seq_length-1]+[0]
1040
lm_label_ids = lm_label_ids[:max_seq_length-1]+[-1]
1043
if len(input_ids) != max_seq_length:
1044
print(len(input_ids))
1046
if len(input_ids_org) != max_seq_length:
1047
print(len(input_ids_org))
1049
if len(input_mask) != max_seq_length:
1050
print(len(input_mask))
1052
if len(segment_ids) != max_seq_length:
1053
print(len(segment_ids))
1055
if len(lm_label_ids) != max_seq_length:
1056
print(len(lm_label_ids))
1064
if example.guid < 5:
1065
logger.info("*** Example ***")
1066
logger.info("guid: %s" % (example.guid))
1067
logger.info("tokens: %s" % " ".join(
1068
[str(x) for x in tokens]))
1069
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
1070
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
1072
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
1073
logger.info("LM label: %s " % (lm_label_ids))
1074
logger.info("Is next sentence label: %s " % (example.is_next))
1077
features = InputFeatures(input_ids=input_ids,
1078
input_ids_org = input_ids_org,
1079
input_mask=input_mask,
1080
segment_ids=segment_ids,
1081
lm_label_ids=lm_label_ids,
1082
is_next=example.is_next,
1083
tail_idxs=tail_idxs)
1088
parser = argparse.ArgumentParser()
1090
parser = get_parameter(parser)
1092
args = parser.parse_args()
1094
if args.local_rank == -1 or args.no_cuda:
1095
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1096
n_gpu = torch.cuda.device_count()
1098
torch.cuda.set_device(args.local_rank)
1099
device = torch.device("cuda", args.local_rank)
1102
torch.distributed.init_process_group(backend='nccl')
1103
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
1104
device, n_gpu, bool(args.local_rank != -1), args.fp16))
1106
if args.gradient_accumulation_steps < 1:
1107
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
1108
args.gradient_accumulation_steps))
1110
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
1112
random.seed(args.seed)
1113
np.random.seed(args.seed)
1114
torch.manual_seed(args.seed)
1116
torch.cuda.manual_seed_all(args.seed)
1118
if not args.do_train:
1119
raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
1121
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
1122
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
1123
if not os.path.exists(args.output_dir):
1124
os.makedirs(args.output_dir)
1127
tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)
1131
num_train_optimization_steps = None
1133
print("Loading Train Dataset", args.data_dir_indomain)
1135
all_type_sentence, train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
1136
num_train_optimization_steps = int(
1137
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
1138
if args.local_rank != -1:
1139
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
1144
model = RobertaForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1152
param_optimizer = list(model.named_parameters())
1154
for par in param_optimizer:
1158
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
1159
optimizer_grouped_parameters = [
1160
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
1161
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
1163
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
1164
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
1168
from apex import amp
1170
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
1173
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
1177
model = torch.nn.DataParallel(model)
1179
if args.local_rank != -1:
1180
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
1186
logger.info("***** Running training *****")
1187
logger.info(" Num examples = %d", len(train_dataset))
1188
logger.info(" Batch size = %d", args.train_batch_size)
1189
logger.info(" Num steps = %d", num_train_optimization_steps)
1191
if args.local_rank == -1:
1192
train_sampler = RandomSampler(train_dataset)
1197
train_sampler = DistributedSampler(train_dataset)
1199
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1202
output_loss_file = os.path.join(args.output_dir, "loss")
1203
loss_fout = open(output_loss_file, 'w')
1206
output_loss_file_no_pseudo = os.path.join(args.output_dir, "loss_no_pseudo")
1207
loss_fout_no_pseudo = open(output_loss_file_no_pseudo, 'w')
1224
all_type_sentence_label = list()
1225
all_previous_sentence_label = list()
1226
all_type_sentiment_label = list()
1227
all_previous_sentiment_label = list()
1228
top_k_all_type = dict()
1229
bottom_k_all_type = dict()
1230
for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1232
nb_tr_examples, nb_tr_steps = 0, 0
1233
for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1240
batch_ = tuple(t.to(device) for t in batch_)
1241
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = batch_
1246
in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1252
#Domain Binary Classifier - Outdomain
1253
#batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1254
batch = AugmentationData_Domain(in_domain_rep.shape[0], k, tokenizer, args.max_seq_length)
1255
batch = tuple(t.to(device) for t in batch)
1256
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, domain_id = batch
1258
out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, lm_label=lm_label_ids, attention_mask=input_mask, func="in_domain_task_rep")
1260
#print(domain_top_k["indices"].shape)
1261
#print(input_ids_org.shape)
1262
#print(out_domain_rep_tail.shape)
1263
#print(in_domain_rep.shape)
1265
############Construct constrive instances
1266
#print("=============")
1267
#print(out_domain_rep_tail.shape)
1268
#print(in_domain_rep.shape)
1269
#print("=============")
1270
comb_rep_pos = torch.cat([in_domain_rep,in_domain_rep.flip(0)], 1)
1271
in_domain_rep_ready = in_domain_rep.repeat(1,int(out_domain_rep_tail.shape[0]/in_domain_rep.shape[0])).reshape(out_domain_rep_tail.shape[0],out_domain_rep_tail.shape[1])
1272
comb_rep_unknow = torch.cat([in_domain_rep_ready, out_domain_rep_tail], 1)
1274
in_domain_binary_loss, domain_binary_logit = model(func="domain_binary_classifier", in_domain_rep=comb_rep_pos.to(device), out_domain_rep=comb_rep_unknow.to(device), domain_id=domain_id,use_detach=True)
1284
batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep)
1285
batch = tuple(t.to(device) for t in batch)
1286
all_in_task_rep_comb, all_sentence_binary_label = batch
1287
in_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1294
task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1316
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1317
batch = tuple(t.to(device) for t in batch)
1318
all_in_task_rep_comb, all_sentence_binary_label = batch
1319
in_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1329
loss = task_loss_org.mean() + in_task_binary_loss.mean()
1334
print("No Using GPU")
1337
if args.gradient_accumulation_steps > 1:
1338
loss = loss / args.gradient_accumulation_steps
1340
with amp.scale_loss(loss, optimizer) as scaled_loss:
1341
scaled_loss.backward()
1346
loss_fout.write("{}\n".format(loss.item()))
1353
tr_loss += loss.item()
1355
nb_tr_examples += input_ids_.size(0)
1357
if (step + 1) % args.gradient_accumulation_steps == 0:
1362
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1365
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1380
model_to_save = model.module if hasattr(model, 'module') else model
1382
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(epo))
1383
torch.save(model_to_save.state_dict(), output_model_file)
1386
#if args.num_train_epochs/args.augment_times in [1,2,3]:
1387
if (args.num_train_epochs/(args.augment_times/5))%5 == 0:
1388
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1389
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1390
torch.save(model_to_save.state_dict(), output_model_file)
1397
logger.info("** ** * Saving fine - tuned model ** ** * ")
1398
model_to_save = model.module if hasattr(model, 'module') else model
1399
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1401
torch.save(model_to_save.state_dict(), output_model_file)
1405
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1406
"""Truncates a sequence pair in place to the maximum length."""
1414
total_length = len(tokens_a)
1415
if total_length <= max_length:
1421
def accuracy(out, labels):
1422
outputs = np.argmax(out, axis=1)
1423
return np.sum(outputs == labels)
1426
if __name__ == "__main__":