2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
9
# http://www.apache.org/licenses/LICENSE-2.0
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
18
from __future__ import absolute_import, division, print_function, unicode_literals
27
from torch.autograd import Variable
28
import torch.nn.functional as F
32
from torch.utils.data import DataLoader, Dataset, RandomSampler
33
from torch.utils.data.distributed import DistributedSampler
34
from tqdm import tqdm, trange
35
from torch.nn import CrossEntropyLoss
37
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
38
#from transformers.modeling_roberta import RobertaForMaskedLMDomainTask
39
from transformers.modeling_roberta_updateRep_self import RobertaForMaskedLMDomainTask
40
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
42
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
43
datefmt='%m/%d/%Y %H:%M:%S',
45
logger = logging.getLogger(__name__)
48
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
51
def get_parameter(parser):
53
## Required parameters
54
parser.add_argument("--data_dir_indomain",
58
help="The input train corpus.(In Domain)")
59
parser.add_argument("--data_dir_outdomain",
63
help="The input train corpus.(Out Domain)")
64
parser.add_argument("--pretrain_model", default=None, type=str, required=True,
65
help="Bert pre-trained model selected in the list: bert-base-uncased, "
66
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
67
parser.add_argument("--output_dir",
71
help="The output directory where the model checkpoints will be written.")
72
parser.add_argument("--augment_times",
76
help="Default batch_size/augment_times to save model")
78
parser.add_argument("--max_seq_length",
81
help="The maximum total input sequence length after WordPiece tokenization. \n"
82
"Sequences longer than this will be truncated, and sequences shorter \n"
83
"than this will be padded.")
84
parser.add_argument("--do_train",
86
help="Whether to run training.")
87
parser.add_argument("--train_batch_size",
90
help="Total batch size for training.")
91
parser.add_argument("--learning_rate",
94
help="The initial learning rate for Adam.")
95
parser.add_argument("--num_train_epochs",
98
help="Total number of training epochs to perform.")
99
parser.add_argument("--warmup_proportion",
102
help="Proportion of training to perform linear learning rate warmup for. "
103
"E.g., 0.1 = 10%% of training.")
104
parser.add_argument("--no_cuda",
106
help="Whether not to use CUDA when available")
107
parser.add_argument("--on_memory",
109
help="Whether to load train samples into memory or use disk")
110
parser.add_argument("--do_lower_case",
112
help="Whether to lower case the input text. True for uncased models, False for cased models.")
113
parser.add_argument("--local_rank",
116
help="local_rank for distributed training on gpus")
117
parser.add_argument('--seed',
120
help="random seed for initialization")
121
parser.add_argument('--gradient_accumulation_steps',
124
help="Number of updates steps to accumualte before performing a backward/update pass.")
125
parser.add_argument('--fp16',
127
help="Whether to use 16-bit float precision instead of 32-bit")
128
parser.add_argument('--loss_scale',
129
type = float, default = 0,
130
help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
131
"0 (default value): dynamic loss scaling.\n"
132
"Positive power of 2: static loss scaling value.\n")
134
parser.add_argument("--num_labels_task",
135
default=None, type=int,
137
help="num_labels_task")
138
parser.add_argument("--weight_decay",
141
help="Weight decay if we apply some.")
142
parser.add_argument("--adam_epsilon",
145
help="Epsilon for Adam optimizer.")
146
parser.add_argument("--max_grad_norm",
149
help="Max gradient norm.")
150
parser.add_argument('--fp16_opt_level',
153
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
154
"See details at https://nvidia.github.io/apex/amp.html")
155
parser.add_argument("--task",
160
parser.add_argument("--K",
169
def return_Classifier(weight, bias, dim_in, dim_out):
170
#LeakyReLU = torch.nn.LeakyReLU
171
classifier = torch.nn.Linear(dim_in, dim_out , bias=True)
173
#print(classifier.weight)
174
#print(classifier.weight.shape)
175
#print(classifier.weight.data)
176
#print(classifier.weight.data.shape)
178
classifier.weight.data = weight.to("cpu")
179
classifier.bias.data = bias.to("cpu")
180
classifier.requires_grad=False
182
#print(classifier.weight)
183
#print(classifier.weight.shape)
188
#logit = LeakyReLU(classifier)
192
def load_GeneralDomain(dir_data_out):
196
print("Load CLS.pt and train.json")
198
docs_head = torch.load(dir_data_out+"train_head.pt")
199
docs_tail = torch.load(dir_data_out+"train_tail.pt")
201
print(docs_head.shape)
202
print(docs_tail.shape)
204
with open(dir_data_out+"train.json") as file:
205
data = json.load(file)
206
print("train.json Done")
208
docs_tail_head = torch.cat([docs_tail, docs_head],2)
209
return docs_tail_head, docs_head, docs_tail, data
212
def load_InDomain(dir_data_in):
214
with open(dir_data_in+"train.json") as file:
215
data = json.load(file)
216
for id, line in enumerate(data):
217
in_data[id]=line["sentence"]
220
parser = argparse.ArgumentParser()
221
parser = get_parameter(parser)
222
args = parser.parse_args()
225
docs_tail_head, docs_head, docs_tail, data = load_GeneralDomain(args.data_dir_outdomain)
226
in_data = load_InDomain(args.data_dir_indomain)
228
if docs_head.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
230
#docs = docs[:,0,:].unsqueeze(1)
232
docs_head = docs_head.mean(1).unsqueeze(1)
233
print(docs_head.shape)
235
print(docs_head.shape)
236
if docs_tail.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
238
#docs = docs[:,0,:].unsqueeze(1)
240
docs_tail = docs_tail.mean(1).unsqueeze(1)
241
print(docs_tail.shape)
243
print(docs_tail.shape)
246
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
248
with open(data_dir_indomain+"train.json") as file:
249
data = json.load(file)
252
num_label_list = list()
253
label_sentence_dict = dict()
254
num_sentiment_label_list = list()
255
sentiment_label_dict = dict()
260
num_sentiment_label_list.append(line["sentiment"])
261
#num_label_list.append(line["aspect"])
262
num_label_list.append(line["sentiment"])
264
num_label = sorted(list(set(num_label_list)))
265
label_map = {label : i for i , label in enumerate(num_label)}
266
num_sentiment_label = sorted(list(set(num_sentiment_label_list)))
267
sentiment_label_map = {label : i for i , label in enumerate(num_sentiment_label)}
273
print("sentiment_label_map:")
274
print(sentiment_label_map)
277
###Create data: 1 choosed data along with the rest of 7 class data
280
all_input_ids = list()
281
all_input_mask = list()
282
all_segment_ids = list()
283
all_lm_labels_ids = list()
285
all_tail_idxs = list()
286
all_sentence_labels = list()
288
cur_tensors_list = list()
289
#print(list(label_map.values()))
290
candidate_label_list = list(label_map.values())
291
candidate_sentiment_label_list = list(sentiment_label_map.values())
292
all_type_sentence = [0]*len(candidate_label_list)
293
all_type_sentiment_sentence = [0]*len(candidate_sentiment_label_list)
297
sentiment = line["sentiment"]
298
sentence = line["sentence"]
299
#label = line["aspect"]
300
label = line["sentiment"]
303
tokens_a = tokenizer.tokenize(sentence)
304
#input_ids = tokenizer.encode(sentence, add_special_tokens=False)
306
if "</s>" in tokens_a:
307
print("Have more than 1 </s>")
308
#tokens_a[tokens_a.index("<s>")] = "s"
309
for i in range(len(tokens_a)):
310
if tokens_a[i] == "</s>":
316
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
317
# transform sample to features
318
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
320
cur_tensors = (torch.tensor(cur_features.input_ids),
321
torch.tensor(cur_features.input_ids_org),
322
torch.tensor(cur_features.input_mask),
323
torch.tensor(cur_features.segment_ids),
324
torch.tensor(cur_features.lm_label_ids),
326
torch.tensor(cur_features.tail_idxs),
327
torch.tensor(label_map[label]),
328
torch.tensor(sentiment_label_map[sentiment]))
330
cur_tensors_list.append(cur_tensors)
333
if label_map[label] in candidate_label_list:
334
all_type_sentence[label_map[label]]=cur_tensors
335
candidate_label_list.remove(label_map[label])
337
if sentiment_label_map[sentiment] in candidate_sentiment_label_list:
339
#print(sentiment_label_map[sentiment])
341
all_type_sentiment_sentence[sentiment_label_map[sentiment]]=cur_tensors
342
candidate_sentiment_label_list.remove(sentiment_label_map[sentiment])
348
return all_type_sentiment_sentence, cur_tensors_list
352
def AugmentationData_Domain(train_batch_size, k, tokenizer, max_seq_length):
353
#top_k_shape = top_k.indices.shape
354
#ids = top_k.indices.reshape(top_k_shape[0]*top_k_shape[1]).tolist()
355
#top_k_shape = top_k["indices"].shape
356
#ids_pos = top_k["indices"].reshape(top_k_shape[0]*top_k_shape[1]).tolist()
357
#ids = top_k["indices"]
359
#bottom_k_shape = bottom_k["indices"].shape
360
#ids_neg = bottom_k["indices"].reshape(bottom_k_shape[0]*bottom_k_shape[1]).tolist()
361
ids_pos = random.sample(range(0,len(in_data)),train_batch_size)
363
ids_neg = random.sample(range(0,len(data)),train_batch_size*k)
365
#ids = ids_pos+ids_neg
368
all_input_ids = list()
369
all_input_ids_org = list()
370
all_input_mask = list()
371
all_segment_ids = list()
372
all_lm_labels_ids = list()
374
all_tail_idxs = list()
375
all_id_domain = list()
377
for id, i in enumerate(ids_pos):
378
t1 = data[str(i)]['sentence']
380
#tokens_a = tokenizer.tokenize(t1)
381
tokens_a = tokenizer.tokenize(t1)
383
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
385
# transform sample to features
386
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
388
all_input_ids.append(torch.tensor(cur_features.input_ids))
389
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
390
all_input_mask.append(torch.tensor(cur_features.input_mask))
391
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
392
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
393
all_is_next.append(torch.tensor(0))
394
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
396
# all_id_domain.append(torch.tensor(0))
398
# all_id_domain.append(torch.tensor(1))
399
all_id_domain.append(torch.tensor(1))
402
for id, i in enumerate(ids_neg):
403
t1 = data[str(i)]['sentence']
405
#tokens_a = tokenizer.tokenize(t1)
406
tokens_a = tokenizer.tokenize(t1)
408
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
410
# transform sample to features
411
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
413
all_input_ids.append(torch.tensor(cur_features.input_ids))
414
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
415
all_input_mask.append(torch.tensor(cur_features.input_mask))
416
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
417
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
418
all_is_next.append(torch.tensor(0))
419
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
421
# all_id_domain.append(torch.tensor(0))
423
# all_id_domain.append(torch.tensor(1))
424
all_id_domain.append(torch.tensor(0))
428
cur_tensors = (torch.stack(all_input_ids),
429
torch.stack(all_input_ids_org),
430
torch.stack(all_input_mask),
431
torch.stack(all_segment_ids),
432
torch.stack(all_lm_labels_ids),
433
torch.stack(all_is_next),
434
torch.stack(all_tail_idxs),
435
torch.stack(all_id_domain))
440
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None):
441
top_k_shape = top_k["indices"].shape
442
sentence_ids = top_k["indices"]
444
all_input_ids = list()
445
all_input_ids_org = list()
446
all_input_mask = list()
447
all_segment_ids = list()
448
all_lm_labels_ids = list()
450
all_tail_idxs = list()
451
all_sentence_labels = list()
452
all_sentiment_labels = list()
454
add_org = tuple(t.to('cpu') for t in add_org)
455
#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
456
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
459
#print("input_ids_",input_ids_.shape)
461
#print("sentence_ids",sentence_ids.shape)
463
#print("sentence_label_",sentence_label_.shape)
467
for id_1, sent in enumerate(sentence_ids):
468
for id_2, sent_id in enumerate(sent):
470
t1 = data[str(int(sent_id))]['sentence']
472
tokens_a = tokenizer.tokenize(t1)
475
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
477
# transform sample to features
478
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
480
all_input_ids.append(torch.tensor(cur_features.input_ids))
481
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
482
all_input_mask.append(torch.tensor(cur_features.input_mask))
483
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
484
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
485
all_is_next.append(torch.tensor(0))
486
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
487
all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
488
all_sentiment_labels.append(torch.tensor(sentiment_label_[id_1]))
490
all_input_ids.append(input_ids_[id_1])
491
all_input_ids_org.append(input_ids_org_[id_1])
492
all_input_mask.append(input_mask_[id_1])
493
all_segment_ids.append(segment_ids_[id_1])
494
all_lm_labels_ids.append(lm_label_ids_[id_1])
495
all_is_next.append(is_next_[id_1])
496
all_tail_idxs.append(tail_idxs_[id_1])
497
all_sentence_labels.append(sentence_label_[id_1])
498
all_sentiment_labels.append(sentiment_label_[id_1])
501
cur_tensors = (torch.stack(all_input_ids),
502
torch.stack(all_input_ids_org),
503
torch.stack(all_input_mask),
504
torch.stack(all_segment_ids),
505
torch.stack(all_lm_labels_ids),
506
torch.stack(all_is_next),
507
torch.stack(all_tail_idxs),
508
torch.stack(all_sentence_labels),
509
torch.stack(all_sentiment_labels)
519
def AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None, in_domain_rep=None):
521
top_k_shape = top_k.indices.shape
522
sentence_ids = top_k.indices
524
#top_k_shape = top_k["indices"].shape
525
#sentence_ids = top_k["indices"]
528
#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
529
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
532
#uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
534
all_sentence_binary_label = list()
535
#all_in_task_rep_comb = list()
536
all_in_rep_comb = list()
538
for id_1, num in enumerate(sentence_label_):
539
#print([sentence_label_==num])
540
#print(type([sentence_label_==num]))
541
sentence_label_int = (sentence_label_==num).to(torch.long)
542
#print(sentence_label_int)
543
#print(sentence_label_int.shape)
544
#print(in_task_rep[id_1].shape)
545
#print(in_task_rep.shape)
547
in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
548
in_domain_rep_append = in_domain_rep[id_1].unsqueeze(0).expand(in_domain_rep.shape[0],-1)
549
#print(in_task_rep_append)
550
#print(in_task_rep_append.shape)
551
in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
552
in_domain_rep_comb = torch.cat((in_domain_rep_append,in_domain_rep),-1)
553
#print(in_task_rep_comb)
554
#print(in_task_rep_comb.shape)
556
#sentence_label_int = sentence_label_int.to(torch.float32)
557
#print(sentence_label_int)
559
#all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
560
#all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
561
#print(in_task_rep_comb.shape)
562
#print(in_domain_rep_comb.shape)
563
in_rep_comb = torch.cat([in_domain_rep_comb,in_task_rep_comb],-1)
566
all_sentence_binary_label.append(sentence_label_int)
567
#all_in_task_rep_comb.append(in_task_rep_comb)
568
all_in_rep_comb.append(in_rep_comb)
569
all_sentence_binary_label = torch.stack(all_sentence_binary_label)
570
#all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
571
all_in_rep_comb = torch.stack(all_in_rep_comb)
573
#cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
574
cur_tensors = (all_in_rep_comb, all_sentence_binary_label)
581
def AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None):
583
top_k_shape = top_k.indices.shape
584
sentence_ids = top_k.indices
586
#top_k_shape = top_k["indices"].shape
587
#sentence_ids = top_k["indices"]
590
#input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
591
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
594
#uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
596
all_sentence_binary_label = list()
597
all_in_task_rep_comb = list()
599
for id_1, num in enumerate(sentence_label_):
600
#print([sentence_label_==num])
601
#print(type([sentence_label_==num]))
602
sentence_label_int = (sentence_label_==num).to(torch.long)
603
#print(sentence_label_int)
604
#print(sentence_label_int.shape)
605
#print(in_task_rep[id_1].shape)
606
#print(in_task_rep.shape)
608
in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
609
#print(in_task_rep_append)
610
#print(in_task_rep_append.shape)
611
in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
612
#print(in_task_rep_comb)
613
#print(in_task_rep_comb.shape)
615
#sentence_label_int = sentence_label_int.to(torch.float32)
616
#print(sentence_label_int)
618
#all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
619
#all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
620
all_sentence_binary_label.append(sentence_label_int)
621
all_in_task_rep_comb.append(in_task_rep_comb)
622
all_sentence_binary_label = torch.stack(all_sentence_binary_label)
623
all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
625
cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
632
class Dataset_noNext(Dataset):
633
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
635
self.vocab_size = tokenizer.vocab_size
636
self.tokenizer = tokenizer
637
self.seq_len = seq_len
638
self.on_memory = on_memory
639
self.corpus_lines = corpus_lines # number of non-empty lines in input corpus
640
self.corpus_path = corpus_path
641
self.encoding = encoding
642
self.current_doc = 0 # to avoid random sentence from same doc
644
# for loading samples directly from file
645
self.sample_counter = 0 # used to keep track of full epochs on file
646
self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair
648
# for loading samples in memory
649
self.current_random_doc = 0
651
self.sample_to_doc = [] # map sample index to doc and line
653
# load samples into memory
657
self.corpus_lines = 0
658
with open(corpus_path, "r", encoding=encoding) as f:
659
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
662
self.all_docs.append(doc)
664
#remove last added sample because there won't be a subsequent line anymore in the doc
665
self.sample_to_doc.pop()
668
sample = {"doc_id": len(self.all_docs),
670
self.sample_to_doc.append(sample)
672
self.corpus_lines = self.corpus_lines + 1
674
# if last row in file is not empty
675
if self.all_docs[-1] != doc:
676
self.all_docs.append(doc)
677
self.sample_to_doc.pop()
679
self.num_docs = len(self.all_docs)
681
# load samples later lazily from disk
683
if self.corpus_lines is None:
684
with open(corpus_path, "r", encoding=encoding) as f:
685
self.corpus_lines = 0
686
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
687
if line.strip() == "":
690
self.corpus_lines += 1
692
# if doc does not end with empty line
693
if line.strip() != "":
696
self.file = open(corpus_path, "r", encoding=encoding)
697
self.random_file = open(corpus_path, "r", encoding=encoding)
700
# last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
701
return self.corpus_lines - self.num_docs - 1
703
def __getitem__(self, item):
704
cur_id = self.sample_counter
705
self.sample_counter += 1
706
if not self.on_memory:
707
# after one epoch we start again from beginning of file
708
if cur_id != 0 and (cur_id % len(self) == 0):
710
self.file = open(self.corpus_path, "r", encoding=self.encoding)
712
#t1, t2, is_next_label = self.random_sent(item)
713
t1, is_next_label = self.random_sent(item)
714
if is_next_label == None:
718
#tokens_a = self.tokenizer.tokenize(t1)
719
tokens_a = tokenizer.tokenize(t1)
721
if "</s>" in tokens_a:
722
print("Have more than 1 </s>")
723
#tokens_a[tokens_a.index("<s>")] = "s"
724
for i in range(len(tokens_a)):
725
if tokens_a[i] == "</s>":
728
#tokens_b = self.tokenizer.tokenize(t2)
731
cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
733
# transform sample to features
734
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
736
cur_tensors = (torch.tensor(cur_features.input_ids),
737
torch.tensor(cur_features.input_ids_org),
738
torch.tensor(cur_features.input_mask),
739
torch.tensor(cur_features.segment_ids),
740
torch.tensor(cur_features.lm_label_ids),
741
torch.tensor(cur_features.is_next),
742
torch.tensor(cur_features.tail_idxs))
746
def random_sent(self, index):
748
Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
749
from one doc. With 50% the second sentence will be a random one from another doc.
750
:param index: int, index of sample.
751
:return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
753
t1, t2 = self.get_corpus_line(index)
756
def get_corpus_line(self, item):
758
Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
759
:param item: int, index of sample.
760
:return: (str, str), two subsequent sentences from corpus
764
assert item < self.corpus_lines
766
sample = self.sample_to_doc[item]
767
t1 = self.all_docs[sample["doc_id"]][sample["line"]]
768
# used later to avoid random nextSentence from same doc
769
self.current_doc = sample["doc_id"]
773
if self.line_buffer is None:
774
# read first non-empty line of file
776
t1 = next(self.file).strip()
778
# use t2 from previous iteration as new t1
779
t1 = self.line_buffer
780
# skip empty rows that are used for separating documents and keep track of current doc id
782
t1 = next(self.file).strip()
783
self.current_doc = self.current_doc+1
784
self.line_buffer = next(self.file).strip()
790
def get_random_line(self):
792
Get random line from another document for nextSentence task.
793
:return: str, content of one line
795
# Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
796
# corpora. However, just to be careful, we try to make sure that
797
# the random document is not the same as the document we're processing.
800
rand_doc_idx = random.randint(0, len(self.all_docs)-1)
801
rand_doc = self.all_docs[rand_doc_idx]
802
line = rand_doc[random.randrange(len(rand_doc))]
804
rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
806
for _ in range(rand_index):
807
line = self.get_next_line()
808
#check if our picked random line is really from another doc like we want it to be
809
if self.current_random_doc != self.current_doc:
813
def get_next_line(self):
814
""" Gets next line of random_file and starts over when reaching end of file"""
816
line = next(self.random_file).strip()
817
#keep track of which document we are currently looking at to later avoid having the same doc as t1
819
self.current_random_doc = self.current_random_doc + 1
820
line = next(self.random_file).strip()
821
except StopIteration:
822
self.random_file.close()
823
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
824
line = next(self.random_file).strip()
828
class InputExample(object):
829
"""A single training/test example for the language model."""
831
def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
832
"""Constructs a InputExample.
834
guid: Unique id for the example.
835
tokens_a: string. The untokenized text of the first sequence. For single
836
sequence tasks, only this sequence must be specified.
837
tokens_b: (Optional) string. The untokenized text of the second sequence.
838
Only must be specified for sequence pair tasks.
839
label: (Optional) string. The label of the example. This should be
840
specified for train and dev examples, but not for test examples.
843
self.tokens_a = tokens_a
844
self.tokens_b = tokens_b
845
self.is_next = is_next # nextSentence
846
self.lm_labels = lm_labels # masked words for language model
849
class InputFeatures(object):
850
"""A single set of features of data."""
852
def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
853
self.input_ids = input_ids
854
self.input_ids_org = input_ids_org
855
self.input_mask = input_mask
856
self.segment_ids = segment_ids
857
self.is_next = is_next
858
self.lm_label_ids = lm_label_ids
859
self.tail_idxs = tail_idxs
862
def random_word(tokens, tokenizer):
864
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
865
:param tokens: list of str, tokenized sentence.
866
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
867
:return: (list of str, list of int), masked tokens and related labels for LM prediction
871
for i, token in enumerate(tokens):
873
prob = random.random()
874
# mask token with 15% probability
877
#candidate_id = random.randint(0,tokenizer.vocab_size)
878
#print(tokenizer.convert_ids_to_tokens(candidate_id))
881
# 80% randomly change token to mask token
883
#tokens[i] = "[MASK]"
886
# 10% randomly change token to random token
888
#tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
889
#tokens[i] = tokenizer.convert_ids_to_tokens(candidate_id)
890
candidate_id = random.randint(0,tokenizer.vocab_size)
891
w = tokenizer.convert_ids_to_tokens(candidate_id)
893
if tokens[i] == None:
895
w = tokenizer.convert_ids_to_tokens(candidate_id)
900
# -> rest 10% randomly keep current token
902
# append current token to output (we will predict these later)
904
#output_label.append(tokenizer.vocab[token])
905
w = tokenizer.convert_tokens_to_ids(token)
907
output_label.append(w)
909
print("Have no this tokens in ids")
912
# For unknown words (should not occur with BPE vocab)
913
#output_label.append(tokenizer.vocab["<unk>"])
914
w = tokenizer.convert_tokens_to_ids("<unk>")
915
output_label.append(w)
916
logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
918
# no masking token (will be ignored by loss function later)
919
output_label.append(-1)
921
return tokens, output_label
924
def convert_example_to_features(example, max_seq_length, tokenizer):
926
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
927
IDs, LM labels, input_mask, CLS and SEP tokens etc.
928
:param example: InputExample, containing sentence input as strings and is_next label
929
:param max_seq_length: int, maximum length of sequence.
930
:param tokenizer: Tokenizer
931
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
933
#now tokens_a is input_ids
934
tokens_a = example.tokens_a
935
tokens_b = example.tokens_b
936
# Modifies `tokens_a` and `tokens_b` in place so that the total
937
# length is less than the specified length.
938
# Account for [CLS], [SEP], [SEP] with "- 3"
939
#_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
940
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
943
tokens_a_org = tokens_a.copy()
944
tokens_a, t1_label = random_word(tokens_a, tokenizer)
951
#tokens_b, t2_label = random_word(tokens_b, tokenizer)
952
# concatenate lm labels and account for CLS, SEP, SEP
953
#lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
954
lm_label_ids = ([-1] + t1_label + [-1])
956
# The convention in BERT is:
957
# (a) For sequence pairs:
958
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
959
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
960
# (b) For single sequences:
961
# tokens: [CLS] the dog is hairy . [SEP]
962
# type_ids: 0 0 0 0 0 0 0
964
# Where "type_ids" are used to indicate whether this is the first
965
# sequence or the second sequence. The embedding vectors for `type=0` and
966
# `type=1` were learned during pre-training and are added to the wordpiece
967
# embedding vector (and position vector). This is not *strictly* necessary
968
# since the [SEP] token unambigiously separates the sequences, but it makes
969
# it easier for the model to learn the concept of sequences.
971
# For classification tasks, the first vector (corresponding to [CLS]) is
972
# used as as the "sentence vector". Note that this only makes sense because
973
# the entire model is fine-tuned.
978
tokens_org.append("<s>")
979
segment_ids.append(0)
980
for i, token in enumerate(tokens_a):
982
tokens.append(tokens_a[i])
983
tokens_org.append(tokens_a_org[i])
984
segment_ids.append(0)
987
tokens_org.append("s")
988
segment_ids.append(0)
989
tokens.append("</s>")
990
tokens_org.append("</s>")
991
segment_ids.append(0)
993
#tokens.append("[SEP]")
994
#segment_ids.append(1)
996
#input_ids = tokenizer.convert_tokens_to_ids(tokens)
997
input_ids = tokenizer.encode(tokens, add_special_tokens=False)
998
input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
999
tail_idxs = len(input_ids)-1
1002
input_ids = [w if w!=None else 0 for w in input_ids]
1003
input_ids_org = [w if w!=None else 0 for w in input_ids_org]
1007
# The mask has 1 for real tokens and 0 for padding tokens. Only real
1008
# tokens are attended to.
1009
input_mask = [1] * len(input_ids)
1011
# Zero-pad up to the sequence length.
1012
pad_id = tokenizer.convert_tokens_to_ids("<pad>")
1013
while len(input_ids) < max_seq_length:
1014
input_ids.append(pad_id)
1015
input_ids_org.append(pad_id)
1016
input_mask.append(0)
1017
segment_ids.append(0)
1018
lm_label_ids.append(-1)
1021
assert len(input_ids) == max_seq_length
1022
assert len(input_ids_org) == max_seq_length
1023
assert len(input_mask) == max_seq_length
1024
assert len(segment_ids) == max_seq_length
1025
assert len(lm_label_ids) == max_seq_length
1027
print("!!!Warning!!!")
1028
input_ids = input_ids[:max_seq_length-1]
1029
if 2 not in input_ids:
1032
input_ids += [pad_id]
1033
input_ids_org = input_ids_org[:max_seq_length-1]
1034
if 2 not in input_ids_org:
1035
input_ids_org += [2]
1037
input_ids_org += [pad_id]
1038
input_mask = input_mask[:max_seq_length-1]+[0]
1039
segment_ids = segment_ids[:max_seq_length-1]+[0]
1040
lm_label_ids = lm_label_ids[:max_seq_length-1]+[-1]
1043
if len(input_ids) != max_seq_length:
1044
print(len(input_ids))
1046
if len(input_ids_org) != max_seq_length:
1047
print(len(input_ids_org))
1049
if len(input_mask) != max_seq_length:
1050
print(len(input_mask))
1052
if len(segment_ids) != max_seq_length:
1053
print(len(segment_ids))
1055
if len(lm_label_ids) != max_seq_length:
1056
print(len(lm_label_ids))
1064
if example.guid < 5:
1065
logger.info("*** Example ***")
1066
logger.info("guid: %s" % (example.guid))
1067
logger.info("tokens: %s" % " ".join(
1068
[str(x) for x in tokens]))
1069
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
1070
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
1072
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
1073
logger.info("LM label: %s " % (lm_label_ids))
1074
logger.info("Is next sentence label: %s " % (example.is_next))
1077
features = InputFeatures(input_ids=input_ids,
1078
input_ids_org = input_ids_org,
1079
input_mask=input_mask,
1080
segment_ids=segment_ids,
1081
lm_label_ids=lm_label_ids,
1082
is_next=example.is_next,
1083
tail_idxs=tail_idxs)
1088
parser = argparse.ArgumentParser()
1090
parser = get_parameter(parser)
1092
args = parser.parse_args()
1094
if args.local_rank == -1 or args.no_cuda:
1095
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1096
n_gpu = torch.cuda.device_count()
1098
torch.cuda.set_device(args.local_rank)
1099
device = torch.device("cuda", args.local_rank)
1101
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
1102
torch.distributed.init_process_group(backend='nccl')
1103
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
1104
device, n_gpu, bool(args.local_rank != -1), args.fp16))
1106
if args.gradient_accumulation_steps < 1:
1107
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
1108
args.gradient_accumulation_steps))
1110
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
1112
random.seed(args.seed)
1113
np.random.seed(args.seed)
1114
torch.manual_seed(args.seed)
1116
torch.cuda.manual_seed_all(args.seed)
1118
if not args.do_train:
1119
raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
1121
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
1122
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
1123
if not os.path.exists(args.output_dir):
1124
os.makedirs(args.output_dir)
1126
#tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model, do_lower_case=args.do_lower_case)
1127
tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)
1130
#train_examples = None
1131
num_train_optimization_steps = None
1133
print("Loading Train Dataset", args.data_dir_indomain)
1134
#train_dataset = Dataset_noNext(args.data_dir, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory)
1135
all_type_sentence, train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
1136
num_train_optimization_steps = int(
1137
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
1138
if args.local_rank != -1:
1139
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
1144
model = RobertaForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1145
#model = RobertaForSequenceClassification.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1152
param_optimizer = list(model.named_parameters())
1154
for par in param_optimizer:
1158
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
1159
optimizer_grouped_parameters = [
1160
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
1161
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
1163
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
1164
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
1168
from apex import amp
1170
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
1173
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
1177
model = torch.nn.DataParallel(model)
1179
if args.local_rank != -1:
1180
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
1186
logger.info("***** Running training *****")
1187
logger.info(" Num examples = %d", len(train_dataset))
1188
logger.info(" Batch size = %d", args.train_batch_size)
1189
logger.info(" Num steps = %d", num_train_optimization_steps)
1191
if args.local_rank == -1:
1192
train_sampler = RandomSampler(train_dataset)
1193
#all_type_sentence_sampler = RandomSampler(all_type_sentence)
1195
#TODO: check if this works with current data generator from disk that relies on next(file)
1196
# (it doesn't return item back by index)
1197
train_sampler = DistributedSampler(train_dataset)
1198
#all_type_sentence_sampler = DistributedSampler(all_type_sentence)
1199
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1200
#all_type_sentence_dataloader = DataLoader(all_type_sentence, sampler=all_type_sentence_sampler, batch_size=len(all_type_sentence_label))
1202
output_loss_file = os.path.join(args.output_dir, "loss")
1203
loss_fout = open(output_loss_file, 'w')
1206
output_loss_file_no_pseudo = os.path.join(args.output_dir, "loss_no_pseudo")
1207
loss_fout_no_pseudo = open(output_loss_file_no_pseudo, 'w')
1213
#alpha = float(1/(args.num_train_epochs*len(train_dataloader)))
1214
#alpha = float(1/args.num_train_epochs)
1221
#retrive_gate = args.num_labels_task
1222
#retrive_gate = len(train_dataset)/100
1224
all_type_sentence_label = list()
1225
all_previous_sentence_label = list()
1226
all_type_sentiment_label = list()
1227
all_previous_sentiment_label = list()
1228
top_k_all_type = dict()
1229
bottom_k_all_type = dict()
1230
for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1232
nb_tr_examples, nb_tr_steps = 0, 0
1233
for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1236
#######################
1237
######################
1240
batch_ = tuple(t.to(device) for t in batch_)
1241
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = batch_
1245
# Generate query representation
1246
in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1252
#Domain Binary Classifier - Outdomain
1253
#batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1254
batch = AugmentationData_Domain(in_domain_rep.shape[0], k, tokenizer, args.max_seq_length)
1255
batch = tuple(t.to(device) for t in batch)
1256
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, domain_id = batch
1258
out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, lm_label=lm_label_ids, attention_mask=input_mask, func="in_domain_task_rep")
1260
#print(domain_top_k["indices"].shape)
1261
#print(input_ids_org.shape)
1262
#print(out_domain_rep_tail.shape)
1263
#print(in_domain_rep.shape)
1265
############Construct constrive instances
1266
#print("=============")
1267
#print(out_domain_rep_tail.shape)
1268
#print(in_domain_rep.shape)
1269
#print("=============")
1270
comb_rep_pos = torch.cat([in_domain_rep,in_domain_rep.flip(0)], 1)
1271
in_domain_rep_ready = in_domain_rep.repeat(1,int(out_domain_rep_tail.shape[0]/in_domain_rep.shape[0])).reshape(out_domain_rep_tail.shape[0],out_domain_rep_tail.shape[1])
1272
comb_rep_unknow = torch.cat([in_domain_rep_ready, out_domain_rep_tail], 1)
1274
in_domain_binary_loss, domain_binary_logit = model(func="domain_binary_classifier", in_domain_rep=comb_rep_pos.to(device), out_domain_rep=comb_rep_unknow.to(device), domain_id=domain_id,use_detach=True)
1282
#Task Binary Classifier in domain
1283
#Pseudo Task --> Won't bp to PLM: only train classifier [In domain data]
1284
#batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1285
batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep)
1286
batch = tuple(t.to(device) for t in batch)
1287
all_in_task_rep_comb, all_sentence_binary_label = batch
1288
in_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=True)
1294
#Train Task org - finetune
1295
#split into: in_dom and query_ --> different weight
1296
task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1301
#Domain Task binary including outdomain
1302
#batch = AugmentationData_Task(task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1303
#batch = tuple(t.to(device) for t in batch)
1304
#input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1305
#out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1307
#batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1308
#batch = tuple(t.to(device) for t in batch)
1309
#all_in_task_rep_comb, all_sentence_binary_label = batch
1310
#out_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1316
#Domain-Task binary Level (in domain task)
1318
batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1319
batch = tuple(t.to(device) for t in batch)
1320
all_in_task_rep_comb, all_sentence_binary_label = batch
1321
in_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1329
#loss = task_loss_org.mean()*2 + in_task_binary_loss.mean() + in_domain_task_binary_loss.mean()*0.5
1330
loss = task_loss_org.mean() + in_domain_task_binary_loss.mean()
1331
#loss = task_loss_org.mean() + in_domain_task_binary_loss.mean()
1332
#loss = task_loss_org.mean()
1334
#loss = mix_domain_binary_loss + (in_task_binary_loss + out_task_binary_loss)/2 + task_loss_org + out_domain_task_binary_loss
1335
print("No Using GPU")
1338
if args.gradient_accumulation_steps > 1:
1339
loss = loss / args.gradient_accumulation_steps
1341
with amp.scale_loss(loss, optimizer) as scaled_loss:
1342
scaled_loss.backward()
1347
loss_fout.write("{}\n".format(loss.item()))
1351
#loss_fout_no_pseudo.write("{}\n".format(loss.item()-pseudo.item()))
1354
tr_loss += loss.item()
1355
#nb_tr_examples += input_ids.size(0)
1356
nb_tr_examples += input_ids_.size(0)
1358
if (step + 1) % args.gradient_accumulation_steps == 0:
1360
# modify learning rate with special warm up BERT uses
1361
# if args.fp16 is False, BertAdam is used that handles this automatically
1362
#lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
1363
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1366
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1373
#optimizer.zero_grad()
1381
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1382
#output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1383
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(epo))
1384
torch.save(model_to_save.state_dict(), output_model_file)
1387
#if args.num_train_epochs/args.augment_times in [1,2,3]:
1388
if (args.num_train_epochs/(args.augment_times/5))%5 == 0:
1389
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1390
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1391
torch.save(model_to_save.state_dict(), output_model_file)
1397
# Save a trained model
1398
logger.info("** ** * Saving fine - tuned model ** ** * ")
1399
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1400
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1402
torch.save(model_to_save.state_dict(), output_model_file)
1406
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1407
"""Truncates a sequence pair in place to the maximum length."""
1409
# This is a simple heuristic which will always truncate the longer sequence
1410
# one token at a time. This makes more sense than truncating an equal percent
1411
# of tokens from each, since if one sequence is very short then each token
1412
# that's truncated likely contains more information than a longer sequence.
1414
#total_length = len(tokens_a) + len(tokens_b)
1415
total_length = len(tokens_a)
1416
if total_length <= max_length:
1422
def accuracy(out, labels):
1423
outputs = np.argmax(out, axis=1)
1424
return np.sum(outputs == labels)
1427
if __name__ == "__main__":