2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
9
# http://www.apache.org/licenses/LICENSE-2.0
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
18
from __future__ import absolute_import, division, print_function, unicode_literals
27
from torch.autograd import Variable
28
import torch.nn.functional as F
32
from torch.utils.data import DataLoader, Dataset, RandomSampler
33
from torch.utils.data.distributed import DistributedSampler
34
from tqdm import tqdm, trange
35
from torch.nn import CrossEntropyLoss
37
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
38
#from transformers.modeling_bert import BertForMaskedLMDomainTask
39
from transformers.modeling_bert_updateRep_self import BertForMaskedLMDomainTask
40
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
42
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
43
datefmt='%m/%d/%Y %H:%M:%S',
45
logger = logging.getLogger(__name__)
48
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
51
def get_parameter(parser):
53
## Required parameters
54
parser.add_argument("--data_dir_indomain",
58
help="The input train corpus.(In Domain)")
59
parser.add_argument("--data_dir_outdomain",
63
help="The input train corpus.(Out Domain)")
64
parser.add_argument("--pretrain_model", default=None, type=str, required=True,
65
help="Bert pre-trained model selected in the list: bert-base-uncased, "
66
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
67
parser.add_argument("--output_dir",
71
help="The output directory where the model checkpoints will be written.")
72
parser.add_argument("--augment_times",
76
help="Default batch_size/augment_times to save model")
78
parser.add_argument("--max_seq_length",
81
help="The maximum total input sequence length after WordPiece tokenization. \n"
82
"Sequences longer than this will be truncated, and sequences shorter \n"
83
"than this will be padded.")
84
parser.add_argument("--do_train",
86
help="Whether to run training.")
87
parser.add_argument("--train_batch_size",
90
help="Total batch size for training.")
91
parser.add_argument("--learning_rate",
94
help="The initial learning rate for Adam.")
95
parser.add_argument("--num_train_epochs",
98
help="Total number of training epochs to perform.")
99
parser.add_argument("--warmup_proportion",
102
help="Proportion of training to perform linear learning rate warmup for. "
103
"E.g., 0.1 = 10%% of training.")
104
parser.add_argument("--no_cuda",
106
help="Whether not to use CUDA when available")
107
parser.add_argument("--on_memory",
109
help="Whether to load train samples into memory or use disk")
110
parser.add_argument("--do_lower_case",
112
help="Whether to lower case the input text. True for uncased models, False for cased models.")
113
parser.add_argument("--local_rank",
116
help="local_rank for distributed training on gpus")
117
parser.add_argument('--seed',
120
help="random seed for initialization")
121
parser.add_argument('--gradient_accumulation_steps',
124
help="Number of updates steps to accumualte before performing a backward/update pass.")
125
parser.add_argument('--fp16',
127
help="Whether to use 16-bit float precision instead of 32-bit")
128
parser.add_argument('--loss_scale',
129
type = float, default = 0,
130
help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
131
"0 (default value): dynamic loss scaling.\n"
132
"Positive power of 2: static loss scaling value.\n")
134
parser.add_argument("--num_labels_task",
135
default=None, type=int,
137
help="num_labels_task")
138
parser.add_argument("--weight_decay",
141
help="Weight decay if we apply some.")
142
parser.add_argument("--adam_epsilon",
145
help="Epsilon for Adam optimizer.")
146
parser.add_argument("--max_grad_norm",
149
help="Max gradient norm.")
150
parser.add_argument('--fp16_opt_level',
153
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
154
"See details at https://nvidia.github.io/apex/amp.html")
155
parser.add_argument("--task",
160
parser.add_argument("--K",
169
def return_Classifier(weight, bias, dim_in, dim_out):
170
#LeakyReLU = torch.nn.LeakyReLU
171
classifier = torch.nn.Linear(dim_in, dim_out , bias=True)
173
#print(classifier.weight)
174
#print(classifier.weight.shape)
175
#print(classifier.weight.data)
176
#print(classifier.weight.data.shape)
178
classifier.weight.data = weight.to("cpu")
179
classifier.bias.data = bias.to("cpu")
180
classifier.requires_grad=False
182
#print(classifier.weight)
183
#print(classifier.weight.shape)
188
#logit = LeakyReLU(classifier)
192
def load_GeneralDomain(dir_data_out):
196
print("Load CLS.pt and train.json")
198
docs_head = torch.load(dir_data_out+"train_head.pt")
199
docs_tail = torch.load(dir_data_out+"train_tail.pt")
201
print(docs_head.shape)
202
print(docs_tail.shape)
204
with open(dir_data_out+"train.json") as file:
205
data = json.load(file)
206
print("train.json Done")
208
docs_tail_head = torch.cat([docs_tail, docs_head],2)
209
return docs_tail_head, docs_head, docs_tail, data
213
parser = argparse.ArgumentParser()
214
parser = get_parameter(parser)
215
args = parser.parse_args()
216
#print(args.data_dir_outdomain)
219
docs_tail_head, docs_head, docs_tail, data = load_GeneralDomain(args.data_dir_outdomain)
221
if docs_head.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
223
#docs = docs[:,0,:].unsqueeze(1)
225
docs_head = docs_head.mean(1).unsqueeze(1)
226
print(docs_head.shape)
228
print(docs_head.shape)
229
if docs_tail.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
231
#docs = docs[:,0,:].unsqueeze(1)
233
docs_tail = docs_tail.mean(1).unsqueeze(1)
234
print(docs_tail.shape)
236
print(docs_tail.shape)
239
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
241
with open(data_dir_indomain+"train.json") as file:
242
data = json.load(file)
245
num_label_list = list()
246
label_sentence_dict = dict()
247
num_sentiment_label_list = list()
248
sentiment_label_dict = dict()
253
num_sentiment_label_list.append(line["sentiment"])
254
#num_label_list.append(line["aspect"])
255
num_label_list.append(line["sentiment"])
257
num_label = sorted(list(set(num_label_list)))
258
label_map = {label : i for i , label in enumerate(num_label)}
259
num_sentiment_label = sorted(list(set(num_sentiment_label_list)))
260
sentiment_label_map = {label : i for i , label in enumerate(num_sentiment_label)}
266
print("sentiment_label_map:")
267
print(sentiment_label_map)
270
###Create data: 1 choosed data along with the rest of 7 class data
273
all_input_ids = list()
274
all_input_mask = list()
275
all_segment_ids = list()
276
all_lm_labels_ids = list()
278
all_tail_idxs = list()
279
all_sentence_labels = list()
281
cur_tensors_list = list()
282
#print(list(label_map.values()))
283
candidate_label_list = list(label_map.values())
284
candidate_sentiment_label_list = list(sentiment_label_map.values())
285
all_type_sentence = [0]*len(candidate_label_list)
286
all_type_sentiment_sentence = [0]*len(candidate_sentiment_label_list)
290
sentiment = line["sentiment"]
291
sentence = line["sentence"]
292
#label = line["aspect"]
293
label = line["sentiment"]
296
tokens_a = tokenizer.tokenize(sentence)
297
#input_ids = tokenizer.encode(sentence, add_special_tokens=False)
299
if "</s>" in tokens_a:
300
print("Have more than 1 </s>")
301
#tokens_a[tokens_a.index("<s>")] = "s"
302
for i in range(len(tokens_a)):
303
if tokens_a[i] == "</s>":
309
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
310
# transform sample to features
311
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
313
cur_tensors = (torch.tensor(cur_features.input_ids),
314
torch.tensor(cur_features.input_ids_org),
315
torch.tensor(cur_features.input_mask),
316
torch.tensor(cur_features.segment_ids),
317
torch.tensor(cur_features.lm_label_ids),
319
torch.tensor(cur_features.tail_idxs),
320
torch.tensor(label_map[label]),
321
torch.tensor(sentiment_label_map[sentiment]))
323
cur_tensors_list.append(cur_tensors)
326
if label_map[label] in candidate_label_list:
327
all_type_sentence[label_map[label]]=cur_tensors
328
candidate_label_list.remove(label_map[label])
330
if sentiment_label_map[sentiment] in candidate_sentiment_label_list:
332
#print(sentiment_label_map[sentiment])
334
all_type_sentiment_sentence[sentiment_label_map[sentiment]]=cur_tensors
335
candidate_sentiment_label_list.remove(sentiment_label_map[sentiment])
341
return all_type_sentiment_sentence, cur_tensors_list
345
def AugmentationData_Domain(bottom_k, top_k, tokenizer, max_seq_length):
346
#top_k_shape = top_k.indices.shape
347
#ids = top_k.indices.reshape(top_k_shape[0]*top_k_shape[1]).tolist()
348
top_k_shape = top_k["indices"].shape
349
ids_pos = top_k["indices"].reshape(top_k_shape[0]*top_k_shape[1]).tolist()
350
#ids = top_k["indices"]
352
bottom_k_shape = bottom_k["indices"].shape
353
ids_neg = bottom_k["indices"].reshape(bottom_k_shape[0]*bottom_k_shape[1]).tolist()
359
ids = ids_pos+ids_neg
362
all_input_ids = list()
363
all_input_ids_org = list()
364
all_input_mask = list()
365
all_segment_ids = list()
366
all_lm_labels_ids = list()
368
all_tail_idxs = list()
369
all_id_domain = list()
371
for id, i in enumerate(ids):
372
t1 = data[str(i)]['sentence']
374
#tokens_a = tokenizer.tokenize(t1)
375
tokens_a = tokenizer.tokenize(t1)
377
if "</s>" in tokens_a:
378
print("Have more than 1 </s>")
379
#tokens_a[tokens_a.index("<s>")] = "s"
380
for i in range(len(tokens_a)):
381
if tokens_a[i] == "</s>":
386
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
388
# transform sample to features
389
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
391
all_input_ids.append(torch.tensor(cur_features.input_ids))
392
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
393
all_input_mask.append(torch.tensor(cur_features.input_mask))
394
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
395
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
396
all_is_next.append(torch.tensor(0))
397
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
399
all_id_domain.append(torch.tensor(0))
401
all_id_domain.append(torch.tensor(1))
404
cur_tensors = (torch.stack(all_input_ids),
405
torch.stack(all_input_ids_org),
406
torch.stack(all_input_mask),
407
torch.stack(all_segment_ids),
408
torch.stack(all_lm_labels_ids),
409
torch.stack(all_is_next),
410
torch.stack(all_tail_idxs),
411
torch.stack(all_id_domain))
416
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None):
417
top_k_shape = top_k["indices"].shape
418
sentence_ids = top_k["indices"]
420
all_input_ids = list()
421
all_input_ids_org = list()
422
all_input_mask = list()
423
all_segment_ids = list()
424
all_lm_labels_ids = list()
426
all_tail_idxs = list()
427
all_sentence_labels = list()
428
all_sentiment_labels = list()
430
add_org = tuple(t.to('cpu') for t in add_org)
431
#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
432
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
435
#print("input_ids_",input_ids_.shape)
437
#print("sentence_ids",sentence_ids.shape)
439
#print("sentence_label_",sentence_label_.shape)
443
for id_1, sent in enumerate(sentence_ids):
444
for id_2, sent_id in enumerate(sent):
446
t1 = data[str(int(sent_id))]['sentence']
448
tokens_a = tokenizer.tokenize(t1)
451
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
453
# transform sample to features
454
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
456
all_input_ids.append(torch.tensor(cur_features.input_ids))
457
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
458
all_input_mask.append(torch.tensor(cur_features.input_mask))
459
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
460
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
461
all_is_next.append(torch.tensor(0))
462
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
463
all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
464
all_sentiment_labels.append(torch.tensor(sentiment_label_[id_1]))
466
all_input_ids.append(input_ids_[id_1])
467
all_input_ids_org.append(input_ids_org_[id_1])
468
all_input_mask.append(input_mask_[id_1])
469
all_segment_ids.append(segment_ids_[id_1])
470
all_lm_labels_ids.append(lm_label_ids_[id_1])
471
all_is_next.append(is_next_[id_1])
472
all_tail_idxs.append(tail_idxs_[id_1])
473
all_sentence_labels.append(sentence_label_[id_1])
474
all_sentiment_labels.append(sentiment_label_[id_1])
477
cur_tensors = (torch.stack(all_input_ids),
478
torch.stack(all_input_ids_org),
479
torch.stack(all_input_mask),
480
torch.stack(all_segment_ids),
481
torch.stack(all_lm_labels_ids),
482
torch.stack(all_is_next),
483
torch.stack(all_tail_idxs),
484
torch.stack(all_sentence_labels),
485
torch.stack(all_sentiment_labels)
495
def AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None, in_domain_rep=None):
497
top_k_shape = top_k.indices.shape
498
sentence_ids = top_k.indices
500
#top_k_shape = top_k["indices"].shape
501
#sentence_ids = top_k["indices"]
504
#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
505
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
508
#uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
510
all_sentence_binary_label = list()
511
#all_in_task_rep_comb = list()
512
all_in_rep_comb = list()
514
for id_1, num in enumerate(sentence_label_):
515
#print([sentence_label_==num])
516
#print(type([sentence_label_==num]))
517
sentence_label_int = (sentence_label_==num).to(torch.long)
518
#print(sentence_label_int)
519
#print(sentence_label_int.shape)
520
#print(in_task_rep[id_1].shape)
521
#print(in_task_rep.shape)
523
in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
524
in_domain_rep_append = in_domain_rep[id_1].unsqueeze(0).expand(in_domain_rep.shape[0],-1)
525
#print(in_task_rep_append)
526
#print(in_task_rep_append.shape)
527
in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
528
in_domain_rep_comb = torch.cat((in_domain_rep_append,in_domain_rep),-1)
529
#print(in_task_rep_comb)
530
#print(in_task_rep_comb.shape)
532
#sentence_label_int = sentence_label_int.to(torch.float32)
533
#print(sentence_label_int)
535
#all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
536
#all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
537
#print(in_task_rep_comb.shape)
538
#print(in_domain_rep_comb.shape)
539
in_rep_comb = torch.cat([in_domain_rep_comb,in_task_rep_comb],-1)
542
all_sentence_binary_label.append(sentence_label_int)
543
#all_in_task_rep_comb.append(in_task_rep_comb)
544
all_in_rep_comb.append(in_rep_comb)
545
all_sentence_binary_label = torch.stack(all_sentence_binary_label)
546
#all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
547
all_in_rep_comb = torch.stack(all_in_rep_comb)
549
#cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
550
cur_tensors = (all_in_rep_comb, all_sentence_binary_label)
557
def AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None):
559
top_k_shape = top_k.indices.shape
560
sentence_ids = top_k.indices
562
#top_k_shape = top_k["indices"].shape
563
#sentence_ids = top_k["indices"]
566
#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
567
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
570
#uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
572
all_sentence_binary_label = list()
573
all_in_task_rep_comb = list()
575
for id_1, num in enumerate(sentence_label_):
576
#print([sentence_label_==num])
577
#print(type([sentence_label_==num]))
578
sentence_label_int = (sentence_label_==num).to(torch.long)
579
#print(sentence_label_int)
580
#print(sentence_label_int.shape)
581
#print(in_task_rep[id_1].shape)
582
#print(in_task_rep.shape)
584
in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
585
#print(in_task_rep_append)
586
#print(in_task_rep_append.shape)
587
in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
588
#print(in_task_rep_comb)
589
#print(in_task_rep_comb.shape)
591
#sentence_label_int = sentence_label_int.to(torch.float32)
592
#print(sentence_label_int)
594
#all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
595
#all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
596
all_sentence_binary_label.append(sentence_label_int)
597
all_in_task_rep_comb.append(in_task_rep_comb)
598
all_sentence_binary_label = torch.stack(all_sentence_binary_label)
599
all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
601
cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
608
class Dataset_noNext(Dataset):
609
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
611
self.vocab_size = tokenizer.vocab_size
612
self.tokenizer = tokenizer
613
self.seq_len = seq_len
614
self.on_memory = on_memory
615
self.corpus_lines = corpus_lines # number of non-empty lines in input corpus
616
self.corpus_path = corpus_path
617
self.encoding = encoding
618
self.current_doc = 0 # to avoid random sentence from same doc
620
# for loading samples directly from file
621
self.sample_counter = 0 # used to keep track of full epochs on file
622
self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair
624
# for loading samples in memory
625
self.current_random_doc = 0
627
self.sample_to_doc = [] # map sample index to doc and line
629
# load samples into memory
633
self.corpus_lines = 0
634
with open(corpus_path, "r", encoding=encoding) as f:
635
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
638
self.all_docs.append(doc)
640
#remove last added sample because there won't be a subsequent line anymore in the doc
641
self.sample_to_doc.pop()
644
sample = {"doc_id": len(self.all_docs),
646
self.sample_to_doc.append(sample)
648
self.corpus_lines = self.corpus_lines + 1
650
# if last row in file is not empty
651
if self.all_docs[-1] != doc:
652
self.all_docs.append(doc)
653
self.sample_to_doc.pop()
655
self.num_docs = len(self.all_docs)
657
# load samples later lazily from disk
659
if self.corpus_lines is None:
660
with open(corpus_path, "r", encoding=encoding) as f:
661
self.corpus_lines = 0
662
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
663
if line.strip() == "":
666
self.corpus_lines += 1
668
# if doc does not end with empty line
669
if line.strip() != "":
672
self.file = open(corpus_path, "r", encoding=encoding)
673
self.random_file = open(corpus_path, "r", encoding=encoding)
676
# last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
677
return self.corpus_lines - self.num_docs - 1
679
def __getitem__(self, item):
680
cur_id = self.sample_counter
681
self.sample_counter += 1
682
if not self.on_memory:
683
# after one epoch we start again from beginning of file
684
if cur_id != 0 and (cur_id % len(self) == 0):
686
self.file = open(self.corpus_path, "r", encoding=self.encoding)
688
#t1, t2, is_next_label = self.random_sent(item)
689
t1, is_next_label = self.random_sent(item)
690
if is_next_label == None:
694
#tokens_a = self.tokenizer.tokenize(t1)
695
tokens_a = tokenizer.tokenize(t1)
697
if "</s>" in tokens_a:
698
print("Have more than 1 </s>")
699
#tokens_a[tokens_a.index("<s>")] = "s"
700
for i in range(len(tokens_a)):
701
if tokens_a[i] == "</s>":
704
#tokens_b = self.tokenizer.tokenize(t2)
707
cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
709
# transform sample to features
710
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
712
cur_tensors = (torch.tensor(cur_features.input_ids),
713
torch.tensor(cur_features.input_ids_org),
714
torch.tensor(cur_features.input_mask),
715
torch.tensor(cur_features.segment_ids),
716
torch.tensor(cur_features.lm_label_ids),
717
torch.tensor(cur_features.is_next),
718
torch.tensor(cur_features.tail_idxs))
722
def random_sent(self, index):
724
Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
725
from one doc. With 50% the second sentence will be a random one from another doc.
726
:param index: int, index of sample.
727
:return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
729
t1, t2 = self.get_corpus_line(index)
732
def get_corpus_line(self, item):
734
Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
735
:param item: int, index of sample.
736
:return: (str, str), two subsequent sentences from corpus
740
assert item < self.corpus_lines
742
sample = self.sample_to_doc[item]
743
t1 = self.all_docs[sample["doc_id"]][sample["line"]]
744
# used later to avoid random nextSentence from same doc
745
self.current_doc = sample["doc_id"]
749
if self.line_buffer is None:
750
# read first non-empty line of file
752
t1 = next(self.file).strip()
754
# use t2 from previous iteration as new t1
755
t1 = self.line_buffer
756
# skip empty rows that are used for separating documents and keep track of current doc id
758
t1 = next(self.file).strip()
759
self.current_doc = self.current_doc+1
760
self.line_buffer = next(self.file).strip()
766
def get_random_line(self):
768
Get random line from another document for nextSentence task.
769
:return: str, content of one line
771
# Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
772
# corpora. However, just to be careful, we try to make sure that
773
# the random document is not the same as the document we're processing.
776
rand_doc_idx = random.randint(0, len(self.all_docs)-1)
777
rand_doc = self.all_docs[rand_doc_idx]
778
line = rand_doc[random.randrange(len(rand_doc))]
780
rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
782
for _ in range(rand_index):
783
line = self.get_next_line()
784
#check if our picked random line is really from another doc like we want it to be
785
if self.current_random_doc != self.current_doc:
789
def get_next_line(self):
790
""" Gets next line of random_file and starts over when reaching end of file"""
792
line = next(self.random_file).strip()
793
#keep track of which document we are currently looking at to later avoid having the same doc as t1
795
self.current_random_doc = self.current_random_doc + 1
796
line = next(self.random_file).strip()
797
except StopIteration:
798
self.random_file.close()
799
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
800
line = next(self.random_file).strip()
804
class InputExample(object):
805
"""A single training/test example for the language model."""
807
def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
808
"""Constructs a InputExample.
810
guid: Unique id for the example.
811
tokens_a: string. The untokenized text of the first sequence. For single
812
sequence tasks, only this sequence must be specified.
813
tokens_b: (Optional) string. The untokenized text of the second sequence.
814
Only must be specified for sequence pair tasks.
815
label: (Optional) string. The label of the example. This should be
816
specified for train and dev examples, but not for test examples.
819
self.tokens_a = tokens_a
820
self.tokens_b = tokens_b
821
self.is_next = is_next # nextSentence
822
self.lm_labels = lm_labels # masked words for language model
825
class InputFeatures(object):
826
"""A single set of features of data."""
828
def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
829
self.input_ids = input_ids
830
self.input_ids_org = input_ids_org
831
self.input_mask = input_mask
832
self.segment_ids = segment_ids
833
self.is_next = is_next
834
self.lm_label_ids = lm_label_ids
835
self.tail_idxs = tail_idxs
838
def random_word(tokens, tokenizer):
840
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
841
:param tokens: list of str, tokenized sentence.
842
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
843
:return: (list of str, list of int), masked tokens and related labels for LM prediction
847
for i, token in enumerate(tokens):
849
prob = random.random()
850
# mask token with 15% probability
853
#candidate_id = random.randint(0,tokenizer.vocab_size)
854
#print(tokenizer.convert_ids_to_tokens(candidate_id))
857
# 80% randomly change token to mask token
859
#tokens[i] = "[MASK]"
862
# 10% randomly change token to random token
864
#tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
865
#tokens[i] = tokenizer.convert_ids_to_tokens(candidate_id)
866
candidate_id = random.randint(0,tokenizer.vocab_size)
867
w = tokenizer.convert_ids_to_tokens(candidate_id)
869
if tokens[i] == None:
871
w = tokenizer.convert_ids_to_tokens(candidate_id)
876
# -> rest 10% randomly keep current token
878
# append current token to output (we will predict these later)
880
#output_label.append(tokenizer.vocab[token])
881
w = tokenizer.convert_tokens_to_ids(token)
883
output_label.append(w)
885
print("Have no this tokens in ids")
888
# For unknown words (should not occur with BPE vocab)
889
#output_label.append(tokenizer.vocab["<unk>"])
890
w = tokenizer.convert_tokens_to_ids("<unk>")
891
output_label.append(w)
892
logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
894
# no masking token (will be ignored by loss function later)
895
output_label.append(-1)
897
return tokens, output_label
900
def convert_example_to_features(example, max_seq_length, tokenizer):
902
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
903
IDs, LM labels, input_mask, CLS and SEP tokens etc.
904
:param example: InputExample, containing sentence input as strings and is_next label
905
:param max_seq_length: int, maximum length of sequence.
906
:param tokenizer: Tokenizer
907
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
909
#now tokens_a is input_ids
910
tokens_a = example.tokens_a
911
tokens_b = example.tokens_b
912
# Modifies `tokens_a` and `tokens_b` in place so that the total
913
# length is less than the specified length.
914
# Account for [CLS], [SEP], [SEP] with "- 3"
915
#_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
916
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
919
tokens_a_org = tokens_a.copy()
920
tokens_a, t1_label = random_word(tokens_a, tokenizer)
927
#tokens_b, t2_label = random_word(tokens_b, tokenizer)
928
# concatenate lm labels and account for CLS, SEP, SEP
929
#lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
930
lm_label_ids = ([-1] + t1_label + [-1])
932
# The convention in BERT is:
933
# (a) For sequence pairs:
934
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
935
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
936
# (b) For single sequences:
937
# tokens: [CLS] the dog is hairy . [SEP]
938
# type_ids: 0 0 0 0 0 0 0
940
# Where "type_ids" are used to indicate whether this is the first
941
# sequence or the second sequence. The embedding vectors for `type=0` and
942
# `type=1` were learned during pre-training and are added to the wordpiece
943
# embedding vector (and position vector). This is not *strictly* necessary
944
# since the [SEP] token unambigiously separates the sequences, but it makes
945
# it easier for the model to learn the concept of sequences.
947
# For classification tasks, the first vector (corresponding to [CLS]) is
948
# used as as the "sentence vector". Note that this only makes sense because
949
# the entire model is fine-tuned.
953
tokens.append("[CLS]")
954
tokens_org.append("[CLS]")
955
segment_ids.append(0)
956
for i, token in enumerate(tokens_a):
958
tokens.append(tokens_a[i])
959
tokens_org.append(tokens_a_org[i])
960
segment_ids.append(0)
963
tokens_org.append("s")
964
segment_ids.append(0)
965
tokens.append("[SEP]")
966
tokens_org.append("[SEP]")
967
segment_ids.append(0)
969
#tokens.append("[SEP]")
970
#segment_ids.append(1)
972
#input_ids = tokenizer.convert_tokens_to_ids(tokens)
973
input_ids = tokenizer.encode(tokens, add_special_tokens=False)
974
input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
975
tail_idxs = len(input_ids)-1
978
input_ids = [w if w!=None else 0 for w in input_ids]
979
input_ids_org = [w if w!=None else 0 for w in input_ids_org]
983
# The mask has 1 for real tokens and 0 for padding tokens. Only real
984
# tokens are attended to.
985
input_mask = [1] * len(input_ids)
987
# Zero-pad up to the sequence length.
988
pad_id = tokenizer.convert_tokens_to_ids("<pad>")
989
while len(input_ids) < max_seq_length:
990
input_ids.append(pad_id)
991
input_ids_org.append(pad_id)
993
segment_ids.append(0)
994
lm_label_ids.append(-1)
997
assert len(input_ids) == max_seq_length
998
assert len(input_ids_org) == max_seq_length
999
assert len(input_mask) == max_seq_length
1000
assert len(segment_ids) == max_seq_length
1001
assert len(lm_label_ids) == max_seq_length
1003
print("!!!Warning!!!")
1004
input_ids = input_ids[:max_seq_length-1]
1005
if 102 not in input_ids:
1008
input_ids += [pad_id]
1009
input_ids_org = input_ids_org[:max_seq_length-1]
1010
if 102 not in input_ids_org:
1011
input_ids_org += [102]
1013
input_ids_org += [pad_id]
1014
input_mask = input_mask[:max_seq_length-1]+[0]
1015
segment_ids = segment_ids[:max_seq_length-1]+[0]
1016
lm_label_ids = lm_label_ids[:max_seq_length-1]+[-1]
1019
if len(input_ids) != max_seq_length:
1020
print(len(input_ids))
1022
if len(input_ids_org) != max_seq_length:
1023
print(len(input_ids_org))
1025
if len(input_mask) != max_seq_length:
1026
print(len(input_mask))
1028
if len(segment_ids) != max_seq_length:
1029
print(len(segment_ids))
1031
if len(lm_label_ids) != max_seq_length:
1032
print(len(lm_label_ids))
1040
if example.guid < 5:
1041
logger.info("*** Example ***")
1042
logger.info("guid: %s" % (example.guid))
1043
logger.info("tokens: %s" % " ".join(
1044
[str(x) for x in tokens]))
1045
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
1046
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
1048
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
1049
logger.info("LM label: %s " % (lm_label_ids))
1050
logger.info("Is next sentence label: %s " % (example.is_next))
1053
features = InputFeatures(input_ids=input_ids,
1054
input_ids_org = input_ids_org,
1055
input_mask=input_mask,
1056
segment_ids=segment_ids,
1057
lm_label_ids=lm_label_ids,
1058
is_next=example.is_next,
1059
tail_idxs=tail_idxs)
1064
parser = argparse.ArgumentParser()
1066
parser = get_parameter(parser)
1068
args = parser.parse_args()
1070
if args.local_rank == -1 or args.no_cuda:
1071
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1072
n_gpu = torch.cuda.device_count()
1074
torch.cuda.set_device(args.local_rank)
1075
device = torch.device("cuda", args.local_rank)
1077
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
1078
torch.distributed.init_process_group(backend='nccl')
1079
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
1080
device, n_gpu, bool(args.local_rank != -1), args.fp16))
1082
if args.gradient_accumulation_steps < 1:
1083
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
1084
args.gradient_accumulation_steps))
1086
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
1088
random.seed(args.seed)
1089
np.random.seed(args.seed)
1090
torch.manual_seed(args.seed)
1092
torch.cuda.manual_seed_all(args.seed)
1094
if not args.do_train:
1095
raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
1097
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
1098
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
1099
if not os.path.exists(args.output_dir):
1100
os.makedirs(args.output_dir)
1102
#tokenizer = BertTokenizer.from_pretrained(args.pretrain_model, do_lower_case=args.do_lower_case)
1103
tokenizer = BertTokenizer.from_pretrained(args.pretrain_model)
1106
#train_examples = None
1107
num_train_optimization_steps = None
1109
print("Loading Train Dataset", args.data_dir_indomain)
1110
#train_dataset = Dataset_noNext(args.data_dir, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory)
1111
all_type_sentence, train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
1112
num_train_optimization_steps = int(
1113
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
1114
if args.local_rank != -1:
1115
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
1120
model = BertForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1121
#model = BertForSequenceClassification.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1128
param_optimizer = list(model.named_parameters())
1130
for par in param_optimizer:
1134
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
1135
optimizer_grouped_parameters = [
1136
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
1137
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
1139
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
1140
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
1144
from apex import amp
1146
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
1149
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
1153
model = torch.nn.DataParallel(model)
1155
if args.local_rank != -1:
1156
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
1162
logger.info("***** Running training *****")
1163
logger.info(" Num examples = %d", len(train_dataset))
1164
logger.info(" Batch size = %d", args.train_batch_size)
1165
logger.info(" Num steps = %d", num_train_optimization_steps)
1167
if args.local_rank == -1:
1168
train_sampler = RandomSampler(train_dataset)
1169
#all_type_sentence_sampler = RandomSampler(all_type_sentence)
1171
#TODO: check if this works with current data generator from disk that relies on next(file)
1172
# (it doesn't return item back by index)
1173
train_sampler = DistributedSampler(train_dataset)
1174
#all_type_sentence_sampler = DistributedSampler(all_type_sentence)
1175
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1176
#all_type_sentence_dataloader = DataLoader(all_type_sentence, sampler=all_type_sentence_sampler, batch_size=len(all_type_sentence_label))
1178
output_loss_file = os.path.join(args.output_dir, "loss")
1179
loss_fout = open(output_loss_file, 'w')
1182
output_loss_file_no_pseudo = os.path.join(args.output_dir, "loss_no_pseudo")
1183
loss_fout_no_pseudo = open(output_loss_file_no_pseudo, 'w')
1189
#alpha = float(1/(args.num_train_epochs*len(train_dataloader)))
1190
#alpha = float(1/args.num_train_epochs)
1201
#retrive_gate = args.num_labels_task
1202
#retrive_gate = len(train_dataset)/100
1204
all_type_sentence_label = list()
1205
all_previous_sentence_label = list()
1206
all_type_sentiment_label = list()
1207
all_previous_sentiment_label = list()
1208
top_k_all_type = dict()
1209
bottom_k_all_type = dict()
1210
for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1212
nb_tr_examples, nb_tr_steps = 0, 0
1213
for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1216
#######################
1217
######################
1218
###Init 8 type sentence
1219
###Init 2 type sentiment
1220
if (step == 0) and (epo == 0):
1221
#batch_ = tuple(t.to(device) for t in batch_)
1222
#all_type_sentence_ = tuple(t.to(device) for t in all_type_sentence)
1224
input_ids_ = torch.stack([line[0] for line in all_type_sentence]).to(device)
1225
input_ids_org_ = torch.stack([line[1] for line in all_type_sentence]).to(device)
1226
input_mask_ = torch.stack([line[2] for line in all_type_sentence]).to(device)
1227
segment_ids_ = torch.stack([line[3] for line in all_type_sentence]).to(device)
1228
lm_label_ids_ = torch.stack([line[4] for line in all_type_sentence]).to(device)
1229
is_next_ = torch.stack([line[5] for line in all_type_sentence]).to(device)
1230
tail_idxs_ = torch.stack([line[6] for line in all_type_sentence]).to(device)
1231
sentence_label_ = torch.stack([line[7] for line in all_type_sentence]).to(device)
1232
sentiment_label_ = torch.stack([line[8] for line in all_type_sentence]).to(device)
1234
with torch.no_grad():
1236
#in_domain_rep_mean, in_task_rep_mean = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep_mean")
1237
in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1238
# Search id from Docs and ranking via (Domain/Task)
1239
#query_domain = in_domain_rep_mean.float().to("cpu")
1240
query_domain = in_domain_rep.float().to("cpu")
1241
query_domain = query_domain.unsqueeze(1)
1242
#query_task = in_task_rep_mean.float().to("cpu")
1243
query_task = in_task_rep.float().to("cpu")
1244
query_task = query_task.unsqueeze(1)
1245
#query_domain_task = torch.cat([query_domain,query_task],2)
1248
task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1249
task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1250
task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1251
task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1254
domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1255
domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1256
domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1257
domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768*2, 2)
1260
#domain_task_binary_classifier_weight, domain_task_binary_classifier_bias = model(func="return_domain_task_binary_classifier")
1261
#domain_task_binary_classifier_weight = domain_task_binary_classifier_weight[:int(domain_task_binary_classifier_weight.shape[0]/n_gpu)][:]
1262
#domain_task_binary_classifier_bias = domain_task_binary_classifier_bias[:int(domain_task_binary_classifier_bias.shape[0]/n_gpu)][:]
1263
#domain_task_binary_classifier = return_Classifier(domain_task_binary_classifier_weight, domain_task_binary_classifier_bias, 768*4, 2)
1265
#start = time.time()
1266
query_domain = query_domain.expand(-1, docs_tail.shape[0], -1)
1267
query_task = query_task.expand(-1, docs_head.shape[0], -1)
1268
#query_domain_task = query_domain_task.expand(-1, docs_head.shape[0], -1)
1274
#LeakyReLU = torch.nn.LeakyReLU()
1277
domain_binary_logit = LeakyReLU(domain_binary_classifier(docs_tail))
1278
domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1279
domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1281
domain_binary_logit = domain_binary_classifier(torch.cat([query_domain, docs_tail[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1282
target = torch.zeros(domain_binary_logit.shape[0], domain_binary_logit.shape[1], dtype=torch.long)
1283
#domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1284
domain_binary_logit = ce_loss(domain_binary_logit.view(-1, 2), target.view(-1)).reshape(domain_binary_logit.shape[0],domain_binary_logit.shape[1])
1287
task_binary_logit = task_binary_classifier(torch.cat([query_task, docs_head[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1288
#task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1289
#target = torch.zeros(task_binary_logit.shape[0], task_binary_logit.shape[1], dtype=torch.long)
1290
task_binary_logit = ce_loss(task_binary_logit.view(-1, 2), target.view(-1)).reshape(task_binary_logit.shape[0],task_binary_logit.shape[1])
1293
domain_task_binary_logit = task_binary_logit+domain_binary_logit*0.5
1299
domain_top_k_all_type = torch.topk(domain_binary_logit, k, dim=1, largest=True, sorted=False)
1300
perm = torch.randperm(domain_binary_logit.shape[1])
1301
domain_bottom_k_all_type_indices = perm[:k]
1302
domain_bottom_k_all_type_values = domain_binary_logit[:,domain_bottom_k_all_type_indices]
1303
domain_bottom_k_all_type_indices = torch.stack(args.domain_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1307
task_top_k_all_type = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1309
domain_task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1312
###########################
1316
domain_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1318
rand_seed = torch.randint(0,k,(choose_n,))
1319
domain_top_k_all_type_indices = domain_top_k_all_type.indices[:,rand_seed]
1320
domain_top_k_all_type_values = domain_top_k_all_type.values[:,rand_seed]
1324
#perm = torch.randperm(domain_task_binary_logit.shape[1])
1325
#domain_bottom_k_all_type_indices = perm[:k]
1326
#domain_bottom_k_all_type_values = domain_task_binary_logit[:,domain_bottom_k_all_type_indices]
1327
#domain_bottom_k_all_type_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1329
#domain_bottom_k_all_type = torch.topk(domain_task_binary_logit, k*2, dim=1, largest=False, sorted=False)
1330
domain_bottom_k_all_type_indices = torch.randint(k+1,domain_binary_logit.shape[1],(choose_n*2,))
1331
domain_bottom_k_all_type_values = domain_task_binary_logit[:,domain_bottom_k_all_type_indices]
1332
domain_bottom_k_all_type_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1336
task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1338
rand_seed = torch.randint(0,k,(choose_n,))
1339
task_top_k_all_type_indices = task_top_k_all_type.indices[:,rand_seed]
1340
task_top_k_all_type_values = task_top_k_all_type.values[:,rand_seed]
1345
domain_task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1347
rand_seed = torch.randint(0,k,(choose_n,))
1348
domain_task_top_k_all_type_indices = domain_task_top_k_all_type.indices[:,rand_seed]
1349
domain_task_top_k_all_type_values = domain_task_top_k_all_type.values[:,rand_seed]
1353
###########################
1356
del domain_task_binary_logit, domain_binary_logit, task_binary_logit
1358
all_type_sentiment_label = sentiment_label_.to('cpu')
1361
domain_bottom_k_all_type = {"values":domain_bottom_k_all_type_values, "indices":domain_bottom_k_all_type_indices}
1362
#domain_top_k_all_type = {"values":domain_top_k_all_type.values, "indices":domain_top_k_all_type.indices}
1363
domain_top_k_all_type = {"values":domain_top_k_all_type_values, "indices":domain_top_k_all_type_indices}
1364
#task_top_k_all_type = {"values":task_top_k_all_type.values, "indices":task_top_k_all_type.indices}
1365
task_top_k_all_type = {"values":task_top_k_all_type_values, "indices":task_top_k_all_type_indices}
1366
#domain_task_top_k_all_type = {"values":domain_task_top_k_all_type.values, "indices":domain_task_top_k_all_type.indices}
1367
domain_task_top_k_all_type = {"values":domain_task_top_k_all_type_values, "indices":domain_task_top_k_all_type_indices}
1369
######################
1370
######################
1374
batch_ = tuple(t.to(device) for t in batch_)
1375
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = batch_
1379
# Generate query representation
1380
in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1383
#if (step%10 == 0) or (sentence_label_.shape[0] != args.train_batch_size):
1384
if (step%retrive_gate == 0) or (sentiment_label_.shape[0] != args.train_batch_size):
1386
with torch.no_grad():
1387
query_domain = in_domain_rep.float().to("cpu")
1388
query_domain = query_domain.unsqueeze(1)
1389
#query_task = in_task_rep_mean.float().to("cpu")
1390
query_task = in_task_rep.float().to("cpu")
1391
query_task = query_task.unsqueeze(1)
1392
query_domain_task = torch.cat([query_domain,query_task],2)
1395
task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1396
task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1397
task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1398
task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1401
domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1402
domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1403
domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1404
domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768*2, 2)
1407
#domain_task_binary_classifier_weight, domain_task_binary_classifier_bias = model(func="return_domain_task_binary_classifier")
1408
#domain_task_binary_classifier_weight = domain_task_binary_classifier_weight[:int(domain_task_binary_classifier_weight.shape[0]/n_gpu)][:]
1409
#domain_task_binary_classifier_bias = domain_task_binary_classifier_bias[:int(domain_task_binary_classifier_bias.shape[0]/n_gpu)][:]
1410
#domain_task_binary_classifier = return_Classifier(domain_task_binary_classifier_weight, domain_task_binary_classifier_bias, 768*4, 2)
1412
#start = time.time()
1413
#query_domain = query_domain.expand(-1, docs.shape[0], -1)
1414
query_domain = query_domain.expand(-1, docs_tail.shape[0], -1)
1415
#query_task = query_task.expand(-1, docs.shape[0], -1)
1416
query_task = query_task.expand(-1, docs_head.shape[0], -1)
1417
#print(docs_head.shape)
1418
#print(query_domain_task.shape)
1420
#query_domain_task = query_domain_task.expand(-1, docs_head.shape[0], -1)
1426
#LeakyReLU = torch.nn.LeakyReLU()
1429
domain_binary_logit = LeakyReLU(domain_binary_classifier(docs_tail))
1430
domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1431
domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1433
domain_binary_logit = domain_binary_classifier(torch.cat([query_domain, docs_tail[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1434
target = torch.zeros(domain_binary_logit.shape[0], domain_binary_logit.shape[1], dtype=torch.long)
1435
#domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1436
domain_binary_logit = ce_loss(domain_binary_logit.view(-1, 2), target.view(-1)).reshape(domain_binary_logit.shape[0],domain_binary_logit.shape[1])
1439
task_binary_logit = task_binary_classifier(torch.cat([query_task, docs_head[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1440
#task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1441
task_binary_logit = ce_loss(task_binary_logit.view(-1, 2), target.view(-1)).reshape(task_binary_logit.shape[0],task_binary_logit.shape[1])
1444
domain_task_binary_logit = task_binary_logit + domain_binary_logit*0.5
1446
####################
1450
#[batch_size, 36603]
1451
domain_top_k = torch.topk(domain_binary_logit, k, dim=1, largest=True, sorted=False)
1453
rand_seed = torch.randint(0,k,(choose_n,))
1454
domain_top_k_indices = domain_top_k.indices[:,rand_seed]
1455
domain_top_k_values = domain_top_k.values[:,rand_seed]
1459
perm = torch.randperm(domain_binary_logit.shape[1])
1460
domain_bottom_k_indices = perm[:k]
1461
domain_bottom_k_values = domain_binary_logit[:,domain_bottom_k_indices]
1462
domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1465
#domain_top_k = torch.topk(domain_binary_logit, k, dim=1, largest=False, sorted=False)
1466
domain_bottom_k_indices = torch.randint(k+1,domain_binary_logit.shape[1],(choose_n*2,))
1467
domain_bottom_k_values = domain_task_binary_logit[:,domain_bottom_k_indices]
1468
domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1471
task_top_k = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1473
#rand_seed = torch.randint(0,k,(choose_n,))
1474
task_top_k_indices = task_top_k.indices[:,rand_seed]
1475
task_top_k_values = task_top_k.values[:,rand_seed]
1479
domain_task_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1480
#rand_seed = torch.randint(0,k,(choose_n,))
1481
domain_task_top_k_indices = domain_task_top_k.indices[:,rand_seed]
1482
domain_task_top_k_values = domain_task_top_k.values[:,rand_seed]
1485
####################
1488
domain_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1489
perm = torch.randperm(domain_task_binary_logit.shape[1])
1490
domain_bottom_k_indices = perm[:k]
1491
domain_bottom_k_values = domain_task_binary_logit[:,domain_bottom_k_indices]
1492
domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1493
task_top_k = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1494
domain_task_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1496
####################
1499
del domain_task_binary_logit, domain_binary_logit, task_binary_logit
1501
all_previous_sentiment_label = sentiment_label_.to('cpu')
1506
domain_bottom_k = {"values":domain_bottom_k_values, "indices":domain_bottom_k_indices}
1507
#domain_top_k = {"values":domain_top_k.values, "indices":domain_top_k.indices}
1508
domain_top_k = {"values":domain_top_k_values, "indices":domain_top_k_indices}
1509
#task_top_k = {"values":task_top_k.values, "indices":task_top_k.indices}
1510
task_top_k = {"values":task_top_k_values, "indices":task_top_k_indices}
1511
#domain_task_top_k = {"values":domain_task_top_k.values, "indices":domain_task_top_k.indices}
1512
domain_task_top_k = {"values":domain_task_top_k_values, "indices":domain_task_top_k_indices}
1517
domain_bottom_k_previous = {"values":torch.cat((domain_bottom_k["values"], domain_bottom_k_all_type["values"]),0), "indices":torch.cat((domain_bottom_k["indices"], domain_bottom_k_all_type["indices"]),0)}
1518
domain_top_k_previous = {"values":torch.cat((domain_top_k["values"], domain_top_k_all_type["values"]),0), "indices":torch.cat((domain_top_k["indices"], domain_top_k_all_type["indices"]),0)}
1519
task_top_k_previous = {"values":torch.cat((task_top_k["values"], task_top_k_all_type["values"]),0), "indices":torch.cat((task_top_k["indices"], task_top_k_all_type["indices"]),0)}
1520
domain_task_top_k_previous = {"values":torch.cat((domain_task_top_k["values"], domain_task_top_k_all_type["values"]),0), "indices":torch.cat((domain_task_top_k["indices"], domain_task_top_k_all_type["indices"]),0)}
1522
all_previous_sentiment_label = torch.cat((all_previous_sentiment_label, all_type_sentiment_label))
1524
###Need to fix --> choice
1525
used_idx = torch.tensor([random.choice(((all_previous_sentiment_label==int(idx_)).nonzero()).tolist())[0] for idx_ in sentiment_label_])
1526
#top_k = {"values":top_k_previous["values"].index_select(0,used_idx), "indices":top_k_previous["indices"].index_select(0,used_idx)}
1527
domain_top_k = {"values":domain_top_k_previous["values"].index_select(0,used_idx), "indices":domain_top_k_previous["indices"].index_select(0,used_idx)}
1528
task_top_k = {"values":task_top_k_previous["values"].index_select(0,used_idx), "indices":task-top_k_previous["indices"].index_select(0,used_idx)}
1529
domain_task_top_k = {"values":domain_task_top_k_previous["values"].index_select(0,used_idx), "indices":domain_task_top_k_previous["indices"].index_select(0,used_idx)}
1531
#bottom_k = {"values":bottom_k_previous["values"].index_select(0,used_idx), "indices":bottom_k_previous["indices"].index_select(0,used_idx)}
1532
domaion_bottom_k = {"values":domain_bottom_k_previous["values"].index_select(0,used_idx), "indices":domain_bottom_k_previous["indices"].index_select(0,used_idx)}
1542
#Domain Binary Classifier - Outdomain
1543
#batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1544
batch = AugmentationData_Domain(domain_top_k, domain_bottom_k, tokenizer, args.max_seq_length)
1545
batch = tuple(t.to(device) for t in batch)
1546
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, domain_id = batch
1548
out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, lm_label=lm_label_ids, attention_mask=input_mask, func="in_domain_task_rep")
1550
#print(domain_top_k["indices"].shape)
1551
#print(input_ids_org.shape)
1552
#print(out_domain_rep_tail.shape)
1553
#print(in_domain_rep.shape)
1555
############Construct constrive instances
1556
comb_rep_pos = torch.cat([in_domain_rep,in_domain_rep.flip(0)], 1)
1557
in_domain_rep_ready = in_domain_rep.repeat(1,int(out_domain_rep_tail.shape[0]/in_domain_rep.shape[0])).reshape(out_domain_rep_tail.shape[0],out_domain_rep_tail.shape[1])
1558
comb_rep_unknow = torch.cat([in_domain_rep_ready, out_domain_rep_tail], 1)
1560
mix_domain_binary_loss, domain_binary_logit = model(func="domain_binary_classifier", in_domain_rep=comb_rep_pos.to(device), out_domain_rep=comb_rep_unknow.to(device), domain_id=domain_id, use_detach=False)
1567
indices = domain_top_k["indices"].reshape(domain_top_k["indices"].shape[0]*domain_top_k["indices"].shape[1])
1568
indices_ = domain_bottom_k["indices"].reshape(domain_bottom_k["indices"].shape[0]*domain_bottom_k["indices"].shape[1])
1569
indices = torch.cat([indices,indices_],0)
1571
out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1572
out_domain_rep_head.requires_grad=True
1574
out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1575
out_domain_rep_tail.requires_grad=True
1578
with torch.no_grad():
1581
docs_head.index_copy_(0, indices, out_domain_rep_head)
1582
docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1584
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
1585
print("head",out_domain_rep_head.shape)
1586
print("tail",out_domain_rep_head.shape)
1587
print("doc_h",docs_head.shape)
1588
print("doc_t",docs_tail.shape)
1589
print("ind",indices.shape)
1595
#Task Binary Classifier in domain
1596
#Pseudo Task --> Won't bp to PLM: only train classifier [In domain data]
1597
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1598
batch = tuple(t.to(device) for t in batch)
1599
all_in_task_rep_comb, all_sentence_binary_label = batch
1600
in_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1605
#Train Task org - finetune
1606
#split into: in_dom and query_ --> different weight
1607
task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1612
#Task Level including outdomain
1613
batch = AugmentationData_Task(task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1614
batch = tuple(t.to(device) for t in batch)
1615
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1616
out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1618
batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head)
1619
batch = tuple(t.to(device) for t in batch)
1620
all_in_task_rep_comb, all_sentence_binary_label = batch
1621
out_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1627
indices = task_top_k["indices"].reshape(task_top_k["indices"].shape[0]*task_top_k["indices"].shape[1])
1629
out_domain_rep_head = out_domain_rep_head[input_ids_org_.shape[0]:,:]
1630
out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1631
out_domain_rep_head.requires_grad=True
1633
out_domain_rep_tail = out_domain_rep_tail[input_ids_org_.shape[0]:,:]
1634
out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1635
out_domain_rep_tail.requires_grad=True
1637
with torch.no_grad():
1639
docs_head.index_copy_(0, indices, out_domain_rep_head)
1640
docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1642
print("head",out_domain_rep_head.shape)
1643
print("head",out_domain_rep_head.get_device())
1644
print("tail",out_domain_rep_head.shape)
1645
print("tail",out_domain_rep_head.get_device())
1646
print("doc_h",docs_head.shape)
1647
print("doc_h",docs_head.get_device())
1648
print("doc_t",docs_tail.shape)
1649
print("doc_t",docs_tail.get_device())
1650
print("ind",indices.shape)
1651
print("ind",indices.get_device())
1653
##############################
1654
##############################
1658
#Domain-Task Level (Out-domain)
1659
batch = AugmentationData_Task(domain_task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1660
batch = tuple(t.to(device) for t in batch)
1661
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1662
out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1664
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head, in_domain_rep=out_domain_rep_tail)
1665
batch = tuple(t.to(device) for t in batch)
1666
all_in_task_rep_comb, all_sentence_binary_label = batch
1667
out_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1671
#Domain-Task Level (in-domain)
1673
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1674
batch = tuple(t.to(device) for t in batch)
1675
in_all_in_task_rep_comb, in_all_sentence_binary_label = batch
1676
in_domain_task_binary_loss, in_domain_task_binary_logit = model(all_in_task_rep_comb=in_all_in_task_rep_comb, all_sentence_binary_label=in_all_sentence_binary_label, func="domain_task_binary_classifier")
1683
indices = domain_task_top_k["indices"].reshape(domain_task_top_k["indices"].shape[0]*domain_task_top_k["indices"].shape[1])
1685
out_domain_rep_head = out_domain_rep_head[input_ids_org_.shape[0]:,:]
1686
out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1687
out_domain_rep_head.requires_grad=True
1689
out_domain_rep_tail = out_domain_rep_tail[input_ids_org_.shape[0]:,:]
1690
out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1691
out_domain_rep_tail.requires_grad=True
1693
with torch.no_grad():
1695
docs_head.index_copy_(0, indices, out_domain_rep_head)
1696
docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1698
print("head",out_domain_rep_head.shape)
1699
print("head",out_domain_rep_head.get_device())
1700
print("tail",out_domain_rep_head.shape)
1701
print("tail",out_domain_rep_head.get_device())
1702
print("doc_h",docs_head.shape)
1703
print("doc_h",docs_head.get_device())
1704
print("doc_t",docs_tail.shape)
1705
print("doc_t",docs_tail.get_device())
1706
print("ind",indices.shape)
1707
print("ind",indices.get_device())
1709
##############################
1710
##############################
1714
#Domain-Task Level (Out-domain)
1715
batch = AugmentationData_Task(domain_task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1716
batch = tuple(t.to(device) for t in batch)
1717
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1718
#out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1719
task_loss_out, class_logit_out = model(input_ids_org=input_ids_org, sentence_label=sentiment_label, attention_mask=input_mask, func="task_class")
1722
#batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head, in_domain_rep=out_domain_rep_tail)
1723
batch = tuple(t.to(device) for t in batch)
1724
all_in_task_rep_comb, all_sentence_binary_label = batch
1725
out_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1728
#Domain-Task Level (in-domain)
1730
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1731
batch = tuple(t.to(device) for t in batch)
1732
in_all_in_task_rep_comb, in_all_sentence_binary_label = batch
1733
in_domain_task_binary_loss, in_domain_task_binary_logit = model(all_in_task_rep_comb=in_all_in_task_rep_comb, all_sentence_binary_label=in_all_sentence_binary_label, func="domain_task_binary_classifier")
1738
task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1744
############################################
1745
############################################
1748
#loss = mix_domain_binary_loss.mean()*0.5 + (in_task_binary_loss.mean() + out_task_binary_loss.mean())*0.5 + task_loss_org.mean() + out_domain_task_binary_loss
1749
loss = task_loss_out.mean() + task_loss_org.mean()
1751
#loss = mix_domain_binary_loss + (in_task_binary_loss + out_task_binary_loss)/2 + task_loss_org + out_domain_task_binary_loss
1752
print("No Using GPU")
1755
loss = task_loss_out.mean() + task_loss_org.mean()
1756
#loss = task_loss_org.mean() + (in_domain_task_binary_loss.mean()+out_domain_task_binary_loss.mean())/2
1758
#loss = mix_domain_binary_loss + (in_task_binary_loss + out_task_binary_loss)/2 + task_loss_org + out_domain_task_binary_loss
1759
print("No Using GPU")
1762
if args.gradient_accumulation_steps > 1:
1763
loss = loss / args.gradient_accumulation_steps
1765
with amp.scale_loss(loss, optimizer) as scaled_loss:
1766
scaled_loss.backward()
1771
loss_fout.write("{}\n".format(loss.item()))
1775
#loss_fout_no_pseudo.write("{}\n".format(loss.item()-pseudo.item()))
1778
tr_loss += loss.item()
1779
#nb_tr_examples += input_ids.size(0)
1780
nb_tr_examples += input_ids_.size(0)
1782
if (step + 1) % args.gradient_accumulation_steps == 0:
1784
# modify learning rate with special warm up BERT uses
1785
# if args.fp16 is False, BertAdam is used that handles this automatically
1786
#lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
1787
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1790
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1797
#optimizer.zero_grad()
1805
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1806
#output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1807
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(epo))
1808
torch.save(model_to_save.state_dict(), output_model_file)
1812
# Save a trained model
1813
logger.info("** ** * Saving fine - tuned model ** ** * ")
1814
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1815
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1817
torch.save(model_to_save.state_dict(), output_model_file)
1821
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1822
"""Truncates a sequence pair in place to the maximum length."""
1824
# This is a simple heuristic which will always truncate the longer sequence
1825
# one token at a time. This makes more sense than truncating an equal percent
1826
# of tokens from each, since if one sequence is very short then each token
1827
# that's truncated likely contains more information than a longer sequence.
1829
#total_length = len(tokens_a) + len(tokens_b)
1830
total_length = len(tokens_a)
1831
if total_length <= max_length:
1837
def accuracy(out, labels):
1838
outputs = np.argmax(out, axis=1)
1839
return np.sum(outputs == labels)
1842
if __name__ == "__main__":