16
"""BERT finetuning runner."""
18
from __future__ import absolute_import, division, print_function, unicode_literals
27
from torch.autograd import Variable
28
import torch.nn.functional as F
32
from torch.utils.data import DataLoader, Dataset, RandomSampler
33
from torch.utils.data.distributed import DistributedSampler
34
from tqdm import tqdm, trange
35
from torch.nn import CrossEntropyLoss
37
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
39
from transformers.modeling_roberta_updateRep_self import RobertaForMaskedLMDomainTask
40
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
42
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
43
datefmt='%m/%d/%Y %H:%M:%S',
45
logger = logging.getLogger(__name__)
48
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
51
def get_parameter(parser):
54
parser.add_argument("--data_dir_indomain",
58
help="The input train corpus.(In Domain)")
59
parser.add_argument("--data_dir_outdomain",
63
help="The input train corpus.(Out Domain)")
64
parser.add_argument("--pretrain_model", default=None, type=str, required=True,
65
help="Bert pre-trained model selected in the list: bert-base-uncased, "
66
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
67
parser.add_argument("--output_dir",
71
help="The output directory where the model checkpoints will be written.")
72
parser.add_argument("--augment_times",
76
help="Default batch_size/augment_times to save model")
78
parser.add_argument("--max_seq_length",
81
help="The maximum total input sequence length after WordPiece tokenization. \n"
82
"Sequences longer than this will be truncated, and sequences shorter \n"
83
"than this will be padded.")
84
parser.add_argument("--do_train",
86
help="Whether to run training.")
87
parser.add_argument("--train_batch_size",
90
help="Total batch size for training.")
91
parser.add_argument("--learning_rate",
94
help="The initial learning rate for Adam.")
95
parser.add_argument("--num_train_epochs",
98
help="Total number of training epochs to perform.")
99
parser.add_argument("--warmup_proportion",
102
help="Proportion of training to perform linear learning rate warmup for. "
103
"E.g., 0.1 = 10%% of training.")
104
parser.add_argument("--no_cuda",
106
help="Whether not to use CUDA when available")
107
parser.add_argument("--on_memory",
109
help="Whether to load train samples into memory or use disk")
110
parser.add_argument("--do_lower_case",
112
help="Whether to lower case the input text. True for uncased models, False for cased models.")
113
parser.add_argument("--local_rank",
116
help="local_rank for distributed training on gpus")
117
parser.add_argument('--seed',
120
help="random seed for initialization")
121
parser.add_argument('--gradient_accumulation_steps',
124
help="Number of updates steps to accumualte before performing a backward/update pass.")
125
parser.add_argument('--fp16',
127
help="Whether to use 16-bit float precision instead of 32-bit")
128
parser.add_argument('--loss_scale',
129
type = float, default = 0,
130
help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
131
"0 (default value): dynamic loss scaling.\n"
132
"Positive power of 2: static loss scaling value.\n")
134
parser.add_argument("--num_labels_task",
135
default=None, type=int,
137
help="num_labels_task")
138
parser.add_argument("--weight_decay",
141
help="Weight decay if we apply some.")
142
parser.add_argument("--adam_epsilon",
145
help="Epsilon for Adam optimizer.")
146
parser.add_argument("--max_grad_norm",
149
help="Max gradient norm.")
150
parser.add_argument('--fp16_opt_level',
153
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
154
"See details at https://nvidia.github.io/apex/amp.html")
155
parser.add_argument("--task",
160
parser.add_argument("--K",
169
def return_Classifier(weight, bias, dim_in, dim_out):
171
classifier = torch.nn.Linear(dim_in, dim_out , bias=True)
178
classifier.weight.data = weight.to("cpu")
179
classifier.bias.data = bias.to("cpu")
180
classifier.requires_grad=False
192
def load_GeneralDomain(dir_data_out):
196
print("Load CLS.pt and train.json")
198
docs_head = torch.load(dir_data_out+"train_head.pt")
199
docs_tail = torch.load(dir_data_out+"train_tail.pt")
201
print(docs_head.shape)
202
print(docs_tail.shape)
204
with open(dir_data_out+"train.json") as file:
205
data = json.load(file)
206
print("train.json Done")
208
docs_tail_head = torch.cat([docs_tail, docs_head],2)
209
return docs_tail_head, docs_head, docs_tail, data
213
parser = argparse.ArgumentParser()
214
parser = get_parameter(parser)
215
args = parser.parse_args()
219
docs_tail_head, docs_head, docs_tail, data = load_GeneralDomain(args.data_dir_outdomain)
221
if docs_head.shape[1]!=1:
225
docs_head = docs_head.mean(1).unsqueeze(1)
226
print(docs_head.shape)
228
print(docs_head.shape)
229
if docs_tail.shape[1]!=1:
233
docs_tail = docs_tail.mean(1).unsqueeze(1)
234
print(docs_tail.shape)
236
print(docs_tail.shape)
239
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
241
with open(data_dir_indomain+"train.json") as file:
242
data = json.load(file)
245
num_label_list = list()
246
label_sentence_dict = dict()
247
num_sentiment_label_list = list()
248
sentiment_label_dict = dict()
253
num_sentiment_label_list.append(line["sentiment"])
255
num_label_list.append(line["sentiment"])
257
num_label = sorted(list(set(num_label_list)))
258
label_map = {label : i for i , label in enumerate(num_label)}
259
num_sentiment_label = sorted(list(set(num_sentiment_label_list)))
260
sentiment_label_map = {label : i for i , label in enumerate(num_sentiment_label)}
266
print("sentiment_label_map:")
267
print(sentiment_label_map)
273
all_input_ids = list()
274
all_input_mask = list()
275
all_segment_ids = list()
276
all_lm_labels_ids = list()
278
all_tail_idxs = list()
279
all_sentence_labels = list()
281
cur_tensors_list = list()
283
candidate_label_list = list(label_map.values())
284
candidate_sentiment_label_list = list(sentiment_label_map.values())
285
all_type_sentence = [0]*len(candidate_label_list)
286
all_type_sentiment_sentence = [0]*len(candidate_sentiment_label_list)
290
sentiment = line["sentiment"]
291
sentence = line["sentence"]
293
label = line["sentiment"]
296
tokens_a = tokenizer.tokenize(sentence)
299
if "</s>" in tokens_a:
300
print("Have more than 1 </s>")
301
#tokens_a[tokens_a.index("<s>")] = "s"
302
for i in range(len(tokens_a)):
303
if tokens_a[i] == "</s>":
309
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
311
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
313
cur_tensors = (torch.tensor(cur_features.input_ids),
314
torch.tensor(cur_features.input_ids_org),
315
torch.tensor(cur_features.input_mask),
316
torch.tensor(cur_features.segment_ids),
317
torch.tensor(cur_features.lm_label_ids),
319
torch.tensor(cur_features.tail_idxs),
320
torch.tensor(label_map[label]),
321
torch.tensor(sentiment_label_map[sentiment]))
323
cur_tensors_list.append(cur_tensors)
326
if label_map[label] in candidate_label_list:
327
all_type_sentence[label_map[label]]=cur_tensors
328
candidate_label_list.remove(label_map[label])
330
if sentiment_label_map[sentiment] in candidate_sentiment_label_list:
334
all_type_sentiment_sentence[sentiment_label_map[sentiment]]=cur_tensors
335
candidate_sentiment_label_list.remove(sentiment_label_map[sentiment])
341
return all_type_sentiment_sentence, cur_tensors_list
345
def AugmentationData_Domain(bottom_k, top_k, tokenizer, max_seq_length):
348
top_k_shape = top_k["indices"].shape
349
ids_pos = top_k["indices"].reshape(top_k_shape[0]*top_k_shape[1]).tolist()
352
bottom_k_shape = bottom_k["indices"].shape
353
ids_neg = bottom_k["indices"].reshape(bottom_k_shape[0]*bottom_k_shape[1]).tolist()
359
ids = ids_pos+ids_neg
362
all_input_ids = list()
363
all_input_ids_org = list()
364
all_input_mask = list()
365
all_segment_ids = list()
366
all_lm_labels_ids = list()
368
all_tail_idxs = list()
369
all_id_domain = list()
371
for id, i in enumerate(ids):
372
t1 = data[str(i)]['sentence']
375
tokens_a = tokenizer.tokenize(t1)
377
if "</s>" in tokens_a:
378
print("Have more than 1 </s>")
379
#tokens_a[tokens_a.index("<s>")] = "s"
380
for i in range(len(tokens_a)):
381
if tokens_a[i] == "</s>":
386
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
389
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
391
all_input_ids.append(torch.tensor(cur_features.input_ids))
392
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
393
all_input_mask.append(torch.tensor(cur_features.input_mask))
394
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
395
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
396
all_is_next.append(torch.tensor(0))
397
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
399
all_id_domain.append(torch.tensor(0))
401
all_id_domain.append(torch.tensor(1))
404
cur_tensors = (torch.stack(all_input_ids),
405
torch.stack(all_input_ids_org),
406
torch.stack(all_input_mask),
407
torch.stack(all_segment_ids),
408
torch.stack(all_lm_labels_ids),
409
torch.stack(all_is_next),
410
torch.stack(all_tail_idxs),
411
torch.stack(all_id_domain))
416
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None, show_type=None):
417
top_k_shape = top_k["indices"].shape
418
sentence_ids = top_k["indices"]
420
all_input_ids = list()
421
all_input_ids_org = list()
422
all_input_mask = list()
423
all_segment_ids = list()
424
all_lm_labels_ids = list()
426
all_tail_idxs = list()
427
all_sentence_labels = list()
428
all_sentiment_labels = list()
430
add_org = tuple(t.to('cpu') for t in add_org)
432
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
443
for id_1, sent in enumerate(sentence_ids):
446
if int(id_1) in show_type and list(input_ids_org_[id_1]).index(2)>10:
447
print("=======================")
449
print(tokenizer.decode(input_ids_org_[id_1]))
450
print(int(sentence_label_[id_1]))
451
print("-----------------------")
452
for id_2, sent_id in enumerate(sent):
453
t1 = data[str(int(sent_id))]['sentence']
455
if int(id_1) in show_type and list(input_ids_org_[id_1]).index(2)>10:
456
print("-----------------------")
459
print("-----------------------")
463
tokens_a = tokenizer.tokenize(t1)
466
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
469
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
471
all_input_ids.append(torch.tensor(cur_features.input_ids))
472
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
473
all_input_mask.append(torch.tensor(cur_features.input_mask))
474
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
475
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
476
all_is_next.append(torch.tensor(0))
477
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
478
all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
479
all_sentiment_labels.append(torch.tensor(sentiment_label_[id_1]))
481
all_input_ids.append(input_ids_[id_1])
482
all_input_ids_org.append(input_ids_org_[id_1])
483
all_input_mask.append(input_mask_[id_1])
484
all_segment_ids.append(segment_ids_[id_1])
485
all_lm_labels_ids.append(lm_label_ids_[id_1])
486
all_is_next.append(is_next_[id_1])
487
all_tail_idxs.append(tail_idxs_[id_1])
488
all_sentence_labels.append(sentence_label_[id_1])
489
all_sentiment_labels.append(sentiment_label_[id_1])
492
cur_tensors = (torch.stack(all_input_ids),
493
torch.stack(all_input_ids_org),
494
torch.stack(all_input_mask),
495
torch.stack(all_segment_ids),
496
torch.stack(all_lm_labels_ids),
497
torch.stack(all_is_next),
498
torch.stack(all_tail_idxs),
499
torch.stack(all_sentence_labels),
500
torch.stack(all_sentiment_labels)
510
def AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None, in_domain_rep=None):
512
top_k_shape = top_k.indices.shape
513
sentence_ids = top_k.indices
520
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
525
all_sentence_binary_label = list()
527
all_in_rep_comb = list()
529
for id_1, num in enumerate(sentence_label_):
532
sentence_label_int = (sentence_label_==num).to(torch.long)
538
in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
539
in_domain_rep_append = in_domain_rep[id_1].unsqueeze(0).expand(in_domain_rep.shape[0],-1)
542
in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
543
in_domain_rep_comb = torch.cat((in_domain_rep_append,in_domain_rep),-1)
554
in_rep_comb = torch.cat([in_domain_rep_comb,in_task_rep_comb],-1)
557
all_sentence_binary_label.append(sentence_label_int)
559
all_in_rep_comb.append(in_rep_comb)
560
all_sentence_binary_label = torch.stack(all_sentence_binary_label)
562
all_in_rep_comb = torch.stack(all_in_rep_comb)
565
cur_tensors = (all_in_rep_comb, all_sentence_binary_label)
572
def AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None):
574
top_k_shape = top_k.indices.shape
575
sentence_ids = top_k.indices
582
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
587
all_sentence_binary_label = list()
588
all_in_task_rep_comb = list()
590
for id_1, num in enumerate(sentence_label_):
593
sentence_label_int = (sentence_label_==num).to(torch.long)
599
in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
602
in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
611
all_sentence_binary_label.append(sentence_label_int)
612
all_in_task_rep_comb.append(in_task_rep_comb)
613
all_sentence_binary_label = torch.stack(all_sentence_binary_label)
614
all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
616
cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
623
class Dataset_noNext(Dataset):
624
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
626
self.vocab_size = tokenizer.vocab_size
627
self.tokenizer = tokenizer
628
self.seq_len = seq_len
629
self.on_memory = on_memory
630
self.corpus_lines = corpus_lines # number of non-empty lines in input corpus
631
self.corpus_path = corpus_path
632
self.encoding = encoding
633
self.current_doc = 0 # to avoid random sentence from same doc
635
# for loading samples directly from file
636
self.sample_counter = 0 # used to keep track of full epochs on file
637
self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair
639
# for loading samples in memory
640
self.current_random_doc = 0
642
self.sample_to_doc = [] # map sample index to doc and line
644
# load samples into memory
648
self.corpus_lines = 0
649
with open(corpus_path, "r", encoding=encoding) as f:
650
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
653
self.all_docs.append(doc)
655
#remove last added sample because there won't be a subsequent line anymore in the doc
656
self.sample_to_doc.pop()
659
sample = {"doc_id": len(self.all_docs),
661
self.sample_to_doc.append(sample)
663
self.corpus_lines = self.corpus_lines + 1
665
# if last row in file is not empty
666
if self.all_docs[-1] != doc:
667
self.all_docs.append(doc)
668
self.sample_to_doc.pop()
670
self.num_docs = len(self.all_docs)
672
# load samples later lazily from disk
674
if self.corpus_lines is None:
675
with open(corpus_path, "r", encoding=encoding) as f:
676
self.corpus_lines = 0
677
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
678
if line.strip() == "":
681
self.corpus_lines += 1
683
# if doc does not end with empty line
684
if line.strip() != "":
687
self.file = open(corpus_path, "r", encoding=encoding)
688
self.random_file = open(corpus_path, "r", encoding=encoding)
691
# last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
692
return self.corpus_lines - self.num_docs - 1
694
def __getitem__(self, item):
695
cur_id = self.sample_counter
696
self.sample_counter += 1
697
if not self.on_memory:
698
# after one epoch we start again from beginning of file
699
if cur_id != 0 and (cur_id % len(self) == 0):
701
self.file = open(self.corpus_path, "r", encoding=self.encoding)
703
#t1, t2, is_next_label = self.random_sent(item)
704
t1, is_next_label = self.random_sent(item)
705
if is_next_label == None:
709
#tokens_a = self.tokenizer.tokenize(t1)
710
tokens_a = tokenizer.tokenize(t1)
711
#if "</s>" in tokens_a:
712
# print("Have more than 1 </s>")
713
# #tokens_a[tokens_a.index("<s>")] = "s"
714
# for i in range(len(tokens_a)):
715
# if tokens_a[i] == "</s>":
717
#tokens_b = self.tokenizer.tokenize(t2)
720
cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
722
# transform sample to features
723
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
725
cur_tensors = (torch.tensor(cur_features.input_ids),
726
torch.tensor(cur_features.input_ids_org),
727
torch.tensor(cur_features.input_mask),
728
torch.tensor(cur_features.segment_ids),
729
torch.tensor(cur_features.lm_label_ids),
730
torch.tensor(cur_features.is_next),
731
torch.tensor(cur_features.tail_idxs))
735
def random_sent(self, index):
737
Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
738
from one doc. With 50% the second sentence will be a random one from another doc.
739
:param index: int, index of sample.
740
:return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
742
t1, t2 = self.get_corpus_line(index)
745
def get_corpus_line(self, item):
747
Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
748
:param item: int, index of sample.
749
:return: (str, str), two subsequent sentences from corpus
753
assert item < self.corpus_lines
755
sample = self.sample_to_doc[item]
756
t1 = self.all_docs[sample["doc_id"]][sample["line"]]
757
# used later to avoid random nextSentence from same doc
758
self.current_doc = sample["doc_id"]
762
if self.line_buffer is None:
763
# read first non-empty line of file
765
t1 = next(self.file).strip()
767
# use t2 from previous iteration as new t1
768
t1 = self.line_buffer
769
# skip empty rows that are used for separating documents and keep track of current doc id
771
t1 = next(self.file).strip()
772
self.current_doc = self.current_doc+1
773
self.line_buffer = next(self.file).strip()
779
def get_random_line(self):
781
Get random line from another document for nextSentence task.
782
:return: str, content of one line
784
# Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
785
# corpora. However, just to be careful, we try to make sure that
786
# the random document is not the same as the document we're processing.
789
rand_doc_idx = random.randint(0, len(self.all_docs)-1)
790
rand_doc = self.all_docs[rand_doc_idx]
791
line = rand_doc[random.randrange(len(rand_doc))]
793
rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
795
for _ in range(rand_index):
796
line = self.get_next_line()
797
#check if our picked random line is really from another doc like we want it to be
798
if self.current_random_doc != self.current_doc:
802
def get_next_line(self):
803
""" Gets next line of random_file and starts over when reaching end of file"""
805
line = next(self.random_file).strip()
806
#keep track of which document we are currently looking at to later avoid having the same doc as t1
808
self.current_random_doc = self.current_random_doc + 1
809
line = next(self.random_file).strip()
810
except StopIteration:
811
self.random_file.close()
812
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
813
line = next(self.random_file).strip()
818
class InputExample(object):
819
"""A single training/test example for the language model."""
821
def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
822
"""Constructs a InputExample.
824
guid: Unique id for the example.
825
tokens_a: string. The untokenized text of the first sequence. For single
826
sequence tasks, only this sequence must be specified.
827
tokens_b: (Optional) string. The untokenized text of the second sequence.
828
Only must be specified for sequence pair tasks.
829
label: (Optional) string. The label of the example. This should be
830
specified for train and dev examples, but not for test examples.
833
self.tokens_a = tokens_a
834
self.tokens_b = tokens_b
835
self.is_next = is_next
836
self.lm_labels = lm_labels
839
class InputFeatures(object):
840
"""A single set of features of data."""
842
def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
843
self.input_ids = input_ids
844
self.input_ids_org = input_ids_org
845
self.input_mask = input_mask
846
self.segment_ids = segment_ids
847
self.is_next = is_next
848
self.lm_label_ids = lm_label_ids
849
self.tail_idxs = tail_idxs
852
def random_word(tokens, tokenizer):
854
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
855
:param tokens: list of str, tokenized sentence.
856
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
857
:return: (list of str, list of int), masked tokens and related labels for LM prediction
861
for i, token in enumerate(tokens):
863
prob = random.random()
880
candidate_id = random.randint(0,tokenizer.vocab_size)
881
w = tokenizer.convert_ids_to_tokens(candidate_id)
883
if tokens[i] == None:
885
w = tokenizer.convert_ids_to_tokens(candidate_id)
895
w = tokenizer.convert_tokens_to_ids(token)
897
output_label.append(w)
899
print("Have no this tokens in ids")
904
w = tokenizer.convert_tokens_to_ids("<unk>")
905
output_label.append(w)
906
logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
909
output_label.append(-1)
911
return tokens, output_label
914
def convert_example_to_features(example, max_seq_length, tokenizer):
916
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
917
IDs, LM labels, input_mask, CLS and SEP tokens etc.
918
:param example: InputExample, containing sentence input as strings and is_next label
919
:param max_seq_length: int, maximum length of sequence.
920
:param tokenizer: Tokenizer
921
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
924
tokens_a = example.tokens_a
925
tokens_b = example.tokens_b
930
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
933
tokens_a_org = tokens_a.copy()
934
tokens_a, t1_label = random_word(tokens_a, tokenizer)
944
lm_label_ids = ([-1] + t1_label + [-1])
968
tokens_org.append("<s>")
969
segment_ids.append(0)
970
for i, token in enumerate(tokens_a):
972
tokens.append(tokens_a[i])
973
tokens_org.append(tokens_a_org[i])
974
segment_ids.append(0)
977
tokens_org.append("s")
978
segment_ids.append(0)
979
tokens.append("</s>")
980
tokens_org.append("</s>")
981
segment_ids.append(0)
987
input_ids = tokenizer.encode(tokens, add_special_tokens=False)
988
input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
989
tail_idxs = len(input_ids)-1
992
input_ids = [w if w!=None else 0 for w in input_ids]
993
input_ids_org = [w if w!=None else 0 for w in input_ids_org]
999
input_mask = [1] * len(input_ids)
1002
pad_id = tokenizer.convert_tokens_to_ids("<pad>")
1003
while len(input_ids) < max_seq_length:
1004
input_ids.append(pad_id)
1005
input_ids_org.append(pad_id)
1006
input_mask.append(0)
1007
segment_ids.append(0)
1008
lm_label_ids.append(-1)
1011
assert len(input_ids) == max_seq_length
1012
assert len(input_ids_org) == max_seq_length
1013
assert len(input_mask) == max_seq_length
1014
assert len(segment_ids) == max_seq_length
1015
assert len(lm_label_ids) == max_seq_length
1017
print("!!!Warning!!!")
1018
input_ids = input_ids[:max_seq_length-1]
1019
if 2 not in input_ids:
1022
input_ids += [pad_id]
1023
input_ids_org = input_ids_org[:max_seq_length-1]
1024
if 2 not in input_ids_org:
1025
input_ids_org += [2]
1027
input_ids_org += [pad_id]
1028
input_mask = input_mask[:max_seq_length-1]+[0]
1029
segment_ids = segment_ids[:max_seq_length-1]+[0]
1030
lm_label_ids = lm_label_ids[:max_seq_length-1]+[-1]
1033
if len(input_ids) != max_seq_length:
1034
print(len(input_ids))
1036
if len(input_ids_org) != max_seq_length:
1037
print(len(input_ids_org))
1039
if len(input_mask) != max_seq_length:
1040
print(len(input_mask))
1042
if len(segment_ids) != max_seq_length:
1043
print(len(segment_ids))
1045
if len(lm_label_ids) != max_seq_length:
1046
print(len(lm_label_ids))
1054
if example.guid < 5:
1055
logger.info("*** Example ***")
1056
logger.info("guid: %s" % (example.guid))
1057
logger.info("tokens: %s" % " ".join(
1058
[str(x) for x in tokens]))
1059
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
1060
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
1062
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
1063
logger.info("LM label: %s " % (lm_label_ids))
1064
logger.info("Is next sentence label: %s " % (example.is_next))
1067
features = InputFeatures(input_ids=input_ids,
1068
input_ids_org = input_ids_org,
1069
input_mask=input_mask,
1070
segment_ids=segment_ids,
1071
lm_label_ids=lm_label_ids,
1072
is_next=example.is_next,
1073
tail_idxs=tail_idxs)
1078
parser = argparse.ArgumentParser()
1080
parser = get_parameter(parser)
1082
args = parser.parse_args()
1084
if args.local_rank == -1 or args.no_cuda:
1085
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1086
n_gpu = torch.cuda.device_count()
1088
torch.cuda.set_device(args.local_rank)
1089
device = torch.device("cuda", args.local_rank)
1092
torch.distributed.init_process_group(backend='nccl')
1093
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
1094
device, n_gpu, bool(args.local_rank != -1), args.fp16))
1096
if args.gradient_accumulation_steps < 1:
1097
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
1098
args.gradient_accumulation_steps))
1100
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
1102
random.seed(args.seed)
1103
np.random.seed(args.seed)
1104
torch.manual_seed(args.seed)
1106
torch.cuda.manual_seed_all(args.seed)
1108
if not args.do_train:
1109
raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
1117
tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)
1121
num_train_optimization_steps = None
1123
print("Loading Train Dataset", args.data_dir_indomain)
1125
all_type_sentence, train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
1126
num_train_optimization_steps = int(
1127
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
1128
if args.local_rank != -1:
1129
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
1134
model = RobertaForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1142
param_optimizer = list(model.named_parameters())
1144
for par in param_optimizer:
1148
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
1149
optimizer_grouped_parameters = [
1150
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
1151
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
1153
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
1154
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
1158
from apex import amp
1160
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
1163
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
1167
model = torch.nn.DataParallel(model)
1169
if args.local_rank != -1:
1170
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
1176
logger.info("***** Running training *****")
1177
logger.info(" Num examples = %d", len(train_dataset))
1178
logger.info(" Batch size = %d", args.train_batch_size)
1179
logger.info(" Num steps = %d", num_train_optimization_steps)
1181
if args.local_rank == -1:
1182
train_sampler = RandomSampler(train_dataset)
1187
train_sampler = DistributedSampler(train_dataset)
1189
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1192
output_loss_file = os.path.join(args.output_dir, "loss")
1193
loss_fout = open(output_loss_file, 'w')
1196
output_loss_file_no_pseudo = os.path.join(args.output_dir, "loss_no_pseudo")
1197
loss_fout_no_pseudo = open(output_loss_file_no_pseudo, 'w')
1218
all_type_sentence_label = list()
1219
all_previous_sentence_label = list()
1220
all_type_sentiment_label = list()
1221
all_previous_sentiment_label = list()
1222
top_k_all_type = dict()
1223
bottom_k_all_type = dict()
1227
for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1229
nb_tr_examples, nb_tr_steps = 0, 0
1230
for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1237
if (step == 0) and (epo == 0):
1241
input_ids_ = torch.stack([line[0] for line in all_type_sentence]).to(device)
1242
input_ids_org_ = torch.stack([line[1] for line in all_type_sentence]).to(device)
1243
input_mask_ = torch.stack([line[2] for line in all_type_sentence]).to(device)
1244
segment_ids_ = torch.stack([line[3] for line in all_type_sentence]).to(device)
1245
lm_label_ids_ = torch.stack([line[4] for line in all_type_sentence]).to(device)
1246
is_next_ = torch.stack([line[5] for line in all_type_sentence]).to(device)
1247
tail_idxs_ = torch.stack([line[6] for line in all_type_sentence]).to(device)
1248
sentence_label_ = torch.stack([line[7] for line in all_type_sentence]).to(device)
1249
sentiment_label_ = torch.stack([line[8] for line in all_type_sentence]).to(device)
1251
with torch.no_grad():
1254
in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1257
query_domain = in_domain_rep.float().to("cpu")
1258
query_domain = query_domain.unsqueeze(1)
1260
query_task = in_task_rep.float().to("cpu")
1261
query_task = query_task.unsqueeze(1)
1265
task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1266
task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1267
task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1268
task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1271
domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1272
domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1273
domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1274
domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768*2, 2)
1283
query_domain = query_domain.expand(-1, docs_tail.shape[0], -1)
1284
query_task = query_task.expand(-1, docs_head.shape[0], -1)
1294
domain_binary_logit = LeakyReLU(domain_binary_classifier(docs_tail))
1295
domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1296
domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1298
domain_binary_logit = domain_binary_classifier(torch.cat([query_domain, docs_tail[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1299
target = torch.zeros(domain_binary_logit.shape[0], domain_binary_logit.shape[1], dtype=torch.long)
1301
domain_binary_logit = ce_loss(domain_binary_logit.view(-1, 2), target.view(-1)).reshape(domain_binary_logit.shape[0],domain_binary_logit.shape[1])
1304
task_binary_logit = task_binary_classifier(torch.cat([query_task, docs_head[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1307
task_binary_logit = ce_loss(task_binary_logit.view(-1, 2), target.view(-1)).reshape(task_binary_logit.shape[0],task_binary_logit.shape[1])
1311
domain_task_binary_logit = task_binary_logit
1317
domain_top_k_all_type = torch.topk(domain_binary_logit, k, dim=1, largest=True, sorted=False)
1318
perm = torch.randperm(domain_binary_logit.shape[1])
1319
domain_bottom_k_all_type_indices = perm[:k]
1320
domain_bottom_k_all_type_values = domain_binary_logit[:,domain_bottom_k_all_type_indices]
1321
domain_bottom_k_all_type_indices = torch.stack(args.domain_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1325
task_top_k_all_type = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1327
domain_task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1334
domain_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1336
rand_seed = torch.randint(0,k,(choose_n,))
1337
domain_top_k_all_type_indices = domain_top_k_all_type.indices[:,rand_seed]
1338
domain_top_k_all_type_values = domain_top_k_all_type.values[:,rand_seed]
1348
domain_bottom_k_all_type_indices = torch.randint(k+1,domain_binary_logit.shape[1],(choose_n*2,))
1349
domain_bottom_k_all_type_values = domain_task_binary_logit[:,domain_bottom_k_all_type_indices]
1350
domain_bottom_k_all_type_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1354
task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1356
rand_seed = torch.randint(0,k,(choose_n,))
1357
task_top_k_all_type_indices = task_top_k_all_type.indices[:,rand_seed]
1358
task_top_k_all_type_values = task_top_k_all_type.values[:,rand_seed]
1363
domain_task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1365
rand_seed = torch.randint(0,k,(choose_n,))
1366
domain_task_top_k_all_type_indices = domain_task_top_k_all_type.indices[:,rand_seed]
1367
domain_task_top_k_all_type_values = domain_task_top_k_all_type.values[:,rand_seed]
1374
del domain_task_binary_logit, domain_binary_logit, task_binary_logit
1376
all_type_sentiment_label = sentiment_label_.to('cpu')
1379
domain_bottom_k_all_type = {"values":domain_bottom_k_all_type_values, "indices":domain_bottom_k_all_type_indices}
1381
domain_top_k_all_type = {"values":domain_top_k_all_type_values, "indices":domain_top_k_all_type_indices}
1383
task_top_k_all_type = {"values":task_top_k_all_type_values, "indices":task_top_k_all_type_indices}
1385
domain_task_top_k_all_type = {"values":domain_task_top_k_all_type_values, "indices":domain_task_top_k_all_type_indices}
1392
batch_ = tuple(t.to(device) for t in batch_)
1393
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = batch_
1398
in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1402
if (step%retrive_gate == 0) or (sentiment_label_.shape[0] != args.train_batch_size):
1404
with torch.no_grad():
1405
query_domain = in_domain_rep.float().to("cpu")
1406
query_domain = query_domain.unsqueeze(1)
1408
query_task = in_task_rep.float().to("cpu")
1409
query_task = query_task.unsqueeze(1)
1410
query_domain_task = torch.cat([query_domain,query_task],2)
1413
task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1414
task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1415
task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1416
task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1419
domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1420
domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1421
domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1422
domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768*2, 2)
1432
query_domain = query_domain.expand(-1, docs_tail.shape[0], -1)
1434
query_task = query_task.expand(-1, docs_head.shape[0], -1)
1447
domain_binary_logit = LeakyReLU(domain_binary_classifier(docs_tail))
1448
domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1449
domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1451
domain_binary_logit = domain_binary_classifier(torch.cat([query_domain, docs_tail[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1452
target = torch.zeros(domain_binary_logit.shape[0], domain_binary_logit.shape[1], dtype=torch.long)
1454
domain_binary_logit = ce_loss(domain_binary_logit.view(-1, 2), target.view(-1)).reshape(domain_binary_logit.shape[0],domain_binary_logit.shape[1])
1457
task_binary_logit = task_binary_classifier(torch.cat([query_task, docs_head[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1459
task_binary_logit = ce_loss(task_binary_logit.view(-1, 2), target.view(-1)).reshape(task_binary_logit.shape[0],task_binary_logit.shape[1])
1462
domain_task_binary_logit = task_binary_logit + domain_binary_logit*0.5
1469
domain_top_k = torch.topk(domain_binary_logit, k, dim=1, largest=True, sorted=False)
1471
rand_seed = torch.randint(0,k,(choose_n,))
1472
domain_top_k_indices = domain_top_k.indices[:,rand_seed]
1473
domain_top_k_values = domain_top_k.values[:,rand_seed]
1477
perm = torch.randperm(domain_binary_logit.shape[1])
1478
domain_bottom_k_indices = perm[:k]
1479
domain_bottom_k_values = domain_binary_logit[:,domain_bottom_k_indices]
1480
domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1484
domain_bottom_k_indices = torch.randint(k+1,domain_binary_logit.shape[1],(choose_n*2,))
1485
domain_bottom_k_values = domain_task_binary_logit[:,domain_bottom_k_indices]
1486
domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1489
task_top_k = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1492
task_top_k_indices = task_top_k.indices[:,rand_seed]
1493
task_top_k_values = task_top_k.values[:,rand_seed]
1497
domain_task_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1499
domain_task_top_k_indices = domain_task_top_k.indices[:,rand_seed]
1500
domain_task_top_k_values = domain_task_top_k.values[:,rand_seed]
1506
domain_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1507
perm = torch.randperm(domain_task_binary_logit.shape[1])
1508
domain_bottom_k_indices = perm[:k]
1509
domain_bottom_k_values = domain_task_binary_logit[:,domain_bottom_k_indices]
1510
domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1511
task_top_k = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1512
domain_task_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1517
del domain_task_binary_logit, domain_binary_logit, task_binary_logit
1519
all_previous_sentiment_label = sentiment_label_.to('cpu')
1524
domain_bottom_k = {"values":domain_bottom_k_values, "indices":domain_bottom_k_indices}
1526
domain_top_k = {"values":domain_top_k_values, "indices":domain_top_k_indices}
1528
task_top_k = {"values":task_top_k_values, "indices":task_top_k_indices}
1530
domain_task_top_k = {"values":domain_task_top_k_values, "indices":domain_task_top_k_indices}
1535
domain_bottom_k_previous = {"values":torch.cat((domain_bottom_k["values"], domain_bottom_k_all_type["values"]),0), "indices":torch.cat((domain_bottom_k["indices"], domain_bottom_k_all_type["indices"]),0)}
1536
domain_top_k_previous = {"values":torch.cat((domain_top_k["values"], domain_top_k_all_type["values"]),0), "indices":torch.cat((domain_top_k["indices"], domain_top_k_all_type["indices"]),0)}
1537
task_top_k_previous = {"values":torch.cat((task_top_k["values"], task_top_k_all_type["values"]),0), "indices":torch.cat((task_top_k["indices"], task_top_k_all_type["indices"]),0)}
1538
domain_task_top_k_previous = {"values":torch.cat((domain_task_top_k["values"], domain_task_top_k_all_type["values"]),0), "indices":torch.cat((domain_task_top_k["indices"], domain_task_top_k_all_type["indices"]),0)}
1540
all_previous_sentiment_label = torch.cat((all_previous_sentiment_label, all_type_sentiment_label))
1543
used_idx = torch.tensor([random.choice(((all_previous_sentiment_label==int(idx_)).nonzero()).tolist())[0] for idx_ in sentiment_label_])
1545
domain_top_k = {"values":domain_top_k_previous["values"].index_select(0,used_idx), "indices":domain_top_k_previous["indices"].index_select(0,used_idx)}
1546
task_top_k = {"values":task_top_k_previous["values"].index_select(0,used_idx), "indices":task-top_k_previous["indices"].index_select(0,used_idx)}
1547
domain_task_top_k = {"values":domain_task_top_k_previous["values"].index_select(0,used_idx), "indices":domain_task_top_k_previous["indices"].index_select(0,used_idx)}
1550
domaion_bottom_k = {"values":domain_bottom_k_previous["values"].index_select(0,used_idx), "indices":domain_bottom_k_previous["indices"].index_select(0,used_idx)}
1562
batch = AugmentationData_Domain(domain_top_k, domain_bottom_k, tokenizer, args.max_seq_length)
1563
batch = tuple(t.to(device) for t in batch)
1564
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, domain_id = batch
1566
out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, lm_label=lm_label_ids, attention_mask=input_mask, func="in_domain_task_rep")
1574
comb_rep_pos = torch.cat([in_domain_rep,in_domain_rep.flip(0)], 1)
1575
in_domain_rep_ready = in_domain_rep.repeat(1,int(out_domain_rep_tail.shape[0]/in_domain_rep.shape[0])).reshape(out_domain_rep_tail.shape[0],out_domain_rep_tail.shape[1])
1576
comb_rep_unknow = torch.cat([in_domain_rep_ready, out_domain_rep_tail], 1)
1578
mix_domain_binary_loss, domain_binary_logit = model(func="domain_binary_classifier", in_domain_rep=comb_rep_pos.to(device), out_domain_rep=comb_rep_unknow.to(device), domain_id=domain_id, use_detach=False)
1585
indices = domain_top_k["indices"].reshape(domain_top_k["indices"].shape[0]*domain_top_k["indices"].shape[1])
1586
indices_ = domain_bottom_k["indices"].reshape(domain_bottom_k["indices"].shape[0]*domain_bottom_k["indices"].shape[1])
1587
indices = torch.cat([indices,indices_],0)
1589
out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1590
out_domain_rep_head.requires_grad=True
1592
out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1593
out_domain_rep_tail.requires_grad=True
1596
with torch.no_grad():
1599
docs_head.index_copy_(0, indices, out_domain_rep_head)
1600
docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1602
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
1603
print("head",out_domain_rep_head.shape)
1604
print("tail",out_domain_rep_head.shape)
1605
print("doc_h",docs_head.shape)
1606
print("doc_t",docs_tail.shape)
1607
print("ind",indices.shape)
1615
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1617
batch = tuple(t.to(device) for t in batch)
1618
all_in_task_rep_comb, all_sentence_binary_label = batch
1619
in_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1626
task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1632
batch = AugmentationData_Task(task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1633
batch = tuple(t.to(device) for t in batch)
1634
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1635
out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1637
batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head)
1638
batch = tuple(t.to(device) for t in batch)
1639
all_in_task_rep_comb, all_sentence_binary_label = batch
1640
out_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1646
indices = task_top_k["indices"].reshape(task_top_k["indices"].shape[0]*task_top_k["indices"].shape[1])
1648
out_domain_rep_head = out_domain_rep_head[input_ids_org_.shape[0]:,:]
1649
out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1650
out_domain_rep_head.requires_grad=True
1652
out_domain_rep_tail = out_domain_rep_tail[input_ids_org_.shape[0]:,:]
1653
out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1654
out_domain_rep_tail.requires_grad=True
1656
with torch.no_grad():
1658
docs_head.index_copy_(0, indices, out_domain_rep_head)
1659
docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1661
print("head",out_domain_rep_head.shape)
1662
print("head",out_domain_rep_head.get_device())
1663
print("tail",out_domain_rep_head.shape)
1664
print("tail",out_domain_rep_head.get_device())
1665
print("doc_h",docs_head.shape)
1666
print("doc_h",docs_head.get_device())
1667
print("doc_t",docs_tail.shape)
1668
print("doc_t",docs_tail.get_device())
1669
print("ind",indices.shape)
1670
print("ind",indices.get_device())
1678
batch = AugmentationData_Task(domain_task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1679
batch = tuple(t.to(device) for t in batch)
1680
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1681
out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1683
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head, in_domain_rep=out_domain_rep_tail)
1684
batch = tuple(t.to(device) for t in batch)
1685
all_in_task_rep_comb, all_sentence_binary_label = batch
1686
out_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1692
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1693
batch = tuple(t.to(device) for t in batch)
1694
in_all_in_task_rep_comb, in_all_sentence_binary_label = batch
1695
in_domain_task_binary_loss, in_domain_task_binary_logit = model(all_in_task_rep_comb=in_all_in_task_rep_comb, all_sentence_binary_label=in_all_sentence_binary_label, func="domain_task_binary_classifier")
1702
indices = domain_task_top_k["indices"].reshape(domain_task_top_k["indices"].shape[0]*domain_task_top_k["indices"].shape[1])
1704
out_domain_rep_head = out_domain_rep_head[input_ids_org_.shape[0]:,:]
1705
out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1706
out_domain_rep_head.requires_grad=True
1708
out_domain_rep_tail = out_domain_rep_tail[input_ids_org_.shape[0]:,:]
1709
out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1710
out_domain_rep_tail.requires_grad=True
1712
with torch.no_grad():
1714
docs_head.index_copy_(0, indices, out_domain_rep_head)
1715
docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1717
print("head",out_domain_rep_head.shape)
1718
print("head",out_domain_rep_head.get_device())
1719
print("tail",out_domain_rep_head.shape)
1720
print("tail",out_domain_rep_head.get_device())
1721
print("doc_h",docs_head.shape)
1722
print("doc_h",docs_head.get_device())
1723
print("doc_t",docs_tail.shape)
1724
print("doc_t",docs_tail.get_device())
1725
print("ind",indices.shape)
1726
print("ind",indices.get_device())
1734
batch = AugmentationData_Task(domain_task_top_k, tokenizer, args.max_seq_length, add_org=batch_, show_type=show_type)
1736
_, input_ids_org11, _, _, _, _, _, _, sentiment_label11 = batch_
1737
for indexxx, label in enumerate(sentiment_label11):
1738
if int(label) in show_type and list(input_ids_org11[indexxx]).index(2)>20:
1741
if list(input_ids_org11[indexxx]).index(2)>10:
1742
show_type.remove(int(label))
1745
if len(show_type)==0:
1750
batch = tuple(t.to(device) for t in batch)
1751
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1752
out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1754
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head, in_domain_rep=out_domain_rep_tail)
1755
batch = tuple(t.to(device) for t in batch)
1756
all_in_task_rep_comb, all_sentence_binary_label = batch
1757
out_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1760
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1761
batch = tuple(t.to(device) for t in batch)
1762
in_all_in_task_rep_comb, in_all_sentence_binary_label = batch
1763
in_domain_task_binary_loss, in_domain_task_binary_logit = model(all_in_task_rep_comb=in_all_in_task_rep_comb, all_sentence_binary_label=in_all_sentence_binary_label, func="domain_task_binary_classifier")
1767
task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1777
loss = mix_domain_binary_loss.mean()*0.5 + (in_task_binary_loss.mean() + out_task_binary_loss.mean())*0.5 + task_loss_org.mean() + out_domain_task_binary_loss
1780
print("No Using GPU")
1783
loss = task_loss_org.mean() + (in_domain_task_binary_loss.mean()+out_domain_task_binary_loss.mean())/2
1786
print("No Using GPU")
1789
if args.gradient_accumulation_steps > 1:
1790
loss = loss / args.gradient_accumulation_steps
1792
with amp.scale_loss(loss, optimizer) as scaled_loss:
1793
scaled_loss.backward()
1798
loss_fout.write("{}\n".format(loss.item()))
1805
tr_loss += loss.item()
1807
nb_tr_examples += input_ids_.size(0)
1809
if (step + 1) % args.gradient_accumulation_steps == 0:
1814
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1817
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1849
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1850
"""Truncates a sequence pair in place to the maximum length."""
1858
total_length = len(tokens_a)
1859
if total_length <= max_length:
1865
def accuracy(out, labels):
1866
outputs = np.argmax(out, axis=1)
1867
return np.sum(outputs == labels)
1870
if __name__ == "__main__":